Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
"""SFT (Supervised Fine-Tuning) training for omniagents.

This module provides a high-level API for supervised fine-tuning of local models
using session traces from omniagents.

SFT trains a model to imitate successful conversations by:
1. Collecting good examples (filtered by judgment, measure, or custom logic)
2. Training the model to reproduce assistant responses

Example:
    from omniagents.training import SFTTrainer, SFTTrainingConfig

    # Train from session traces
    trainer = SFTTrainer(model_name="Qwen/Qwen2.5-0.5B")
    result = trainer.train_from_traces(
        "my_project", "my_agent",
        judgment="acceptable",
    )
    result.save_model("./sft_trained")

    # Or export and train separately
    from omniagents.training import export_traces_for_sft
    dataset = export_traces_for_sft("my_project", "my_agent", judgment="acceptable")
    result = trainer.train(dataset)
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Any, Union, TYPE_CHECKING

if TYPE_CHECKING:
    from datasets import Dataset
    from transformers import PreTrainedModel, PreTrainedTokenizer
    from trl import SFTTrainer as TRLSFTTrainer

try:
    from datasets import Dataset
except ImportError:
    Dataset = None  # type: ignore


@dataclass
class SFTTrainingConfig:
    """Configuration for SFT training.

    This wraps TRL's SFTConfig with sensible defaults for agent training.

    Attributes:
        # Training parameters
        num_train_epochs: Number of training epochs (default: 1)
        per_device_batch_size: Samples per batch (default: 2)
        gradient_accumulation_steps: Gradient accumulation (default: 4)
        learning_rate: Learning rate (default: 2e-5)
        max_grad_norm: Gradient clipping (default: 1.0)
        warmup_ratio: Warmup ratio (default: 0.1)

        # Sequence parameters
        max_length: Maximum sequence length (default: 2048)
        packing: Whether to pack sequences for efficiency (default: False)

        # Output
        output_dir: Directory for checkpoints and logs
        save_strategy: When to save ("no", "epoch", "steps")
        logging_steps: Log every N steps

        # Reporting
        report_to: Where to report metrics ("none", "wandb", "tensorboard")
    """

    # Training parameters
    num_train_epochs: int = 1
    per_device_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-5
    max_grad_norm: float = 1.0
    warmup_ratio: float = 0.1

    # Sequence parameters
    max_length: int = 2048
    packing: bool = False

    # Output
    output_dir: str = "./sft_output"
    save_strategy: str = "no"
    logging_steps: int = 10

    # Reporting
    report_to: str = "none"

    def to_trl_config(self) -> "SFTConfig":
        """Convert to TRL's SFTConfig."""
        try:
            from trl import SFTConfig
        except ImportError:
            raise ImportError(
                "trl library is required for SFT training. "
                "Install with: pip install trl>=0.12.0"
            )

        return SFTConfig(
            output_dir=self.output_dir,
            num_train_epochs=self.num_train_epochs,
            per_device_train_batch_size=self.per_device_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            learning_rate=self.learning_rate,
            max_grad_norm=self.max_grad_norm,
            warmup_ratio=self.warmup_ratio,
            max_length=self.max_length,
            packing=self.packing,
            save_strategy=self.save_strategy,
            logging_steps=self.logging_steps,
            report_to=self.report_to,
        )


