Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
"""GRPO (Group Relative Policy Optimization) training for omniagents.

This module provides a high-level API for training agents using GRPO
with omniagents evaluation measures as the reward signal.

GRPO is an RL algorithm that:
1. Generates multiple rollouts for each problem (with tool use)
2. Scores them with rewards (from omniagents measures)
3. Computes advantages by comparing within each problem's group
4. Updates the model to make better solutions more likely

The key benefit is that you can train agents with tools using only
pass/fail evaluation signals!
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
import subprocess

from .rewards import measure_to_reward, MeasureRewardAdapter, combine_rewards
from .dataset import eval_suite_to_hf_dataset


# Response schema for Qwen2.5 models (similar to TRL's qwen3_schema but without <think> blocks)
# This allows TRL's GRPOTrainer to parse tool calls from Qwen2.5 models.
# Format: <tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>
QWEN25_RESPONSE_SCHEMA = {
    "x-regex": r"^(?P<content>.*?)(?=(?:<tool_call>|<\|im_end\|>|$))(?P<tool_calls>(?:<tool_call>.+?</tool_call>\s*)+)?\s*(?:<\|im_end\|>|$)",
    "type": "object",
    "properties": {
        "role": {"const": "assistant"},
        "content": {"type": "string"},
        "tool_calls": {
            "type": "array",
            "x-regex-iterator": r"<tool_call>\s*(.+?)\s*</tool_call>",
            "items": {
                "x-parser": "json",
                "x-parser-args": {"transform": "{type: 'function', function: @}"},
                "type": "object",
                "properties": {
                    "type": {"const": "function"},
                    "function": {
                        "type": "object",
                        "properties": {
                            "name": {"type": "string"},
                            "arguments": {
                                "type": "object",
                                "additionalProperties": {},
                            },
                        },
                    },
                },
            },
        },
    },
}


if TYPE_CHECKING:
    from datasets import Dataset
    from transformers import PreTrainedModel, PreTrainedTokenizer
    from trl import GRPOTrainer as TRLGRPOTrainer
    from omniagents.notebook.evaluation import EvalSuite
    from agents import Agent


def _extract_hf_model_name(model: str) -> str:
    """Extract HuggingFace model name from various formats.

    Supports:
    - Direct HF model: "Qwen/Qwen3-0.6B"
    - LiteLLM/Ollama: "litellm/ollama_chat/hf.co/Qwen/Qwen3-0.6B-GGUF:Q8_0"
    - Ollama HF: "hf.co/Qwen/Qwen3-0.6B-GGUF:Q8_0"

    Returns the canonical HuggingFace model name.
    """
    if model is None:
        raise ValueError("Agent must have a model specified for training")

    # Strip litellm/ollama_chat prefix
    if model.startswith("litellm/ollama_chat/"):
        model = model[len("litellm/ollama_chat/"):]

    # Handle hf.co prefix (Ollama HuggingFace models)
    if model.startswith("hf.co/"):
        model = model[len("hf.co/"):]
        # Format: Qwen/Qwen3-0.6B-GGUF:Q8_0
        # Strip quantization tag
        if ":" in model:
            model = model.split(":")[0]
        # Strip -GGUF suffix
        if model.endswith("-GGUF"):
            model = model[:-5]

    return model


class ToolSchemaWrapper(dict):
    """A wrapper that makes a JSON schema look like a callable tool to TRL.

    This class:
    1. Subclasses dict so isinstance(tool, dict) returns True for transformers
    2. Has __name__ attribute so TRL's tool.__name__ works
    3. Is callable so TRL can execute the underlying tool function

    This bypasses the Google-style docstring limitation in transformers by
    using pre-parsed JSON schemas from the openai-agents SDK (which supports
    Google/NumPy/Sphinx docstrings via griffe).
    """

    def __init__(self, schema: Dict[str, Any], func: Callable):
        super().__init__(schema)
        self._func = func
        self.__name__ = schema["function"]["name"]
        self.__doc__ = schema["function"].get("description", "")

    def __call__(self, *args, **kwargs):
        return self._func(*args, **kwargs)


def _extract_tools_as_json_schemas(agent: "Agent") -> List[ToolSchemaWrapper]:
    """Extract tool JSON schemas from an Agent for TRL's GRPOTrainer.

    This extracts the pre-parsed JSON schemas from FunctionTool objects,
    bypassing the Google-style docstring limitation in transformers.
    The openai-agents SDK uses griffe which auto-detects docstring style
    (Google/NumPy/Sphinx) and stores the schema in FunctionTool.params_json_schema.

    Returns ToolSchemaWrapper objects that:
    - Are dict subclasses (so transformers uses them as JSON schemas directly)
    - Have __name__ attribute (so TRL can build its tool dict)
    - Are callable (so TRL can execute the tools)
    """
    if not agent.tools:
        return []

    wrappers = []
    for tool in agent.tools:
        # FunctionTool from agents SDK already has the parsed schema
        if hasattr(tool, 'params_json_schema'):
            schema = {
                "type": "function",
                "function": {
                    "name": tool.name,
                    "description": tool.description or "",
                    "parameters": tool.params_json_schema
                }
            }
            # Get the original function for direct invocation by TRL
            if hasattr(tool, '_original_func'):
                func = tool._original_func
            else:
                raise ValueError(
                    f"Tool '{tool.name}' missing _original_func attribute. "
                    f"Make sure tools are created with omniagents' @function_tool decorator."
                )

            wrappers.append(ToolSchemaWrapper(schema, func))

    return wrappers


def _get_agent_instructions(agent: "Agent") -> Optional[str]:
    """Extract instructions from an Agent.

    Instructions can be a string or a callable. For training, we only
    support static string instructions.
    """
    instructions = agent.instructions
    if instructions is None:
        return None
    if callable(instructions):
        raise ValueError(
            "Dynamic instructions (callable) are not supported for training. "
            "Please use static string instructions."
        )
    return instructions


@dataclass
class GRPOTrainingConfig:
    """Configuration for GRPO training.

    This wraps TRL's GRPOConfig with sensible defaults for agent training.

    Attributes:
        # Training parameters
        num_train_epochs: Number of training epochs (default: 1)
        per_device_batch_size: Problems per batch (default: 2)
        gradient_accumulation_steps: Gradient accumulation (default: 2)
        learning_rate: Learning rate (default: 1e-5)
        max_grad_norm: Gradient clipping (default: 1.0)

        # GRPO-specific parameters
        num_generations: Number of rollouts per problem (default: 4)
        max_completion_length: Maximum tokens to generate (default: 1024).
            For agents with tools, this must be large enough to fit ALL tool
            calls plus the final response. Each tool call uses ~50-70 tokens.
            256 is too small for multi-step tool use.
        temperature: Sampling temperature for diversity (default: 0.8)
        mask_truncated_completions: Exclude truncated completions from loss
            calculation to prevent learning from garbage (default: True)

        # Tool-related parameters
        max_tool_calling_iterations: Max tool call rounds (default: 5)

        # Output
        output_dir: Directory for checkpoints and logs
        save_strategy: When to save ("no", "epoch", "steps")
        logging_steps: Log every N steps

        # Reporting
        report_to: Where to report metrics ("none", "wandb", "tensorboard")
    """

    # Training parameters
    num_train_epochs: int = 1
    per_device_batch_size: int = 2
    gradient_accumulation_steps: int = 2
    learning_rate: float = 1e-5
    max_grad_norm: float = 1.0

    # GRPO-specific parameters
    num_generations: int = 4
    max_completion_length: int = 1024
    temperature: float = 0.8
    mask_truncated_completions: bool = True

    # Memory optimization
    gradient_checkpointing: bool = False  # Trade compute for memory (useful for 8GB VRAM)
    torch_dtype: str = "bfloat16"  # Model dtype: "bfloat16", "float16", "float32"

    # LoRA / QLoRA (parameter-efficient fine-tuning). Trains small adapters
    # instead of the full model, slashing memory so a capable model fits on a
    # laptop/consumer GPU. Works with ECHO (the aux loss trains the same
    # adapters). load_in_4bit (QLoRA) needs a CUDA GPU + bitsandbytes.
    use_lora: bool = False
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: Optional[List[str]] = None  # None -> all linear layers
    load_in_4bit: bool = False  # 4-bit base (QLoRA); implies use_lora

    # Tool-related parameters
    max_tool_calling_iterations: int = 5

    # ECHO auxiliary world-modeling loss (Environment Cross-entropy Hybrid Objective)
    # When > 0, adds an on-policy cross-entropy loss over tool-result tokens (the
    # "environment" tokens the GRPO loss ignores), teaching the policy a world
    # model of its tools "for free". 0.0 = vanilla GRPO. Requires the agent to
    # have tools. See omniagents.core.training.echo.EchoGRPOTrainer.
    world_model_coeff: float = 0.0
    # Normalization for the auxiliary loss: "selected_tokens" (token-mean),
    # "sequence_mean", or "seq_mean_token_sum_norm".
    world_loss_normalization: str = "selected_tokens"

    # vLLM-accelerated generation (the throughput lever for RL). "colocate"
    # shares the GPU with training (single-GPU); "server" talks to a separate
    # vLLM server. Big speedup at scale; needs vllm installed + a CUDA GPU.
    use_vllm: bool = False
    vllm_mode: str = "colocate"
    vllm_gpu_memory_utilization: float = 0.3
    vllm_max_model_len: Optional[int] = None
    vllm_tensor_parallel_size: int = 1

    # Reproducibility (for matched baseline-vs-ECHO runs across seeds)
    seed: int = 0

    # Output
    output_dir: str = "./grpo_output"
    save_strategy: str = "no"
    logging_steps: int = 5

    # Reporting
    report_to: str = "none"

    def to_trl_config(self) -> "GRPOConfig":
        """Convert to TRL's GRPOConfig."""
        try:
            from trl import GRPOConfig
        except ImportError:
            raise ImportError(
                "trl library is required for GRPO training. "
                "Install with: pip install trl>=0.12.0"
            )

        effective_batch = self.per_device_batch_size * self.gradient_accumulation_steps
        if effective_batch % self.num_generations != 0:
            adjusted_accum = self.num_generations
            print(f"Note: Adjusting gradient_accumulation_steps from {self.gradient_accumulation_steps} "
                  f"to {adjusted_accum} to be divisible by num_generations={self.num_generations}")
        else:
            adjusted_accum = self.gradient_accumulation_steps

        return GRPOConfig(
            output_dir=self.output_dir,
            num_train_epochs=self.num_train_epochs,
            per_device_train_batch_size=self.per_device_batch_size,
            gradient_accumulation_steps=adjusted_accum,
            learning_rate=self.learning_rate,
            max_grad_norm=self.max_grad_norm,
            num_generations=self.num_generations,
            max_completion_length=self.max_completion_length,
            max_tool_calling_iterations=self.max_tool_calling_iterations,
            temperature=self.temperature,
            mask_truncated_completions=self.mask_truncated_completions,
            gradient_checkpointing=self.gradient_checkpointing,
            seed=self.seed,
            use_vllm=self.use_vllm,
            vllm_mode=self.vllm_mode,
            vllm_gpu_memory_utilization=self.vllm_gpu_memory_utilization,
            vllm_max_model_length=self.vllm_max_model_len,
            vllm_tensor_parallel_size=self.vllm_tensor_parallel_size,
            save_strategy=self.save_strategy,
            logging_steps=self.logging_steps,
            report_to=self.report_to,
        )

    def build_peft_config(self):
        """Return a peft LoraConfig if LoRA is enabled, else None."""
        if not (self.use_lora or self.load_in_4bit):
            return None
        try:
            from peft import LoraConfig
        except ImportError:
            raise ImportError("LoRA training requires peft. Install with: pip install peft")
        return LoraConfig(
            r=self.lora_r,
            lora_alpha=self.lora_alpha,
            lora_dropout=self.lora_dropout,
            target_modules=self.lora_target_modules or "all-linear",
            task_type="CAUSAL_LM",
        )


