Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
agentio / agentio / processor.py
Size: Mime:
import os
import json
import logging
import threading
from collections import deque
from typing import Optional, Any, Dict

try:
    import httpx
except ImportError:
    httpx = None

from agents.tracing import TracingProcessor
from .storage import SQLiteStorage
from .utils import fibo, validate_keys

logger = logging.getLogger(__name__)

class AgentIOProcessor(TracingProcessor):
    """Local trace processor that routes to different SQLite databases based on project metadata"""
    
    def __init__(self, default_project: Optional[str] = None):
        self.default_project = default_project
        self._storages = {}  # project -> SQLiteStorage mapping
        self._lock = threading.Lock()
    
    def _get_project_from_trace(self, trace: Any) -> str:
        """Extract project from trace metadata"""
        if hasattr(trace, 'metadata') and trace.metadata:
            project = trace.metadata.get('agentio_project')
            if project:
                return project
        return self.default_project or 'default'
    
    def _get_storage(self, project: str) -> SQLiteStorage:
        """Get or create storage for a project"""
        with self._lock:
            if project not in self._storages:
                from pathlib import Path
                db_path = Path.home() / ".agentio" / f"{project}.db"
                self._storages[project] = SQLiteStorage(str(db_path))
            return self._storages[project]
    
    def on_trace_start(self, trace: Any):
        """Called when trace starts"""
        try:
            project = self._get_project_from_trace(trace)
            storage = self._get_storage(project)
            
            trace_data = trace.export()
            # Validate metadata keys if present
            if 'metadata' in trace_data:
                validate_keys(trace_data['metadata'])
            storage.insert_trace(trace_data)
        except Exception as e:
            logger.debug(f"Failed to insert trace: {e}")
    
    def on_trace_end(self, trace: Any):
        """Called when trace ends - SDK doesn't provide end data"""
        # Traces don't have timestamps in the SDK
        pass
    
    def on_span_start(self, span: Any):
        """Called when span starts - not needed"""
        pass
    
    def on_span_end(self, span: Any):
        """Called when span ends - store full span data"""
        try:
            # Get current trace to determine project
            from agents.tracing import get_current_trace
            trace = get_current_trace()
            if not trace:
                logger.debug("No active trace for span")
                return
            
            project = self._get_project_from_trace(trace)
            storage = self._get_storage(project)
            
            # Get the base export data
            span_export = span.export()
            
            # Enhance with full data from span_data object
            if hasattr(span, 'span_data'):
                span_data = span.span_data
                
                # Extract the span type safely
                span_type = getattr(span_data, 'type', None)
                
                # Enhance response spans with full data that export() misses
                if span_type == 'response' and hasattr(span_data, 'response'):
                    response = span_data.response
                    if response:
                        # Just dump the entire response object - much cleaner!
                        if hasattr(response, 'model_dump'):
                            # Use pydantic's model_dump to get all fields
                            enhanced_response = response.model_dump()
                        else:
                            # Fallback to dict conversion if not a pydantic model
                            enhanced_response = dict(response) if hasattr(response, '__dict__') else str(response)
                        
                        # Store enhanced response data
                        span_export['span_data']['response'] = enhanced_response
                        
                        # Add input if it was set
                        if hasattr(span_data, 'input') and span_data.input:
                            span_export['span_data']['input'] = span_data.input
                
                # Generation spans already export everything properly, but let's ensure consistency
                elif span_type == 'generation':
                    # The export() already includes all the data for generation spans
                    pass
            
            storage.insert_span(span_export)
        except Exception as e:
            logger.debug(f"Failed to insert span: {e}")
    
    def shutdown(self):
        """Called on app exit - shutdown all storages"""
        with self._lock:
            for storage in self._storages.values():
                try:
                    storage.close()
                except:
                    pass
            self._storages.clear()
    
    def force_flush(self):
        """Force flush - SQLite auto-commits"""
        pass