@dataclass
class SFTTrainingResult:
    """Result of SFT training.

    Attributes:
        model_path: Path to the saved model (if saved)
        metrics: Training metrics from the trainer
        final_loss: Final training loss
        trainer: The underlying TRL SFTTrainer (for advanced use)
    """

    model_path: Optional[str] = None
    metrics: List[Dict[str, Any]] = field(default_factory=list)
    final_loss: Optional[float] = None
    trainer: Optional[Any] = None

    def save_model(self, path: str) -> str:
        """Save the trained model to a path.

        Args:
            path: Directory to save the model

        Returns:
            The path where the model was saved
        """
        if self.trainer is None:
            raise ValueError("No trainer available - training may have failed")

        path = str(Path(path).resolve())
        self.trainer.save_model(path)
        self.model_path = path
        return path

    def to_ollama(
        self,
        model_name: str,
        *,
        model_path: Optional[str] = None,
        llama_cpp_path: Optional[str] = None,
        quantization: str = "q8_0",
    ) -> str:
        """Convert the trained model to Ollama format and register it.

        This converts the HuggingFace model (safetensors) to GGUF format using
        llama.cpp's convert_hf_to_gguf.py script, then registers it with Ollama.

        The chat template is automatically detected from the model architecture
        and fetched from a matching Ollama base model.

        Args:
            model_name: Name to register the model as in Ollama
            model_path: Path to the model (defaults to self.model_path)
            llama_cpp_path: Path to llama.cpp repo (searches common locations if not specified)
            quantization: GGUF quantization type (default: "q8_0"). Options: f32, f16, bf16, q8_0, etc.

        Returns:
            The Ollama model name that can be used with litellm/ollama_chat/

        Example:
            result = trainer.train(dataset)
            result.save_model("./trained_model")
            ollama_name = result.to_ollama("my-trained-agent")

            # Now use with omniagents
            agent.model = f"ollama_chat/{ollama_name}"
        """
        import subprocess
        import json

        path = model_path or self.model_path
        if path is None:
            raise ValueError(
                "No model path available. Call save_model() first or provide model_path."
            )

        path = Path(path).resolve()
        gguf_path = path / "model.gguf"
        modelfile_path = path / "Modelfile"

        # Find llama.cpp convert script
        convert_script = self._find_llama_cpp_convert(llama_cpp_path)

        print(f"Converting model to Ollama format: {model_name}")
        print(f"  Source: {path}")
        print(f"  Using: {convert_script}")

        try:
            # Step 1: Convert safetensors to GGUF using llama.cpp
            print(f"  Converting to GGUF format (quantization: {quantization})...")
            convert_cmd = [
                "python", str(convert_script),
                str(path),
                "--outtype", quantization,
                "--outfile", str(gguf_path),
            ]

            result = subprocess.run(
                convert_cmd,
                check=True,
                capture_output=True,
                text=True,
            )
            print(f"  Created GGUF: {gguf_path}")

            # Step 2: Get chat template from base model
            template, stop_tokens = self._get_ollama_template(path)

            # Step 3: Create Modelfile with template
            modelfile_content = f"FROM {gguf_path}\n\n"
            if template:
                modelfile_content += f'TEMPLATE """{template}"""\n\n'
            for stop in stop_tokens:
                modelfile_content += f'PARAMETER stop "{stop}"\n'

            modelfile_path.write_text(modelfile_content)

            # Step 4: Create Ollama model from GGUF
            print("  Registering with Ollama...")
            ollama_cmd = ["ollama", "create", model_name, "-f", str(modelfile_path)]

            result = subprocess.run(
                ollama_cmd,
                check=True,
                capture_output=True,
                text=True,
            )
            print(f"Successfully created Ollama model: {model_name}")
            return model_name

        except subprocess.CalledProcessError as e:
            error_msg = e.stderr or e.stdout or str(e)
            raise RuntimeError(
                f"Failed to create Ollama model. Error: {error_msg}"
            ) from e
        except FileNotFoundError as e:
            if "ollama" in str(e):
                raise RuntimeError(
                    "ollama command not found. Please install Ollama: https://ollama.com/download"
                ) from e
            raise
        finally:
            # Clean up the Modelfile (keep the GGUF for potential reuse)
            if modelfile_path.exists():
                modelfile_path.unlink()

    def _find_llama_cpp_convert(self, llama_cpp_path: Optional[str] = None) -> Path:
        """Find the llama.cpp convert_hf_to_gguf.py script."""
        import os

        # Check explicit argument first
        if llama_cpp_path:
            script = Path(llama_cpp_path) / "convert_hf_to_gguf.py"
            if script.exists():
                return script
            raise FileNotFoundError(f"convert_hf_to_gguf.py not found at {llama_cpp_path}")

        # Check LLAMA_CPP_PATH environment variable
        env_path = os.environ.get("LLAMA_CPP_PATH")
        if env_path:
            script = Path(env_path) / "convert_hf_to_gguf.py"
            if script.exists():
                return script

        # Search common locations
        search_paths = [
            Path.home() / "code" / "llama.cpp",
            Path.home() / "llama.cpp",
            Path.home() / "projects" / "llama.cpp",
            Path("/opt/llama.cpp"),
            Path.cwd() / "llama.cpp",
        ]

        for base in search_paths:
            script = base / "convert_hf_to_gguf.py"
            if script.exists():
                return script

        raise FileNotFoundError(
            "llama.cpp not found. Please either:\n"
            "  1. Set LLAMA_CPP_PATH in your .env file\n"
            "  2. Clone to ~/code/llama.cpp (auto-detected)\n"
            "  3. Pass llama_cpp_path='/path/to/llama.cpp' to to_ollama()\n\n"
            "To install:\n"
            "  git clone https://github.com/ggml-org/llama.cpp.git\n"
            "  pip install -r llama.cpp/requirements.txt"
        )

    def _get_ollama_template(self, model_path: Path) -> tuple:
        """Get the Ollama chat template based on the model architecture.

        Reads config.json to detect the model architecture, maps it to a known
        Ollama base model, and fetches the template from that model.

        Returns:
            Tuple of (template_string, list_of_stop_tokens)
        """
        import subprocess
        import json

        # Map HuggingFace architectures to Ollama base models
        ARCH_TO_OLLAMA_MODEL = {
            "Qwen2ForCausalLM": "qwen2.5:0.5b",
            "Qwen2_5ForCausalLM": "qwen2.5:0.5b",
            "Qwen3ForCausalLM": "qwen3:0.6b",
            "LlamaForCausalLM": "llama3.2:1b",
            "MistralForCausalLM": "mistral:7b",
            "GemmaForCausalLM": "gemma:2b",
            "Gemma2ForCausalLM": "gemma2:2b",
            "PhiForCausalLM": "phi3:mini",
            "Phi3ForCausalLM": "phi3:mini",
        }

        # Default stop tokens for common architectures
        ARCH_TO_STOP_TOKENS = {
            "Qwen2ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "Qwen2_5ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "Qwen3ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "LlamaForCausalLM": ["<|eot_id|>"],
            "MistralForCausalLM": ["</s>"],
            "GemmaForCausalLM": ["<end_of_turn>"],
            "Gemma2ForCausalLM": ["<end_of_turn>"],
            "PhiForCausalLM": ["<|end|>"],
            "Phi3ForCausalLM": ["<|end|>"],
        }

        # Read model config to get architecture
        config_path = model_path / "config.json"
        if not config_path.exists():
            print("  Warning: config.json not found, using basic template")
            return None, []

        with open(config_path) as f:
            config = json.load(f)

        architectures = config.get("architectures", [])
        if not architectures:
            print("  Warning: No architecture found in config.json")
            return None, []

        arch = architectures[0]
        print(f"  Detected architecture: {arch}")

        # Get corresponding Ollama model
        ollama_model = ARCH_TO_OLLAMA_MODEL.get(arch)
        stop_tokens = ARCH_TO_STOP_TOKENS.get(arch, [])

        if not ollama_model:
            print(f"  Warning: Unknown architecture {arch}, using basic template")
            return None, stop_tokens

        # Try to fetch template from Ollama
        try:
            result = subprocess.run(
                ["ollama", "show", ollama_model, "--template"],
                capture_output=True,
                text=True,
                timeout=30,
            )
            if result.returncode == 0 and result.stdout.strip():
                print(f"  Using template from: {ollama_model}")
                return result.stdout, stop_tokens
            else:
                # Model not available, try to pull it
                print(f"  Pulling base model for template: {ollama_model}")
                subprocess.run(
                    ["ollama", "pull", ollama_model],
                    capture_output=True,
                    timeout=300,
                )
                # Retry getting template
                result = subprocess.run(
                    ["ollama", "show", ollama_model, "--template"],
                    capture_output=True,
                    text=True,
                    timeout=30,
                )
                if result.returncode == 0 and result.stdout.strip():
                    return result.stdout, stop_tokens
        except Exception as e:
            print(f"  Warning: Could not fetch template: {e}")

        return None, stop_tokens


