Repository URL to install this package:
|
Version:
0.4.37 ▾
|
omni-code
/
compact.py
|
|---|
from __future__ import annotations
import asyncio
import contextlib
import copy
import json
import os
from pathlib import Path
from typing import Any
from agents import Runner
from omniagents import server_function
from omniagents.core.agents.builder import _default_build_agent
from omniagents.core.config.loader import load_agent_spec_from_yaml
from omniagents.core.runtime.compaction import (
build_resume_protocol_block,
extract_tool_events,
extract_transcript,
extract_user_messages,
format_tool_events_block,
format_user_messages_block,
history_since_last_summary,
is_context_length_exceeded,
)
from omniagents.core.session.manager import Session
from openai import APIStatusError, BadRequestError
def _last_user_message(history: list[dict]) -> str | None:
for item in reversed(history or []):
if item.get("role") == "user" and isinstance(item.get("content"), str):
content = item.get("content")
if content:
return content
return None
def _is_context_length_exceeded(err: Exception) -> bool:
if isinstance(err, (BadRequestError, APIStatusError)):
return is_context_length_exceeded(err)
return False
def _summarization_window(history: list[dict]) -> list[dict]:
return history_since_last_summary(history, include_summary=False)
def _extract_transcript(history: list[dict]) -> str:
return extract_transcript(history)
def _extract_user_messages(history: list[dict]) -> list[str]:
return extract_user_messages(history)
def _format_user_messages_block(messages: list[str]) -> str:
return format_user_messages_block(messages)
def _extract_tool_events(history: list[dict]) -> list[dict]:
return extract_tool_events(history)
def _format_tool_events_block(events: list[dict], *, max_items: int = 5) -> str:
return format_tool_events_block(events, max_items=max_items)
def _trim_text_head(text: str, max_chars: int) -> str:
if len(text) <= max_chars:
return text
return text[:max_chars] + f"\n...[TRUNCATED {len(text) - max_chars} chars]..."
def _trim_text_tail(text: str, max_chars: int) -> str:
if len(text) <= max_chars:
return text
return f"...[TRUNCATED {len(text) - max_chars} chars]...\n" + text[-max_chars:]
def _extract_tool_name_from_event(event: dict) -> str | None:
name = event.get("name")
if isinstance(name, str) and name:
return name
tool_name = event.get("tool_name")
if isinstance(tool_name, str) and tool_name:
return tool_name
tool_calls = event.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
first = tool_calls[0]
if isinstance(first, dict):
fn = first.get("function")
if isinstance(fn, dict):
fn_name = fn.get("name")
if isinstance(fn_name, str) and fn_name:
return fn_name
omni = event.get("omniagents")
if isinstance(omni, dict):
omni_name = omni.get("tool_name") or omni.get("name")
if isinstance(omni_name, str) and omni_name:
return omni_name
return None
def _trim_large_strings(value: Any, *, max_chars: int) -> Any:
if isinstance(value, str):
return _trim_text_head(value, max_chars)
if isinstance(value, list):
return [_trim_large_strings(v, max_chars=max_chars) for v in value]
if isinstance(value, dict):
return {
k: _trim_large_strings(v, max_chars=max_chars) for k, v in value.items()
}
return value
def _trim_exec_bash_fields(value: Any, *, stdout_chars: int, stderr_chars: int) -> Any:
if isinstance(value, list):
return [
_trim_exec_bash_fields(
v, stdout_chars=stdout_chars, stderr_chars=stderr_chars
)
for v in value
]
if isinstance(value, dict):
out: dict = {}
for k, v in value.items():
if isinstance(v, str) and k.lower() in {"stdout", "out"}:
out[k] = _trim_text_tail(v, stdout_chars)
continue
if isinstance(v, str) and k.lower() in {"stderr", "err"}:
out[k] = _trim_text_tail(v, stderr_chars)
continue
out[k] = _trim_exec_bash_fields(
v, stdout_chars=stdout_chars, stderr_chars=stderr_chars
)
return out
return value
def _sanitize_tool_event(event: dict) -> dict:
tool_name = _extract_tool_name_from_event(event)
try:
event_copy = copy.deepcopy(event)
except Exception:
event_copy = {"event": str(event)}
if tool_name in {"read_file", "convert_to_markdown"}:
safe = event_copy
elif tool_name == "execute_bash":
safe = _trim_exec_bash_fields(event_copy, stdout_chars=2000, stderr_chars=4000)
safe = _trim_large_strings(safe, max_chars=2000)
else:
safe = _trim_large_strings(event_copy, max_chars=2000)
try:
rendered = json.dumps(safe, ensure_ascii=False, default=str)
except Exception:
safe = {"tool_name": tool_name or "UNKNOWN", "event": str(event_copy)}
rendered = json.dumps(safe, ensure_ascii=False, default=str)
max_event_chars = 8000
if len(rendered) <= max_event_chars:
return safe
minimal = {
"tool_name": tool_name or "UNKNOWN",
"truncated": True,
"original_chars": len(rendered),
}
if isinstance(safe, dict):
for key in ("command", "cmd", "args", "exit_code", "returncode"):
val = safe.get(key)
if val is not None:
minimal[key] = val
minimal["snippet_tail"] = _trim_text_tail(rendered, max_event_chars)
return minimal
def _extract_file_paths_from_tool_event(event: dict) -> list[str]:
tool_name = _extract_tool_name_from_event(event) or ""
paths: list[str] = []
def add_path(value: Any) -> None:
if isinstance(value, str) and value:
paths.append(value)
if tool_name in {"read_file", "apply_patch", "write_file"}:
for k in ("file_path", "path"):
add_path(event.get(k))
tool_calls = event.get("tool_calls")
if isinstance(tool_calls, list):
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function")
if not isinstance(fn, dict):
continue
args = fn.get("arguments")
if isinstance(args, dict):
for k in ("file_path", "path"):
add_path(args.get(k))
elif isinstance(args, str):
try:
args_obj = json.loads(args)
if isinstance(args_obj, dict):
for k in ("file_path", "path"):
add_path(args_obj.get(k))
except Exception:
pass
if tool_name == "list_directory":
add_path(event.get("path"))
seen: set[str] = set()
out: list[str] = []
for p in paths:
if p in seen:
continue
seen.add(p)
out.append(p)
return out
def _resume_protocol_block(*, history: list[dict], tool_events: list[dict]) -> str:
return build_resume_protocol_block(history=history, tool_events=tool_events)
def _extract_response_text(resp: Any) -> str:
out = []
for item in getattr(resp, "output", None) or []:
if getattr(item, "type", None) != "message":
continue
if getattr(item, "role", None) != "assistant":
continue
for c in getattr(item, "content", None) or []:
if getattr(c, "type", None) in {"output_text", "text"}:
text = getattr(c, "text", None)
if isinstance(text, str) and text:
out.append(text)
return "".join(out).strip()
@contextlib.contextmanager
def _temporary_environ(overrides: dict[str, str | None]):
prior: dict[str, str | None] = {}
for key, value in (overrides or {}).items():
prior[key] = os.environ.get(key)
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
try:
yield
finally:
for key, value in prior.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def _litellm_env_overrides(
*, model: str | None, api_key: str | None
) -> dict[str, str | None]:
if not api_key:
return {}
model_name = (model or "").lower()
if model_name.startswith("anthropic/") or model_name.startswith("claude"):
return {"ANTHROPIC_API_KEY": api_key}
if model_name.startswith("gemini/") or model_name.startswith("google/"):
return {"GOOGLE_API_KEY": api_key}
return {"OPENAI_API_KEY": api_key}
def _build_model_provider(
runtime: dict[str, Any],
) -> tuple[Any | None, dict[str, str | None]]:
provider = (runtime.get("provider") or "").strip().lower()
model = runtime.get("model")
api_key = runtime.get("api_key")
base_url = runtime.get("base_url")
if provider == "litellm":
from agents.extensions.models.litellm_provider import LitellmProvider
return LitellmProvider(), _litellm_env_overrides(model=model, api_key=api_key)
if provider in {"openai", "openai-compatible", "azure"}:
from openai import AsyncOpenAI
from agents.models.openai_provider import OpenAIProvider
resolved_base_url = base_url or "https://api.openai.com/v1"
resolved_api_key = api_key
if not resolved_api_key:
if provider == "azure":
resolved_api_key = os.getenv("AZURE_OPENAI_API_KEY")
else:
resolved_api_key = os.getenv("OPENAI_API_KEY")
is_azure = resolved_base_url and (
"openai.azure.com" in resolved_base_url
or "services.ai.azure.com" in resolved_base_url
)
if is_azure:
if not resolved_api_key:
return None, {}
client = AsyncOpenAI(
api_key="dummy",
base_url=resolved_base_url,
default_headers={"api-key": resolved_api_key},
)
return OpenAIProvider(openai_client=client), {}
if not resolved_api_key and resolved_base_url != "https://api.openai.com/v1":
resolved_api_key = "dummy"
if not resolved_api_key:
return None, {}
client = AsyncOpenAI(
api_key=resolved_api_key,
base_url=resolved_base_url,
)
return OpenAIProvider(openai_client=client), {}
return None, {}
def _resolve_compaction_runtime(
service: Any, session: Session
) -> tuple[str | None, dict[str, Any]]:
candidates: list[str] = []
active_model = getattr(session, "active_model", None)
if isinstance(active_model, str) and active_model.strip():
candidates.append(active_model.strip())
spec_model = getattr(getattr(service, "spec", None), "model_name", None)
if isinstance(spec_model, str) and spec_model.strip():
candidates.append(spec_model.strip())
from omni_code.models import resolve_model_for_runtime
for ref in candidates:
runtime = resolve_model_for_runtime(ref)
if runtime:
return ref, runtime
model_config = getattr(session, "model_config", None)
if isinstance(model_config, dict) and model_config.get("model"):
runtime = {
"provider": (model_config.get("provider") or "").strip().lower(),
"model": model_config.get("model"),
"api_key": model_config.get("api_key"),
"base_url": model_config.get("base_url"),
"api_version": model_config.get("api_version"),
}
return candidates[0] if candidates else None, runtime
return candidates[0] if candidates else None, {}
async def _run_compaction_agent(
*,
model: str | None,
runtime: dict[str, Any] | None,
prompt: str,
session: Session | None,
) -> str:
agent_yaml = (
Path(__file__).resolve().parent.parent
/ "omni_agents"
/ "compactor"
/ "agent.yml"
)
spec = load_agent_spec_from_yaml(str(agent_yaml))
if model:
spec.model_name = model
agent = await _default_build_agent(
settings={}, mcp_servers=None, spec=spec, session=session
)
model_provider, env_overrides = _build_model_provider(runtime or {})
from agents import RunConfig
run_config = (
RunConfig(model_provider=model_provider) if model_provider is not None else None
)
with _temporary_environ(env_overrides):
result = await Runner.run(agent, prompt, max_turns=1, run_config=run_config)
return result.final_output_as(str)
@server_function(
description="Compact context by appending a summary marker",
strict=True,
name_override="compact",
)
async def compact(service: Any, session: Session) -> dict:
await service.set_client_status(
"Compacting...",
session=session,
run_id=getattr(session, "active_run_id", None),
show_spinner=True,
)
try:
full_history = list(getattr(session, "history", []) or [])
window = _summarization_window(full_history)
model_ref, runtime = _resolve_compaction_runtime(service, session)
model_name = runtime.get("model") if isinstance(runtime, dict) else None
if (not isinstance(model_name, str) or not model_name.strip()) and isinstance(
model_ref, str
):
model_name = model_ref.strip()
if (
isinstance(model_name, str)
and "/" in model_name
and runtime.get("provider") != "litellm"
):
prefix = model_name.split("/", 1)[0]
if prefix not in {"openai", "litellm"}:
model_name = model_name.split("/", 1)[1]
while True:
transcript = _extract_transcript(window)
user_messages = _extract_user_messages(full_history)
user_messages_block = _format_user_messages_block(user_messages)
prompt = (
"All user messages (verbatim; last is highlighted):\n"
f"{user_messages_block}\n\n"
"Conversation transcript since last checkpoint (raw items; includes tool calls/outputs):\n"
f"{transcript}"
)
try:
summary_body = await _run_compaction_agent(
model=model_name,
runtime=runtime,
prompt=prompt,
session=session,
)
break
except Exception as e:
if _is_context_length_exceeded(e) and window:
window.pop(0)
continue
raise
summary_body = summary_body or "(no summary produced)"
user_messages = _extract_user_messages(full_history)
user_messages_block = _format_user_messages_block(user_messages)
tool_events = _extract_tool_events(full_history)
tool_events_block = _format_tool_events_block(tool_events, max_items=5)
resume_protocol = _resume_protocol_block(
history=full_history, tool_events=tool_events
)
summary_text = (
"Another language model started to solve this problem and produced a summary of its thinking process. "
"This message is a checkpoint to help you continue the work without duplicating effort. "
f"{resume_protocol}\n"
"Here are all user messages (verbatim; last is highlighted):\n"
f"{user_messages_block}\n\n"
"Here are the last 5 tool call/result events (verbatim):\n"
f"{tool_events_block}\n\n"
"Here is the summary produced by the other language model:\n\n"
f"{summary_body}"
)
summary_item = {
"role": "assistant",
"content": summary_text,
"omniagents": {"kind": "context_summary", "version": 1},
}
session.append_message(summary_item)
finally:
await service.set_client_status(
"",
session=session,
run_id=getattr(session, "active_run_id", None),
show_spinner=False,
)
return "Context compacted."