class RemoteAgentIOProcessor(TracingProcessor):
    """Remote trace processor that sends to AgentIO server with queuing"""
    
    def __init__(self, endpoint: str, project: str, token: Optional[str] = None):
        if httpx is None:
            raise ImportError("httpx is required for remote processor. Install with: pip install httpx")
        
        self.endpoint = endpoint
        self.project = project
        self.token = token or os.environ.get("AGENTIO_TOKEN")
        self.client = None
        self._client_lock = threading.Lock()
        self._queued_events = deque(maxlen=1000)  # Buffer up to 1000 events
        self._shutdown = False
        self._client_thread = threading.Thread(target=self._init_client_with_retry)
        self._client_thread.daemon = True
        self._client_thread.start()
    
    def _init_client_with_retry(self):
        """Initialize client with exponential backoff"""
        import time
        
        fib = fibo()
        for sleep_coefficient in fib:
            if self._shutdown:
                break
                
            try:
                client = httpx.Client(timeout=httpx.Timeout(5.0))
                # Test connection
                test_url = self.endpoint.replace('/api/trace', '/').replace('/trace', '/')
                client.get(test_url, timeout=2.0)
                
                with self._client_lock:
                    self.client = client
                    # Flush queued events
                    while self._queued_events and not self._shutdown:
                        event = self._queued_events.popleft()
                        self._send_now(event['type'], event['data'])
                    logger.info(f"Connected to AgentIO server: {self.endpoint}")
                    break
            except Exception as e:
                logger.debug(f"Failed to connect to server: {e}")
                if sleep_coefficient and not self._shutdown:
                    time.sleep(0.1 * sleep_coefficient)  # Fibonacci backoff
    
    def on_trace_start(self, trace: Any):
        """Called when trace starts"""
        try:
            trace_data = trace.export()
            # Validate metadata keys if present
            if 'metadata' in trace_data:
                validate_keys(trace_data['metadata'])
            self._send("trace_start", trace_data)
        except Exception as e:
            logger.debug(f"Failed to process trace start: {e}")
    
    def on_trace_end(self, trace: Any):
        """Called when trace ends"""
        # Traces don't have timestamps in the SDK, but we still notify
        try:
            trace_data = trace.export()
            self._send("trace_end", {"id": trace_data.get('id')})
        except Exception as e:
            logger.debug(f"Failed to process trace end: {e}")
    
    def on_span_start(self, span: Any):
        """Called when span starts"""
        pass  # Skip to reduce traffic
    
    def on_span_end(self, span: Any):
        """Called when span ends"""
        try:
            # Get the base export data
            span_export = span.export()
            
            # Enhance with full data from span_data object
            if hasattr(span, 'span_data'):
                span_data = span.span_data
                
                # Extract the span type safely
                span_type = getattr(span_data, 'type', None)
                
                # Enhance response spans with full data that export() misses
                if span_type == 'response' and hasattr(span_data, 'response'):
                    response = span_data.response
                    if response:
                        # Just dump the entire response object - much cleaner!
                        if hasattr(response, 'model_dump'):
                            # Use pydantic's model_dump to get all fields
                            enhanced_response = response.model_dump()
                        else:
                            # Fallback to dict conversion if not a pydantic model
                            enhanced_response = dict(response) if hasattr(response, '__dict__') else str(response)
                        
                        # Store enhanced response data
                        span_export['span_data']['response'] = enhanced_response
                        
                        # Add input if it was set
                        if hasattr(span_data, 'input') and span_data.input:
                            span_export['span_data']['input'] = span_data.input
            
            self._send("span_end", span_export)
        except Exception as e:
            logger.debug(f"Failed to process span end: {e}")
    
    def _send(self, event_type: str, data: Dict[str, Any]):
        """Send or queue event"""
        with self._client_lock:
            if self.client is None:
                # Queue the event if not connected
                self._queued_events.append({"type": event_type, "data": data})
            else:
                self._send_now(event_type, data)
    
    def _send_now(self, event_type: str, data: Dict[str, Any]):
        """Actually send event to remote server"""
        if self._shutdown:
            return
            
        try:
            headers = {}
            if self.token:
                headers["Authorization"] = f"Bearer {self.token}"
            
            self.client.post(
                self.endpoint,
                json={
                    "event": event_type,
                    "project": self.project,
                    "data": data
                },
                headers=headers,
                timeout=2.0  # Don't block user code
            )
        except Exception as e:
            # Never crash user's app due to tracing
            logger.debug(f"Failed to send trace: {e}")
    
    def shutdown(self):
        """Shutdown the processor"""
        self._shutdown = True
        # Wait for client thread to finish sending queued events
        if self._client_thread.is_alive():
            self._client_thread.join(timeout=5.0)
        if self.client:
            self.client.close()
    
    def force_flush(self):
        """Force flush queued events"""
        with self._client_lock:
            while self._queued_events and self.client and not self._shutdown:
                event = self._queued_events.popleft()
                self._send_now(event['type'], event['data'])