class SFTTrainer:
    """High-level SFT trainer for omniagents.

    This class wraps TRL's SFTTrainer with omniagents-specific functionality:
    - Direct training from session traces with filtering
    - Integration with judgment/measure-based filtering
    - Simplified API for common use cases

    Example:
        from omniagents.training import SFTTrainer, SFTTrainingConfig

        # Create trainer
        trainer = SFTTrainer(model_name="Qwen/Qwen2.5-0.5B")

        # Train from session traces
        result = trainer.train_from_traces(
            "my_project", "my_agent",
            judgment="acceptable",
        )

        # Or from a dataset
        dataset = export_traces_for_sft(...)
        result = trainer.train(dataset)

        # Save the model
        result.save_model("./trained_model")
    """

    def __init__(
        self,
        model_name: str,
        *,
        config: Optional[SFTTrainingConfig] = None,
        torch_dtype: str = "bfloat16",
        device_map: str = "auto",
    ):
        """Initialize the SFT trainer.

        Args:
            model_name: HuggingFace model name or path (e.g., "Qwen/Qwen2.5-0.5B")
            config: Training configuration (default: SFTTrainingConfig())
            torch_dtype: Torch dtype for model ("bfloat16", "float16", "float32")
            device_map: Device map for model loading ("auto", "cpu", "cuda")
        """
        self.model_name = model_name
        self.config = config or SFTTrainingConfig()
        self.torch_dtype = torch_dtype
        self.device_map = device_map

        self._model = None
        self._tokenizer = None
        self._trl_trainer = None

    def _load_model(self):
        """Load the model and tokenizer."""
        try:
            import torch
            from transformers import AutoModelForCausalLM, AutoTokenizer
        except ImportError:
            raise ImportError(
                "transformers and torch are required. "
                "Install with: pip install transformers torch"
            )

        # Map dtype string to torch dtype
        dtype_map = {
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
            "float32": torch.float32,
        }
        dtype = dtype_map.get(self.torch_dtype, torch.bfloat16)

        print(f"Loading model: {self.model_name}...")
        self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self._model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=dtype,
            device_map=self.device_map,
        )

        # Set pad token if needed
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = self._tokenizer.eos_token

        param_count = sum(p.numel() for p in self._model.parameters())
        print(f"Model loaded! Parameters: {param_count:,}")

    def train(
        self,
        dataset: "Dataset",
    ) -> SFTTrainingResult:
        """Train the model using SFT on a dataset.

        Args:
            dataset: HuggingFace Dataset with 'messages' column (conversational format)
                     or 'prompt'/'completion' columns

        Returns:
            SFTTrainingResult with trained model and metrics
        """
        try:
            from trl import SFTTrainer as TRLSFTTrainer
        except ImportError:
            raise ImportError(
                "trl library is required for SFT training. "
                "Install with: pip install trl>=0.12.0"
            )

        # Load model if not already loaded
        if self._model is None:
            self._load_model()

        print(f"Training on {len(dataset)} samples...")

        # Create TRL trainer
        trl_config = self.config.to_trl_config()
        print(f"SFT Configuration:")
        print(f"  Max sequence length: {trl_config.max_length}")
        print(f"  Learning rate: {trl_config.learning_rate}")
        print(f"  Epochs: {trl_config.num_train_epochs}")
        print(f"  Packing: {trl_config.packing}")

        self._trl_trainer = TRLSFTTrainer(
            model=self._model,
            args=trl_config,
            train_dataset=dataset,
            processing_class=self._tokenizer,
        )

        # Train!
        print("\nStarting SFT training...\n")
        self._trl_trainer.train()

        # Collect metrics
        metrics = []
        if hasattr(self._trl_trainer, "state") and hasattr(self._trl_trainer.state, "log_history"):
            metrics = self._trl_trainer.state.log_history

        final_loss = None
        if metrics:
            for entry in reversed(metrics):
                if "loss" in entry:
                    final_loss = entry["loss"]
                    break

        return SFTTrainingResult(
            metrics=metrics,
            final_loss=final_loss,
            trainer=self._trl_trainer,
        )

    def train_from_traces(
        self,
        project_slug: str,
        agent_slug: str,
        *,
        judgment: Optional[Literal["acceptable", "unacceptable"]] = None,
        measure: Optional[Union[str, Callable]] = None,
        measure_passed: bool = True,
        filter_fn: Optional[Callable[[List[Dict[str, Any]]], bool]] = None,
        include_system: bool = True,
        max_sessions: Optional[int] = None,
    ) -> SFTTrainingResult:
        """Train directly from session traces.

        This combines export_traces_for_sft() and train() into a single call.

        Args:
            project_slug: Project identifier for the sessions database
            agent_slug: Agent identifier for the sessions database
            judgment: Filter by trace judgment ("acceptable" or "unacceptable")
            measure: Measure name or function to filter by
            measure_passed: Include sessions where measure passed (True) or failed (False)
            filter_fn: Custom predicate on history
            include_system: Include system messages
            max_sessions: Maximum sessions to include

        Returns:
            SFTTrainingResult with trained model and metrics

        Example:
            trainer = SFTTrainer(model_name="Qwen/Qwen2.5-0.5B")
            result = trainer.train_from_traces(
                "my_project", "my_agent",
                judgment="acceptable",
            )
            result.save_model("./sft_model")
        """
        # Export traces to dataset
        print(f"Exporting traces from {project_slug}/{agent_slug}...")
        dataset = export_traces_for_sft(
            project_slug,
            agent_slug,
            judgment=judgment,
            measure=measure,
            measure_passed=measure_passed,
            filter_fn=filter_fn,
            format="conversational",
            include_system=include_system,
            max_sessions=max_sessions,
        )

        if len(dataset) == 0:
            raise ValueError(
                f"No sessions found matching the filter criteria in {project_slug}/{agent_slug}. "
                "Try adjusting your filters or check that sessions exist."
            )

        print(f"Found {len(dataset)} conversations")
        return self.train(dataset)

    @property
    def model(self):
        """Access the underlying model."""
        return self._model

    @property
    def tokenizer(self):
        """Access the tokenizer."""
        return self._tokenizer