@dataclass
class GRPOTrainingResult:
    """Result of GRPO training.

    Attributes:
        model_path: Path to the saved model (if saved)
        metrics: Training metrics from the trainer
        final_loss: Final training loss
        trainer: The underlying TRL GRPOTrainer (for advanced use)
        tools: Tool schemas used during training (for inference)
        instructions: System instructions used during training (for inference)
    """

    model_path: Optional[str] = None
    metrics: List[Dict[str, Any]] = field(default_factory=list)
    final_loss: Optional[float] = None
    trainer: Optional[Any] = None
    tools: Optional[List[Dict[str, Any]]] = None
    instructions: Optional[str] = None
    _inference_ready: bool = field(default=False, repr=False)

    def save_model(self, path: str) -> str:
        """Save the trained model to a path.

        Args:
            path: Directory to save the model

        Returns:
            The path where the model was saved
        """
        if self.trainer is None:
            raise ValueError("No trainer available - training may have failed")

        path = str(Path(path).resolve())
        self.trainer.save_model(path)
        self.model_path = path
        return path

    def to_ollama(
        self,
        model_name: str,
        *,
        model_path: Optional[str] = None,
        llama_cpp_path: Optional[str] = None,
        quantization: str = "q8_0",
    ) -> str:
        """Convert the trained model to Ollama format and register it.

        This converts the HuggingFace model (safetensors) to GGUF format using
        llama.cpp's convert_hf_to_gguf.py script, then registers it with Ollama.

        The chat template is automatically detected from the model architecture
        and fetched from a matching Ollama base model.

        Args:
            model_name: Name to register the model as in Ollama
            model_path: Path to the model (defaults to self.model_path)
            llama_cpp_path: Path to llama.cpp repo (searches common locations if not specified)
            quantization: GGUF quantization type (default: "q8_0"). Options: f32, f16, bf16, q8_0, etc.

        Returns:
            The Ollama model name that can be used with litellm/ollama_chat/

        Example:
            result = train_grpo(agent, suite, ...)
            result.save_model("./trained_model")
            ollama_name = result.to_ollama("my-trained-agent")

            # Now use with omniagents
            agent.model = f"litellm/ollama_chat/{ollama_name}"
        """
        import json

        path = model_path or self.model_path
        if path is None:
            raise ValueError(
                "No model path available. Call save_model() first or provide model_path."
            )

        path = Path(path).resolve()
        gguf_path = path / "model.gguf"
        modelfile_path = path / "Modelfile"

        # Find llama.cpp convert script
        convert_script = self._find_llama_cpp_convert(llama_cpp_path)

        print(f"Converting model to Ollama format: {model_name}")
        print(f"  Source: {path}")
        print(f"  Using: {convert_script}")

        try:
            # Step 1: Convert safetensors to GGUF using llama.cpp
            print(f"  Converting to GGUF format (quantization: {quantization})...")
            convert_cmd = [
                "python", str(convert_script),
                str(path),
                "--outtype", quantization,
                "--outfile", str(gguf_path),
            ]

            result = subprocess.run(
                convert_cmd,
                check=True,
                capture_output=True,
                text=True,
            )
            print(f"  Created GGUF: {gguf_path}")

            # Step 2: Get chat template from base model
            template, stop_tokens = self._get_ollama_template(path)

            # Step 3: Create Modelfile with template
            modelfile_content = f"FROM {gguf_path}\n\n"
            if template:
                # Escape for Python string
                escaped_template = template.replace('\\', '\\\\').replace('"', '\\"')
                modelfile_content += f'TEMPLATE """{template}"""\n\n'
            for stop in stop_tokens:
                modelfile_content += f'PARAMETER stop "{stop}"\n'

            modelfile_path.write_text(modelfile_content)

            # Step 3: Create Ollama model from GGUF
            print("  Registering with Ollama...")
            ollama_cmd = ["ollama", "create", model_name, "-f", str(modelfile_path)]

            result = subprocess.run(
                ollama_cmd,
                check=True,
                capture_output=True,
                text=True,
            )
            print(f"Successfully created Ollama model: {model_name}")
            return model_name

        except subprocess.CalledProcessError as e:
            error_msg = e.stderr or e.stdout or str(e)
            raise RuntimeError(
                f"Failed to create Ollama model. Error: {error_msg}"
            ) from e
        except FileNotFoundError as e:
            if "ollama" in str(e):
                raise RuntimeError(
                    "ollama command not found. Please install Ollama: https://ollama.com/download"
                ) from e
            raise
        finally:
            # Clean up the Modelfile (keep the GGUF for potential reuse)
            if modelfile_path.exists():
                modelfile_path.unlink()

    def _find_llama_cpp_convert(self, llama_cpp_path: Optional[str] = None) -> Path:
        """Find the llama.cpp convert_hf_to_gguf.py script."""
        import os

        # Check explicit argument first
        if llama_cpp_path:
            script = Path(llama_cpp_path) / "convert_hf_to_gguf.py"
            if script.exists():
                return script
            raise FileNotFoundError(f"convert_hf_to_gguf.py not found at {llama_cpp_path}")

        # Check LLAMA_CPP_PATH environment variable
        env_path = os.environ.get("LLAMA_CPP_PATH")
        if env_path:
            script = Path(env_path) / "convert_hf_to_gguf.py"
            if script.exists():
                return script

        # Search common locations
        search_paths = [
            Path.home() / "code" / "llama.cpp",
            Path.home() / "llama.cpp",
            Path.home() / "projects" / "llama.cpp",
            Path("/opt/llama.cpp"),
            Path.cwd() / "llama.cpp",
        ]

        for base in search_paths:
            script = base / "convert_hf_to_gguf.py"
            if script.exists():
                return script

        raise FileNotFoundError(
            "llama.cpp not found. Please either:\n"
            "  1. Set LLAMA_CPP_PATH in your .env file\n"
            "  2. Clone to ~/code/llama.cpp (auto-detected)\n"
            "  3. Pass llama_cpp_path='/path/to/llama.cpp' to to_ollama()\n\n"
            "To install:\n"
            "  git clone https://github.com/ggml-org/llama.cpp.git\n"
            "  pip install -r llama.cpp/requirements.txt"
        )

    def _get_ollama_template(self, model_path: Path) -> tuple[Optional[str], list[str]]:
        """Get the Ollama chat template based on the model architecture.

        Reads config.json to detect the model architecture, maps it to a known
        Ollama base model, and fetches the template from that model.

        Returns:
            Tuple of (template_string, list_of_stop_tokens)
        """
        import json

        # Map HuggingFace architectures to Ollama base models
        ARCH_TO_OLLAMA_MODEL = {
            "Qwen2ForCausalLM": "qwen2.5:0.5b",
            "Qwen2_5ForCausalLM": "qwen2.5:0.5b",
            "Qwen3ForCausalLM": "qwen3:0.6b",
            "LlamaForCausalLM": "llama3.2:1b",
            "MistralForCausalLM": "mistral:7b",
            "GemmaForCausalLM": "gemma:2b",
            "Gemma2ForCausalLM": "gemma2:2b",
            "PhiForCausalLM": "phi3:mini",
            "Phi3ForCausalLM": "phi3:mini",
        }

        # Default stop tokens for common architectures
        ARCH_TO_STOP_TOKENS = {
            "Qwen2ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "Qwen2_5ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "Qwen3ForCausalLM": ["<|im_start|>", "<|im_end|>"],
            "LlamaForCausalLM": ["<|eot_id|>"],
            "MistralForCausalLM": ["</s>"],
            "GemmaForCausalLM": ["<end_of_turn>"],
            "Gemma2ForCausalLM": ["<end_of_turn>"],
            "PhiForCausalLM": ["<|end|>"],
            "Phi3ForCausalLM": ["<|end|>"],
        }

        # Read model config to get architecture
        config_path = model_path / "config.json"
        if not config_path.exists():
            print("  Warning: config.json not found, using basic template")
            return None, []

        with open(config_path) as f:
            config = json.load(f)

        architectures = config.get("architectures", [])
        if not architectures:
            print("  Warning: No architecture found in config.json")
            return None, []

        arch = architectures[0]
        print(f"  Detected architecture: {arch}")

        # Get corresponding Ollama model
        ollama_model = ARCH_TO_OLLAMA_MODEL.get(arch)
        stop_tokens = ARCH_TO_STOP_TOKENS.get(arch, [])

        if not ollama_model:
            print(f"  Warning: Unknown architecture {arch}, using basic template")
            return None, stop_tokens

        # Try to fetch template from Ollama
        try:
            result = subprocess.run(
                ["ollama", "show", ollama_model, "--template"],
                capture_output=True,
                text=True,
                timeout=30,
            )
            if result.returncode == 0 and result.stdout.strip():
                print(f"  Using template from: {ollama_model}")
                return result.stdout, stop_tokens
            else:
                # Model not available, try to pull it
                print(f"  Pulling base model for template: {ollama_model}")
                pull_result = subprocess.run(
                    ["ollama", "pull", ollama_model],
                    capture_output=True,
                    text=True,
                    timeout=300,
                )
                if pull_result.returncode == 0:
                    # Retry getting template
                    result = subprocess.run(
                        ["ollama", "show", ollama_model, "--template"],
                        capture_output=True,
                        text=True,
                        timeout=30,
                    )
                    if result.returncode == 0 and result.stdout.strip():
                        print(f"  Using template from: {ollama_model}")
                        return result.stdout, stop_tokens
        except subprocess.TimeoutExpired:
            print(f"  Warning: Timeout fetching template from {ollama_model}")
        except Exception as e:
            print(f"  Warning: Failed to get template: {e}")

        return None, stop_tokens

    def _prepare_for_inference(self) -> None:
        """Prepare the model for inference (eval mode, disable gradient checkpointing)."""
        if self._inference_ready:
            return

        if self.trainer is None:
            raise ValueError("No trainer available - training may have failed")

        model = self.trainer.model
        model.eval()
        if hasattr(model, 'gradient_checkpointing_disable'):
            model.gradient_checkpointing_disable()

        self._inference_ready = True

    def generate(
        self,
        prompt: str,
        *,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        repetition_penalty: float = 1.2,
        tools: Optional[List[Dict[str, Any]]] = None,
        instructions: Optional[str] = None,
    ) -> str:
        """Generate a response using the trained model.

        This is a convenience method for quick testing after training.
        For production use, convert to Ollama with `to_ollama()`.

        Args:
            prompt: The user prompt to respond to
            max_new_tokens: Maximum tokens to generate (default: 256)
            temperature: Sampling temperature (default: 0.7)
            repetition_penalty: Penalty for repetition (default: 1.2)
            tools: Override tool schemas (default: use training tools)
            instructions: Override system instructions (default: use training instructions)

        Returns:
            The generated response text

        Example:
            result = train_grpo(agent, suite, ...)
            response = result.generate("Compare the weather in London and Paris.")
            print(response)
        """
        import torch

        self._prepare_for_inference()

        model = self.trainer.model
        tokenizer = self.trainer.processing_class

        # Use provided overrides or fall back to training values
        effective_tools = tools if tools is not None else self.tools
        effective_instructions = instructions if instructions is not None else self.instructions

        # Build messages
        messages = []
        if effective_instructions:
            messages.append({"role": "system", "content": effective_instructions})
        messages.append({"role": "user", "content": prompt})

        # Apply chat template with tools if available
        template_kwargs = {
            "add_generation_prompt": True,
            "tokenize": True,
            "return_dict": True,
            "return_tensors": "pt",
        }
        if effective_tools:
            template_kwargs["tools"] = effective_tools

        inputs = tokenizer.apply_chat_template(messages, **template_kwargs)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=repetition_penalty,
            )

        # Decode only the new tokens
        response = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[-1]:],
            skip_special_tokens=True
        )

        # Strip thinking tags if present (Qwen3 uses <think>...</think>)
        if "</think>" in response:
            response = response.split("</think>")[-1].strip()

        return response


