Source code for oumi.cli.analyze

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Optional

import pandas as pd
import typer
from rich.table import Table

import oumi.cli.cli_utils as cli_utils
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
from oumi.utils.logging import logger

# Valid output formats for analysis results
_VALID_OUTPUT_FORMATS = ("csv", "json", "parquet")

if TYPE_CHECKING:
    from oumi.core.analyze.dataset_analyzer import DatasetAnalyzer


[docs] def analyze( ctx: typer.Context, config: Annotated[ str, typer.Option( *cli_utils.CONFIG_FLAGS, help="Path to the configuration file for analysis.", ), ], output: Annotated[ Optional[str], typer.Option( "--output", "-o", help="Output directory for analysis results. Overrides config output_path.", ), ] = None, output_format: Annotated[ str, typer.Option( "--format", "-f", help="Output format for results: csv, json, or parquet (case-insensitive).", ), ] = "csv", level: cli_utils.LOG_LEVEL_TYPE = None, verbose: cli_utils.VERBOSE_TYPE = False, ): """Analyze a dataset to compute metrics and statistics. Args: ctx: The Typer context object. config: Path to the configuration file for analysis. output: Output directory for results. Overrides config output_path. output_format: Output format (csv, json, parquet). Case-insensitive. level: The logging level for the specified command. verbose: Enable verbose logging with additional debug information. """ from oumi.core.analyze.dataset_analyzer import DatasetAnalyzer # Validate output format early before any expensive operations output_format = output_format.lower() if output_format not in _VALID_OUTPUT_FORMATS: cli_utils.CONSOLE.print( f"[red]Error:[/red] Invalid output format '{output_format}'. " f"Supported formats: {', '.join(_VALID_OUTPUT_FORMATS)}" ) raise typer.Exit(code=1) try: extra_args = cli_utils.parse_extra_cli_args(ctx) config = str( cli_utils.resolve_and_fetch_config( try_get_config_name_for_alias(config, AliasType.ANALYZE), ) ) with cli_utils.CONSOLE.status( "[green]Loading configuration...[/green]", spinner="dots" ): # Delayed imports from oumi.core.configs import AnalyzeConfig # Load configuration parsed_config: AnalyzeConfig = AnalyzeConfig.from_yaml_and_arg_list( config, extra_args, logger=logger ) # Override output path if provided via CLI if output: parsed_config.output_path = output # Validate configuration parsed_config.finalize_and_validate() if verbose: parsed_config.print_config(logger) # Create analyzer with cli_utils.CONSOLE.status( "[green]Loading dataset...[/green]", spinner="dots" ): analyzer = DatasetAnalyzer(parsed_config) # Run analysis with cli_utils.CONSOLE.status( "[green]Running analysis...[/green]", spinner="dots" ): analyzer.analyze_dataset() # Display summary _display_analysis_summary(analyzer) # Export results if parsed_config.output_path: _export_results(analyzer, parsed_config.output_path, output_format) except FileNotFoundError as e: logger.error(f"Configuration file not found: {e}") cli_utils.CONSOLE.print(f"[red]Error:[/red] Configuration file not found: {e}") raise typer.Exit(code=1) except ValueError as e: logger.error(f"Invalid configuration: {e}") cli_utils.CONSOLE.print(f"[red]Error:[/red] Invalid configuration: {e}") raise typer.Exit(code=1) except RuntimeError as e: logger.error(f"Analysis failed: {e}") cli_utils.CONSOLE.print(f"[red]Error:[/red] Analysis failed: {e}") raise typer.Exit(code=1) except Exception as e: logger.error(f"Unexpected error during analysis: {e}", exc_info=True) cli_utils.CONSOLE.print(f"[red]Unexpected error:[/red] {e}") raise typer.Exit(code=1)
def _display_analysis_summary(analyzer: "DatasetAnalyzer") -> None: """Display analysis summary in formatted tables to the console.""" summary = analyzer.analysis_summary # Dataset overview table overview = summary.get("dataset_overview", {}) if overview: table = Table( title="Dataset Overview", title_style="bold magenta", show_lines=True, ) table.add_column("Metric", style="cyan") table.add_column("Value", style="green") table.add_row("Dataset Name", str(overview.get("dataset_name", "N/A"))) table.add_row( "Total Conversations", str(overview.get("total_conversations", "N/A")) ) table.add_row( "Conversations Analyzed", str(overview.get("conversations_analyzed", "N/A")) ) table.add_row( "Coverage", f"{overview.get('dataset_coverage_percentage', 0):.1f}%", ) table.add_row("Total Messages", str(overview.get("total_messages", "N/A"))) table.add_row( "Analyzers Used", ", ".join(overview.get("analyzers_used", [])) or "None", ) cli_utils.CONSOLE.print(table) # Message-level summary msg_summary = summary.get("message_level_summary", {}) if msg_summary: for analyzer_name, metrics in msg_summary.items(): table = Table( title=f"Message-Level Metrics ({analyzer_name})", title_style="bold blue", show_lines=True, ) table.add_column("Metric", style="cyan") table.add_column("Mean", style="green") table.add_column("Std", style="yellow") table.add_column("Min", style="dim") table.add_column("Max", style="dim") table.add_column("Median", style="dim") for metric_name, stats in metrics.items(): if isinstance(stats, dict): table.add_row( metric_name, f"{stats.get('mean', 'N/A'):.2f}" if isinstance(stats.get("mean"), (int, float)) else "N/A", f"{stats.get('std', 'N/A'):.2f}" if isinstance(stats.get("std"), (int, float)) else "N/A", str(stats.get("min", "N/A")), str(stats.get("max", "N/A")), f"{stats.get('median', 'N/A'):.2f}" if isinstance(stats.get("median"), (int, float)) else "N/A", ) cli_utils.CONSOLE.print(table) # Conversation turns summary turns_summary = summary.get("conversation_turns", {}) if turns_summary and isinstance(turns_summary, dict) and turns_summary.get("count"): table = Table( title="Conversation Turns", title_style="bold yellow", show_lines=True, ) table.add_column("Statistic", style="cyan") table.add_column("Value", style="green") table.add_row("Count", str(turns_summary.get("count", "N/A"))) table.add_row( "Mean", f"{turns_summary.get('mean', 0):.2f}" if isinstance(turns_summary.get("mean"), (int, float)) else "N/A", ) table.add_row( "Std", f"{turns_summary.get('std', 0):.2f}" if isinstance(turns_summary.get("std"), (int, float)) else "N/A", ) table.add_row("Min", str(turns_summary.get("min", "N/A"))) table.add_row("Max", str(turns_summary.get("max", "N/A"))) table.add_row( "Median", f"{turns_summary.get('median', 0):.2f}" if isinstance(turns_summary.get("median"), (int, float)) else "N/A", ) cli_utils.CONSOLE.print(table) def _export_results( analyzer: "DatasetAnalyzer", output_path: str, output_format: str, ) -> None: """Export analysis results to files.""" output_dir = Path(output_path) output_dir.mkdir(parents=True, exist_ok=True) # Export message-level results if analyzer.message_df is not None and not analyzer.message_df.empty: msg_path = output_dir / f"message_analysis.{output_format}" _save_dataframe(analyzer.message_df, msg_path, output_format) cli_utils.CONSOLE.print(f"[green]Saved message analysis to:[/green] {msg_path}") # Export conversation-level results if analyzer.conversation_df is not None and not analyzer.conversation_df.empty: conv_path = output_dir / f"conversation_analysis.{output_format}" _save_dataframe(analyzer.conversation_df, conv_path, output_format) cli_utils.CONSOLE.print( f"[green]Saved conversation analysis to:[/green] {conv_path}" ) # Export summary as JSON summary_path = output_dir / "analysis_summary.json" with open(summary_path, "w") as f: json.dump(analyzer.analysis_summary, f, indent=2, default=str) cli_utils.CONSOLE.print(f"[green]Saved analysis summary to:[/green] {summary_path}") def _save_dataframe(df: pd.DataFrame, path: Path, output_format: str) -> None: """Save a DataFrame to the specified format.""" if output_format == "csv": df.to_csv(path, index=False) elif output_format == "json": df.to_json(path, orient="records", indent=2) elif output_format == "parquet": df.to_parquet(path, index=False)