def train_sft(
    model_name: str,
    project_slug: str,
    agent_slug: str,
    *,
    config: Optional[SFTTrainingConfig] = None,
    judgment: Optional[Literal["acceptable", "unacceptable"]] = None,
    measure: Optional[Union[str, Callable]] = None,
    measure_passed: bool = True,
    filter_fn: Optional[Callable[[List[Dict[str, Any]]], bool]] = None,
    include_system: bool = True,
    max_sessions: Optional[int] = None,
    torch_dtype: str = "bfloat16",
    device_map: str = "auto",
) -> SFTTrainingResult:
    """Train a model with SFT using session traces.

    This is the simplest way to fine-tune a model on successful conversations.
    For more control, use the SFTTrainer class directly.

    Args:
        model_name: HuggingFace model name or path
        project_slug: Project identifier for sessions
        agent_slug: Agent identifier for sessions
        config: Training configuration (default: SFTTrainingConfig())
        judgment: Filter by trace judgment ("acceptable" or "unacceptable")
        measure: Measure name or function to filter by
        measure_passed: Include sessions where measure passed
        filter_fn: Custom predicate on history
        include_system: Include system messages
        max_sessions: Maximum sessions to include
        torch_dtype: Torch dtype for model
        device_map: Device map for model loading

    Returns:
        SFTTrainingResult with trained model and metrics

    Example:
        from omniagents.training import train_sft, SFTTrainingConfig

        config = SFTTrainingConfig(
            num_train_epochs=1,
            learning_rate=2e-5,
        )

        result = train_sft(
            model_name="Qwen/Qwen2.5-0.5B",
            project_slug="my_project",
            agent_slug="my_agent",
            judgment="acceptable",
            config=config,
        )

        result.save_model("./sft_trained")
    """
    trainer = SFTTrainer(
        model_name=model_name,
        config=config,
        torch_dtype=torch_dtype,
        device_map=device_map,
    )

    return trainer.train_from_traces(
        project_slug,
        agent_slug,
        judgment=judgment,
        measure=measure,
        measure_passed=measure_passed,
        filter_fn=filter_fn,
        include_system=include_system,
        max_sessions=max_sessions,
    )