class GRPOTrainer:
    """High-level GRPO trainer for omniagents agents.

    This class wraps TRL's GRPOTrainer with omniagents-specific functionality:
    - Uses Agent's tools, instructions, and model
    - Automatic conversion of measures to reward functions
    - Integration with EvalSuite for training data

    Example:
        from omniagents import Agent, function_tool
        from omniagents.notebook import EvalSuite, measure
        from omniagents.training import GRPOTrainer, GRPOTrainingConfig

        @function_tool
        def calculate(expr: str) -> str:
            '''Evaluate a math expression.'''
            return str(eval(expr))

        agent = Agent(
            name="Math Agent",
            model="Qwen/Qwen3-0.6B",
            tools=[calculate],
            instructions="You are a math assistant. Use the calculator tool.",
        )

        @measure
        def correct_answer(ctx):
            # ... grading logic ...
            return pass_reason("Correct") or fail_reason("Incorrect")

        suite = EvalSuite.from_records(data, input_fn=..., expect_fn=...)

        trainer = GRPOTrainer(
            agent=agent,
            reward_measures=["correct_answer"],
        )
        result = trainer.train(suite)
    """

    def __init__(
        self,
        agent: "Agent",
        reward_measures: List[Union[str, Callable]],
        *,
        config: Optional[GRPOTrainingConfig] = None,
        reward_weights: Optional[List[float]] = None,
        torch_dtype: str = "bfloat16",
        device_map: str = "auto",
        environment_factory: Optional[Callable[[], Any]] = None,
        extra_reward_funcs: Optional[List[Callable]] = None,
        rollout_func: Optional[Callable] = None,
    ):
        """Initialize the GRPO trainer.

        Args:
            agent: The omniagents Agent to train (must have model, can have tools/instructions)
            reward_measures: List of measure names or functions to use as rewards
            config: Training configuration (default: GRPOTrainingConfig())
            reward_weights: Optional weights for combining multiple rewards
            torch_dtype: Torch dtype for model ("bfloat16", "float16", "float32")
            device_map: Device map for model loading ("auto", "cpu", "cuda")
            environment_factory: Optional zero-arg factory returning a per-rollout
                training environment (TRL ``environment_factory``). When set, the
                environment's public methods become the agent's tools (the agent's
                own tools are not forwarded to TRL), enabling sandboxed multi-turn
                rollouts. See ``omniagents.core.training.sandbox_env``.
            extra_reward_funcs: Optional TRL-style reward functions to use in
                addition to (or instead of) ``reward_measures`` -- e.g. a
                verifier reward that inspects each rollout's environment.
        """
        self.agent = agent
        self.reward_measures = reward_measures
        self.config = config or GRPOTrainingConfig()
        self.reward_weights = reward_weights
        self.torch_dtype = torch_dtype
        self.device_map = device_map
        self.environment_factory = environment_factory
        self.extra_reward_funcs = list(extra_reward_funcs or [])
        # Optional fully-custom rollout (TRL `rollout_func`). For most cases the
        # recommended fast path is use_vllm=True on the environment_factory path,
        # where TRL drives vLLM generation + weight sync for you.
        self.rollout_func = rollout_func

        # Extract model name from agent
        self.model_name = _extract_hf_model_name(agent.model)

        # Extract tools as JSON schemas from agent
        self.tools = _extract_tools_as_json_schemas(agent)

        # Extract instructions from agent
        self.instructions = _get_agent_instructions(agent)

        self._model = None
        self._tokenizer = None
        self._reward_fn = None
        self._trl_trainer = None

    def _load_model(self):
        """Load the model and tokenizer."""
        try:
            import torch
            from transformers import AutoModelForCausalLM, AutoTokenizer
        except ImportError:
            raise ImportError(
                "transformers and torch are required. "
                "Install with: pip install transformers torch"
            )

        dtype_map = {
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
            "float32": torch.float32,
        }
        dtype = dtype_map.get(self.torch_dtype, torch.bfloat16)

        print(f"Loading model: {self.model_name}...")
        self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        model_kwargs = {"torch_dtype": dtype, "device_map": self.device_map}
        if self.config.load_in_4bit:
            # QLoRA: load the base in 4-bit (CUDA + bitsandbytes required).
            from transformers import BitsAndBytesConfig

            model_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
            )
            print("  Loading base in 4-bit (QLoRA)")
        self._model = AutoModelForCausalLM.from_pretrained(self.model_name, **model_kwargs)

        if self.config.load_in_4bit:
            from peft import prepare_model_for_kbit_training

            self._model = prepare_model_for_kbit_training(
                self._model,
                use_gradient_checkpointing=self.config.gradient_checkpointing,
            )

        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = self._tokenizer.eos_token

        # Decoder-only models require left-padding for correct generation
        self._tokenizer.padding_side = "left"

        # Set response_schema for models that TRL doesn't natively support
        # This prevents TRL from calling add_response_schema() which only supports Qwen3.
        # Tools may come from the agent OR from a training environment (whose
        # methods become tools), so cover both cases.
        tool_using = bool(self.tools) or self.environment_factory is not None
        if tool_using and not getattr(self._tokenizer, "response_schema", None):
            model_lower = self.model_name.lower()
            if "qwen2.5" in model_lower or "qwen2-" in model_lower:
                print(f"Setting Qwen2.5 response schema for tool parsing...")
                self._tokenizer.response_schema = QWEN25_RESPONSE_SCHEMA

        param_count = sum(p.numel() for p in self._model.parameters())
        print(f"Model loaded! Parameters: {param_count:,}")

        if self.tools:
            print(f"Agent tools: {[t['function']['name'] for t in self.tools]}")
        if self.instructions:
            print(f"Agent instructions: {self.instructions[:100]}...")

    def _build_reward_fn(self):
        """Build the combined reward function from measures (None if no measures)."""
        reward_adapters = [measure_to_reward(m) for m in self.reward_measures]

        if not reward_adapters:
            self._reward_fn = None
        elif len(reward_adapters) == 1:
            self._reward_fn = reward_adapters[0]
        else:
            self._reward_fn = combine_rewards(
                *reward_adapters,
                weights=self.reward_weights,
            )

    def train(
        self,
        suite: "EvalSuite",
    ) -> GRPOTrainingResult:
        """Train the agent using GRPO.

        Args:
            suite: EvalSuite containing training data

        Returns:
            GRPOTrainingResult with trained model and metrics
        """
        try:
            from trl import GRPOTrainer as TRLGRPOTrainer
        except ImportError:
            raise ImportError(
                "trl library is required for GRPO training. "
                "Install with: pip install trl>=0.12.0"
            )

        if self._model is None:
            self._load_model()

        self._build_reward_fn()
        reward_funcs = ([self._reward_fn] if self._reward_fn is not None else []) + self.extra_reward_funcs
        if not reward_funcs:
            raise ValueError(
                "No reward functions: provide reward_measures and/or extra_reward_funcs."
            )

        # Convert suite to HF dataset with conversational format
        print(f"Preparing dataset from {len(suite)} cases...")
        train_dataset = eval_suite_to_hf_dataset(
            suite,
            system_prompt=self.instructions,
        )
        print(f"Dataset ready: {len(train_dataset)} samples")

        # Create TRL trainer with tools
        trl_config = self.config.to_trl_config()
        print(f"GRPO Configuration:")
        print(f"  Generations per prompt: {trl_config.num_generations}")
        print(f"  Max completion length: {trl_config.max_completion_length}")
        print(f"  Temperature: {trl_config.temperature}")
        print(f"  Learning rate: {trl_config.learning_rate}")
        if self.tools:
            print(f"  Tools: {[t['function']['name'] for t in self.tools]}")

        # Build trainer kwargs
        trainer_kwargs = {
            "model": self._model,
            "args": trl_config,
            "train_dataset": train_dataset,
            "processing_class": self._tokenizer,
            "reward_funcs": reward_funcs,
        }

        # LoRA / QLoRA: TRL wraps the model with PEFT adapters; only the adapters
        # train (compatible with ECHO). Slashes memory for laptop/consumer GPUs.
        peft_config = self.config.build_peft_config()
        if peft_config is not None:
            trainer_kwargs["peft_config"] = peft_config
            print(f"  LoRA: r={self.config.lora_r}, alpha={self.config.lora_alpha}"
                  f"{', 4-bit base (QLoRA)' if self.config.load_in_4bit else ''}")

        # Tools come from the training environment when one is provided (its
        # public methods become the tools); otherwise from the agent.
        if self.rollout_func is not None:
            trainer_kwargs["rollout_func"] = self.rollout_func
            print("  Rollout: custom rollout_func")
        elif self.environment_factory is not None:
            trainer_kwargs["environment_factory"] = self.environment_factory
            print("  Training environment: per-rollout sandbox (tools from environment)")
        elif self.tools:
            trainer_kwargs["tools"] = self.tools
        if self.config.use_vllm:
            print(f"  vLLM: {self.config.vllm_mode} (gpu_mem_util={self.config.vllm_gpu_memory_utilization}, "
                  f"max_model_len={self.config.vllm_max_model_len})")

        # Enable the ECHO auxiliary world-modeling loss when configured.
        if self.config.world_model_coeff > 0:
            if not self.tools and self.environment_factory is None:
                raise ValueError(
                    "ECHO world-model loss (world_model_coeff > 0) requires tools: it "
                    "predicts tool-result tokens, of which there are none without tools. "
                    "Add tools to the agent, supply an environment_factory, or set "
                    "world_model_coeff=0.0."
                )
            from .echo import EchoGRPOTrainer

            trainer_kwargs["world_model_coeff"] = self.config.world_model_coeff
            trainer_kwargs["world_loss_normalization"] = self.config.world_loss_normalization
            print(
                f"  ECHO world-model loss: coeff={self.config.world_model_coeff}, "
                f"normalization={self.config.world_loss_normalization}"
            )
            self._trl_trainer = EchoGRPOTrainer(**trainer_kwargs)
        else:
            self._trl_trainer = TRLGRPOTrainer(**trainer_kwargs)

        # Train!
        print("\nStarting GRPO training...\n")
        self._trl_trainer.train()

        # Collect metrics
        metrics = []
        if hasattr(self._trl_trainer, "state") and hasattr(self._trl_trainer.state, "log_history"):
            metrics = self._trl_trainer.state.log_history

        final_loss = None
        if metrics:
            for entry in reversed(metrics):
                if "loss" in entry:
                    final_loss = entry["loss"]
                    break

        return GRPOTrainingResult(
            metrics=metrics,
            final_loss=final_loss,
            trainer=self._trl_trainer,
            tools=self.tools,
            instructions=self.instructions,
        )

    @property
    def model(self):
        """Access the underlying model."""
        return self._model

    @property
    def tokenizer(self):
        """Access the tokenizer."""
        return self._tokenizer