def eval_results_to_sft_dataset(
    results: Any,  # EvalSuiteResults - avoid import cycle
    *,
    filter_passed: bool = True,
    filter_fn: Optional[Callable[[Any], bool]] = None,
    include_system: bool = True,
    include_tool_calls: bool = True,
) -> "Dataset":
    """
    Convert evaluation results to SFT training dataset.

    This is the key function for distillation - it extracts successful
    teacher runs and converts them to a format suitable for SFT training,
    preserving tool calls so the student can learn agentic behavior.

    Args:
        results: EvalSuiteResults from running suite.run(agent)
        filter_passed: If True, only include results where all measures passed
        filter_fn: Optional custom filter function that takes an EvalResult
        include_system: Include system messages in training data
        include_tool_calls: Include tool calls in OpenAI format (essential for
                           teaching tool use to the student model)

    Returns:
        HuggingFace Dataset with 'messages' column in OpenAI chat format

    Example:
        # Run teacher model to generate expert traces
        teacher_results = await suite.run(teacher_agent)

        # Convert successful runs to training dataset
        dataset = eval_results_to_sft_dataset(
            teacher_results,
            filter_passed=True,
            include_tool_calls=True,
        )

        # Train student model
        trainer = SFTTrainer("Qwen/Qwen2.5-0.5B-Instruct")
        result = trainer.train(dataset)
    """
    if Dataset is None:
        raise ImportError(
            "The 'datasets' package is required for eval_results_to_sft_dataset. "
            "Install it with: pip install datasets"
        )

    all_messages = []

    for result in results.results:
        # Apply filters
        if filter_passed and not result.passed:
            continue
        if filter_fn and not filter_fn(result):
            continue

        # Get history from result
        history = getattr(result, "history", None)
        if not history:
            continue

        # Convert history to OpenAI format
        if include_tool_calls:
            messages = _convert_history_to_openai_format(history, include_system)
        else:
            messages = []
            for msg in history:
                role = msg.get("role")
                if role not in ("system", "user", "assistant"):
                    continue
                if role == "system" and not include_system:
                    continue
                content = _extract_content(msg)
                if content:
                    messages.append({"role": role, "content": content})

        if messages:
            all_messages.append(messages)

    return Dataset.from_dict({"messages": all_messages})


def export_traces_for_sft(
    project_slug: str,
    agent_slug: str,
    *,
    # Filtering options
    judgment: Optional[Literal["acceptable", "unacceptable"]] = None,
    measure: Optional[Union[str, Callable]] = None,
    measure_passed: bool = True,
    filter_fn: Optional[Callable[[List[Dict[str, Any]]], bool]] = None,
    # Output format
    format: Literal["conversational", "prompt_completion"] = "conversational",
    include_system: bool = True,
    # Limits
    max_sessions: Optional[int] = None,
) -> "Dataset":
    """
    Export session traces for SFT training.

    This function loads sessions from the omniagents session database and
    optionally filters them using judgments, measures, or custom predicates.
    The output is a HuggingFace Dataset ready for SFTTrainer.

    Args:
        project_slug: Project identifier for the sessions database.
        agent_slug: Agent identifier for the sessions database.
        judgment: Filter by trace judgment from Studio ("acceptable" or "unacceptable").
                  Requires traces to have been analyzed in Studio.
        measure: Measure name (string) or measure function to replay on history.
                 Sessions are filtered based on whether the measure passes or fails.
        measure_passed: If True (default), include sessions where measure passed.
                        If False, include sessions where measure failed.
        filter_fn: Custom predicate function that takes a history (list of messages)
                   and returns True to include, False to exclude.
        format: Output format for SFTTrainer:
                - "conversational": {"messages": [{"role": ..., "content": ...}, ...]}
                - "prompt_completion": {"prompt": ..., "completion": ...}
        include_system: Include system messages in output (default True).
        max_sessions: Maximum number of sessions to include (after filtering).

    Returns:
        HuggingFace Dataset ready for SFTTrainer.

    Raises:
        ImportError: If the 'datasets' package is not installed.
        ValueError: If invalid parameters are provided.

    Example:
        # Export conversations marked as good in Studio UI
        dataset = export_traces_for_sft(
            "my_project", "my_agent",
            judgment="acceptable",
        )

        # Export conversations where a custom measure passed
        @measure
        def good_response(ctx):
            # ... evaluation logic
            return pass_reason() if ok else fail_reason("bad")

        dataset = export_traces_for_sft(
            "my_project", "my_agent",
            measure=good_response,
        )

        # Train with SFTTrainer
        from trl import SFTTrainer
        trainer = SFTTrainer(model="...", train_dataset=dataset)
        trainer.train()
    """
    if Dataset is None:
        raise ImportError(
            "The 'datasets' package is required for export_traces_for_sft. "
            "Install it with: pip install datasets"
        )

    # Import here to avoid circular imports
    from omniagents.core.session.history_db import list_sessions, load_history
    from omniagents.core.evaluation.context import build_eval_context
    from omniagents.core.evaluation.registry import get_measure

    # Load judgment data if filtering by judgment
    judgments_by_session: Dict[str, str] = {}
    if judgment is not None:
        judgments_by_session = _load_judgments(project_slug)

    # Resolve measure function if provided as string
    measure_fn: Optional[Callable] = None
    if measure is not None:
        if isinstance(measure, str):
            measure_fn = get_measure(measure)
        elif callable(measure):
            measure_fn = measure
        else:
            raise ValueError(
                f"measure must be a string (measure name) or callable, got {type(measure)}"
            )

    # List all sessions
    sessions = list_sessions(project_slug=project_slug, agent_slug=agent_slug)

    # Collect filtered histories
    filtered_histories: List[List[Dict[str, Any]]] = []

    for session_info in sessions:
        session_id = session_info["id"]

        # Skip archived sessions by default
        if session_info.get("archived", False):
            continue

        # Skip empty sessions
        if session_info.get("message_count", 0) == 0:
            continue

        # Filter by judgment if specified
        if judgment is not None:
            session_judgment = judgments_by_session.get(session_id)
            if session_judgment != judgment:
                continue

        # Load full history
        history = load_history(
            session_id, project_slug=project_slug, agent_slug=agent_slug
        )

        if not history:
            continue

        # Filter by measure if specified
        if measure_fn is not None:
            ctx = build_eval_context(history=history)
            try:
                result = measure_fn(ctx)
                passed = result.get("passed", False) if isinstance(result, dict) else False
            except Exception:
                # If measure fails to run, skip this session
                continue

            if passed != measure_passed:
                continue

        # Apply custom filter
        if filter_fn is not None:
            try:
                if not filter_fn(history):
                    continue
            except Exception:
                continue

        filtered_histories.append(history)

        # Check max_sessions limit
        if max_sessions is not None and len(filtered_histories) >= max_sessions:
            break

    # Convert to dataset format
    if format == "conversational":
        data = _to_conversational_format(filtered_histories, include_system)
    elif format == "prompt_completion":
        data = _to_prompt_completion_format(filtered_histories, include_system)
    else:
        raise ValueError(f"Unknown format: {format}. Must be 'conversational' or 'prompt_completion'")

    return Dataset.from_dict(data)