def train_grpo(
    agent: "Agent",
    suite: "EvalSuite",
    reward_measures: List[Union[str, Callable]],
    *,
    config: Optional[GRPOTrainingConfig] = None,
    reward_weights: Optional[List[float]] = None,
    torch_dtype: Optional[str] = None,  # Defaults to config.torch_dtype
    device_map: str = "auto",
    environment_factory: Optional[Callable[[], Any]] = None,
    extra_reward_funcs: Optional[List[Callable]] = None,
    rollout_func: Optional[Callable] = None,
) -> GRPOTrainingResult:
    """Train an agent with GRPO using omniagents evaluation.

    This is the simplest way to train an agent with GRPO. For more control,
    use the GRPOTrainer class directly.

    Args:
        agent: The omniagents Agent to train
        suite: EvalSuite containing training data
        reward_measures: List of measure names or functions to use as rewards
        config: Training configuration (default: GRPOTrainingConfig())
        reward_weights: Optional weights for combining multiple rewards
        torch_dtype: Torch dtype for model
        device_map: Device map for model loading

    Returns:
        GRPOTrainingResult with trained model and metrics

    Example:
        from omniagents import Agent, function_tool
        from omniagents.notebook import EvalSuite, measure, pass_reason, fail_reason
        from omniagents.training import train_grpo, GRPOTrainingConfig

        @function_tool
        def calculate(expr: str) -> str:
            '''Evaluate a math expression.'''
            return str(eval(expr))

        agent = Agent(
            name="Math Agent",
            model="Qwen/Qwen3-0.6B",
            tools=[calculate],
            instructions="Solve math problems using the calculator.",
        )

        @measure
        def correct_answer(ctx):
            response = ctx.final_assistant_message.text or ""
            expected = ctx.expect.get('answer')
            if expected in response:
                return pass_reason("Correct")
            return fail_reason("Incorrect")

        suite = EvalSuite.from_records(
            data[:100],
            input_fn=lambda r: r['question'],
            expect_fn=lambda r: {'answer': r['answer']},
            measures=["correct_answer"],
        )

        config = GRPOTrainingConfig(
            num_generations=4,
            num_train_epochs=1,
            learning_rate=1e-5,
        )

        result = train_grpo(
            agent=agent,
            suite=suite,
            reward_measures=["correct_answer"],
            config=config,
        )

        # Save the trained model
        result.save_model("./trained_model")

        # Convert to Ollama for fast inference
        result.to_ollama("my-math-agent")
    """
    if config is None:
        config = GRPOTrainingConfig()

    # Use config's torch_dtype if not explicitly provided
    effective_dtype = torch_dtype if torch_dtype is not None else config.torch_dtype

    trainer = GRPOTrainer(
        agent=agent,
        reward_measures=reward_measures,
        config=config,
        reward_weights=reward_weights,
        torch_dtype=effective_dtype,
        device_map=device_map,
        environment_factory=environment_factory,
        extra_reward_funcs=extra_reward_funcs,
        rollout_func=rollout_func,
    )

    return trainer.train(suite)