def _load_judgments(project_slug: str) -> Dict[str, str]:
    """Load judgments from trace analysis database.

    The trace analysis stores judgments by group_id, which often corresponds
    to session_id in omniagents.
    """
    from omniagents.core.paths import get_traces_dir

    judgments: Dict[str, str] = {}
    db_path = get_traces_dir() / f"{project_slug}.db"

    if not db_path.exists():
        return judgments

    try:
        import sqlite3

        conn = sqlite3.connect(str(db_path))
        conn.row_factory = sqlite3.Row
        cursor = conn.execute(
            "SELECT group_id, judgment FROM analysis WHERE judgment IS NOT NULL"
        )
        for row in cursor:
            judgments[row["group_id"]] = row["judgment"]
        conn.close()
    except Exception:
        pass

    return judgments


def _to_conversational_format(
    histories: List[List[Dict[str, Any]]],
    include_system: bool,
    include_tool_calls: bool = True,
) -> Dict[str, List[Any]]:
    """Convert histories to conversational format for SFTTrainer.

    Output format: {"messages": [[{"role": "user", "content": "..."}, ...], ...]}

    Args:
        histories: List of conversation histories
        include_system: Include system messages
        include_tool_calls: If True, convert tool calls to OpenAI format.
                           If False, only include text messages (original behavior).
    """
    all_messages: List[List[Dict[str, Any]]] = []

    for history in histories:
        if include_tool_calls:
            messages = _convert_history_to_openai_format(history, include_system)
        else:
            # Original behavior - text only
            messages = []
            for msg in history:
                role = msg.get("role")
                if role not in ("system", "user", "assistant"):
                    continue
                if role == "system" and not include_system:
                    continue

                content = _extract_content(msg)
                if content:
                    messages.append({"role": role, "content": content})

        if messages:
            all_messages.append(messages)

    return {"messages": all_messages}


def _convert_history_to_openai_format(
    history: List[Dict[str, Any]],
    include_system: bool = True,
) -> List[Dict[str, Any]]:
    """
    Convert omniagents history format to OpenAI chat format with tool calls.

    This is essential for SFT training on agentic tasks where the student
    needs to learn tool calling behavior from the teacher.

    omniagents format:
    - {"role": "system", "content": "..."}
    - {"role": "user", "content": "..."}
    - {"role": "assistant", "content": "...", "type": "message"}
    - {"type": "function_call", "name": "geocode", "arguments": {...}, "call_id": "..."}
    - {"type": "function_call_output", "call_id": "...", "output": "..."}

    OpenAI format:
    - {"role": "system", "content": "..."}
    - {"role": "user", "content": "..."}
    - {"role": "assistant", "content": "...", "tool_calls": [...]}
    - {"role": "tool", "tool_call_id": "...", "content": "..."}
    """
    import json as _json

    messages = []

    i = 0
    while i < len(history):
        item = history[i]

        if not isinstance(item, dict):
            i += 1
            continue

        role = item.get("role")
        item_type = item.get("type")

        # System message
        if role == "system":
            if include_system:
                content = item.get("content", "")
                if content:
                    messages.append({"role": "system", "content": content})
            i += 1
            continue

        # User message
        if role == "user":
            content = _extract_content(item)
            if content:
                messages.append({"role": "user", "content": content})
            i += 1
            continue

        # Assistant message (text response)
        if role == "assistant":
            content = _extract_content(item)

            # Look ahead for tool calls that follow this message
            tool_calls = []
            j = i + 1
            while j < len(history):
                next_item = history[j]
                if next_item.get("type") == "function_call":
                    call_id = next_item.get("call_id", f"call_{j}")
                    name = next_item.get("name", "")
                    args = next_item.get("arguments", {})
                    if isinstance(args, str):
                        args_str = args
                    else:
                        args_str = _json.dumps(args)

                    tool_calls.append({
                        "id": call_id,
                        "type": "function",
                        "function": {
                            "name": name,
                            "arguments": args_str,
                        }
                    })
                    j += 1
                elif next_item.get("type") == "function_call_output":
                    # Tool output - we'll handle after adding assistant message
                    break
                elif next_item.get("role") in ("user", "assistant"):
                    # Next turn
                    break
                else:
                    j += 1

            # Build assistant message
            assistant_msg: Dict[str, Any] = {"role": "assistant"}
            if content:
                assistant_msg["content"] = content
            if tool_calls:
                assistant_msg["tool_calls"] = tool_calls

            # Only add if there's content or tool calls
            if content or tool_calls:
                messages.append(assistant_msg)

            i = j  # Skip to after tool calls
            continue

        # Function call (standalone, not preceded by assistant message)
        if item_type == "function_call":
            # Create an assistant message with this tool call
            call_id = item.get("call_id", f"call_{i}")
            name = item.get("name", "")
            args = item.get("arguments", {})
            if isinstance(args, str):
                args_str = args
            else:
                args_str = _json.dumps(args)

            # Look for more consecutive tool calls
            tool_calls = [{
                "id": call_id,
                "type": "function",
                "function": {
                    "name": name,
                    "arguments": args_str,
                }
            }]

            j = i + 1
            while j < len(history):
                next_item = history[j]
                if next_item.get("type") == "function_call":
                    nc_id = next_item.get("call_id", f"call_{j}")
                    nc_name = next_item.get("name", "")
                    nc_args = next_item.get("arguments", {})
                    if isinstance(nc_args, str):
                        nc_args_str = nc_args
                    else:
                        nc_args_str = _json.dumps(nc_args)

                    tool_calls.append({
                        "id": nc_id,
                        "type": "function",
                        "function": {
                            "name": nc_name,
                            "arguments": nc_args_str,
                        }
                    })
                    j += 1
                else:
                    break

            messages.append({
                "role": "assistant",
                "content": "",
                "tool_calls": tool_calls,
            })
            i = j
            continue

        # Function call output (tool response)
        if item_type == "function_call_output":
            call_id = item.get("call_id", "")
            output = item.get("output", "")
            if isinstance(output, dict):
                output = _json.dumps(output)

            messages.append({
                "role": "tool",
                "tool_call_id": call_id,
                "content": str(output),
            })
            i += 1
            continue

        # Skip unknown items
        i += 1

    return messages


def _to_prompt_completion_format(
    histories: List[List[Dict[str, Any]]],
    include_system: bool,
) -> Dict[str, List[str]]:
    """Convert histories to prompt-completion format for SFTTrainer.

    Output format: {"prompt": [...], "completion": [...]}

    For multi-turn conversations, the prompt includes all messages up to
    the final user message, and completion is the final assistant response.
    """
    prompts: List[str] = []
    completions: List[str] = []

    for history in histories:
        # Find the last user message and last assistant message
        last_user_idx = -1
        last_assistant_idx = -1

        for i, msg in enumerate(history):
            role = msg.get("role")
            if role == "user":
                last_user_idx = i
            elif role == "assistant":
                last_assistant_idx = i

        # Skip if no user or assistant message
        if last_user_idx == -1 or last_assistant_idx == -1:
            continue

        # Skip if assistant message comes before last user message
        # (we want prompt -> completion flow)
        if last_assistant_idx < last_user_idx:
            continue

        # Build prompt from all messages up to and including last user message
        prompt_parts: List[str] = []
        for i, msg in enumerate(history):
            if i > last_user_idx:
                break

            role = msg.get("role")
            if role not in ("system", "user", "assistant"):
                continue
            if role == "system" and not include_system:
                continue

            content = _extract_content(msg)
            if content:
                if role == "system":
                    prompt_parts.append(f"System: {content}")
                elif role == "user":
                    prompt_parts.append(f"User: {content}")
                elif role == "assistant":
                    prompt_parts.append(f"Assistant: {content}")

        # Get completion (final assistant response)
        completion = _extract_content(history[last_assistant_idx])

        if prompt_parts and completion:
            prompts.append("\n\n".join(prompt_parts))
            completions.append(completion)

    return {"prompt": prompts, "completion": completions}


def _extract_content(msg: Dict[str, Any]) -> str:
    """Extract text content from a message."""
    content = msg.get("content")

    if isinstance(content, str):
        return content.strip()

    # Handle structured content (list of content parts)
    if isinstance(content, list):
        parts: List[str] = []
        for part in content:
            if isinstance(part, dict):
                if part.get("type") == "text":
                    text = part.get("text", "")
                    if text:
                        parts.append(text)
                elif part.get("type") == "output_text":
                    text = part.get("text", "")
                    if text:
                        parts.append(text)
            elif isinstance(part, str):
                parts.append(part)
        return " ".join(parts).strip()

    return ""