Repository URL to install this package:
|
Version:
0.4.40 ▾
|
omni-code
/
models.py
|
|---|
"""Model configuration management for Omni Code.
Model sources (merged at runtime):
1. Project defaults - Built-in providers/models shipped with omni-code
2. User config - User providers/models in ~/.config/omni_code/models.json
User config overrides project defaults for any matching provider/model.
Config schema (models.json):
{
"version": 3,
"default": "gpt-5.1" | "azure-prod/gpt-5.2" | null,
"voice_default": "gpt-realtime" | "azure-prod/gpt" | null,
"providers": {
"openai": {
"type": "openai",
"api_key": "${OPENAI_API_KEY}",
"models": {"gpt-5.1": {"model": "gpt-5.1", "label": "GPT-5.1"}}
}
}
}
Model references:
- Unprefixed names ("gpt-5.1") imply the default provider profile "openai".
- Other providers use "<provider_name>/<model_name>" ("azure-prod/gpt-5.2").
"""
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from omni_code.config import PROJECT_NAME, PROJECT_NAME_WINDOWS
from omni_code.default_models import get_default_model_name_fallback, get_default_providers
MODELS_CONFIG_VERSION = 3
DEFAULT_PROVIDER_NAME = "openai"
def get_config_dir() -> Path:
"""Get the config directory for omni_code."""
if os.name == "nt":
base = os.getenv("APPDATA")
if not base:
base = str(Path.home() / "AppData" / "Roaming")
return Path(base) / PROJECT_NAME_WINDOWS
override = os.getenv("XDG_CONFIG_HOME")
if override:
return Path(override) / PROJECT_NAME
return Path.home() / ".config" / PROJECT_NAME
def get_models_config_path() -> Path:
"""Get path to models.json config file."""
return get_config_dir() / "models.json"
def _expand_env_vars(value: Any) -> Any:
"""Recursively expand ${VAR} and ${VAR:-default} in strings."""
if isinstance(value, str):
pattern = r"\$\{([^}:]+)(?::-([^}]*))?\}"
def replace_var(match: re.Match) -> str:
var_name = match.group(1)
default = match.group(2)
env_value = os.getenv(var_name)
if env_value is not None:
return env_value
if default is not None:
return default
return ""
return re.sub(pattern, replace_var, value)
elif isinstance(value, dict):
return {k: _expand_env_vars(v) for k, v in value.items()}
elif isinstance(value, list):
return [_expand_env_vars(v) for v in value]
return value
def load_models_config() -> Dict[str, Any]:
"""Load models config from models.json."""
models_path = get_models_config_path()
if not models_path.exists():
return {
"version": MODELS_CONFIG_VERSION,
"default": None,
"voice_default": None,
"providers": {},
}
try:
content = models_path.read_text(encoding="utf-8")
raw = json.loads(content)
except (json.JSONDecodeError, IOError) as e:
print(f"Warning: Failed to load {models_path}: {e}")
return {
"version": MODELS_CONFIG_VERSION,
"default": None,
"voice_default": None,
"providers": {},
}
if raw.get("version") != MODELS_CONFIG_VERSION:
print(
f"Warning: Unsupported models.json version {raw.get('version')}. "
f"Expected version {MODELS_CONFIG_VERSION}."
)
return {
"version": MODELS_CONFIG_VERSION,
"default": None,
"voice_default": None,
"providers": {},
}
providers = raw.get("providers")
if not isinstance(providers, dict):
providers = {}
default_value = raw.get("default")
if default_value is not None and not isinstance(default_value, str):
default_value = None
voice_default_value = raw.get("voice_default")
if voice_default_value is not None and not isinstance(voice_default_value, str):
voice_default_value = None
return {
"version": MODELS_CONFIG_VERSION,
"default": default_value,
"voice_default": voice_default_value,
"providers": providers,
}
def save_models_config(config: Dict[str, Any]) -> None:
"""Save models.json with secure permissions."""
path = get_models_config_path()
path.parent.mkdir(parents=True, exist_ok=True)
to_write = {
"version": MODELS_CONFIG_VERSION,
"default": config.get("default"),
"voice_default": config.get("voice_default"),
"providers": config.get("providers") or {},
}
content = json.dumps(to_write, indent=2)
if os.name == "nt":
path.write_text(content, encoding="utf-8")
return
fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as handle:
handle.write(content)
def _merge_provider_config(
base_provider: Dict[str, Any],
override_provider: Dict[str, Any],
) -> Dict[str, Any]:
merged: Dict[str, Any] = dict(base_provider)
base_models = base_provider.get("models")
if not isinstance(base_models, dict):
base_models = {}
override_models = override_provider.get("models")
if not isinstance(override_models, dict):
override_models = {}
merged_models = dict(base_models)
merged_models.update(override_models)
for key, value in override_provider.items():
if key == "models":
continue
merged[key] = value
merged["models"] = merged_models
return merged
def get_merged_providers() -> Dict[str, Dict[str, Any]]:
"""Get all providers merged from project defaults and user config."""
merged = get_default_providers()
user_config = load_models_config()
user_providers = user_config.get("providers")
if not isinstance(user_providers, dict):
return merged
for provider_name, provider_config in user_providers.items():
if not isinstance(provider_config, dict):
continue
existing = merged.get(provider_name) or {}
merged[provider_name] = _merge_provider_config(existing, provider_config)
return merged
def _format_model_ref(provider_name: str, model_name: str) -> str:
if provider_name == DEFAULT_PROVIDER_NAME:
return model_name
return f"{provider_name}/{model_name}"
def _parse_model_ref(ref: str) -> Optional[tuple[str, str]]:
if not isinstance(ref, str):
return None
name = (ref or "").strip()
if not name:
return None
if "/" in name:
provider_name, model_name = name.split("/", 1)
provider_name = provider_name.strip()
model_name = model_name.strip()
if not provider_name or not model_name:
return None
return provider_name, model_name
return DEFAULT_PROVIDER_NAME, name
def normalize_model_ref(ref: str) -> Optional[str]:
parsed = _parse_model_ref(ref)
if not parsed:
return None
provider_name, model_name = parsed
return _format_model_ref(provider_name, model_name)
def get_model_config(ref: str) -> Optional[Dict[str, Any]]:
"""Get the merged (provider + model) configuration for a model reference."""
parsed = _parse_model_ref(ref)
if not parsed:
return None
provider_name, model_name = parsed
providers = get_merged_providers()
provider_config = providers.get(provider_name)
if not isinstance(provider_config, dict):
return None
provider_models = provider_config.get("models")
if not isinstance(provider_models, dict):
return None
model_config = provider_models.get(model_name)
if not isinstance(model_config, dict):
return None
provider_type = provider_config.get("type") or provider_config.get("provider") or "unknown"
merged: Dict[str, Any] = {
"name": _format_model_ref(provider_name, model_name),
"provider": provider_type,
"provider_name": provider_name,
"model": model_config.get("model") or model_name,
"label": model_config.get("label") or model_name,
}
for key in [
"api_key",
"base_url",
"api_version",
"realtime_url",
"realtime_base_url",
"realtime",
"reasoning",
"max_input_tokens",
"max_output_tokens",
"model_settings",
]:
if key in provider_config:
merged[key] = provider_config.get(key)
if key in model_config:
merged[key] = model_config.get(key)
return merged
def get_default_model_name() -> Optional[str]:
"""Get the name of the default model.
Priority:
1. User-configured default in models.json
2. Project default (from default_models.py)
"""
config = load_models_config()
user_default = config.get("default")
if user_default:
merged = get_model_config(user_default)
if merged and not merged.get("realtime"):
normalized = normalize_model_ref(user_default)
if normalized:
return normalized
return get_default_model_name_fallback()
def get_voice_default_model_name() -> Optional[str]:
"""Get the name of the default voice (realtime) model."""
config = load_models_config()
user_default = config.get("voice_default")
if user_default and get_model_config(user_default):
normalized = normalize_model_ref(user_default)
if normalized:
merged = get_model_config(normalized) or {}
if merged.get("realtime"):
return normalized
return None
def get_voice_default_model_config() -> Optional[Dict[str, Any]]:
name = get_voice_default_model_name()
if not name:
return None
return get_model_config(name)
def get_default_model_config() -> Optional[Dict[str, Any]]:
"""Get configuration for the default model."""
default_name = get_default_model_name()
if default_name:
return get_model_config(default_name)
return None
def is_user_model(ref: str) -> bool:
"""Check if a model is user-defined (vs project default)."""
parsed = _parse_model_ref(ref)
if not parsed:
return False
provider_name, model_name = parsed
config = load_models_config()
providers = config.get("providers")
if not isinstance(providers, dict):
return False
provider = providers.get(provider_name)
if not isinstance(provider, dict):
return False
models = provider.get("models")
if not isinstance(models, dict):
return False
return model_name in models
def list_models() -> List[Dict[str, Any]]:
"""List all available models (merged from project defaults and user config)."""
providers = get_merged_providers()
default_name = get_default_model_name()
voice_default_name = get_voice_default_model_name()
results: List[Dict[str, Any]] = []
for provider_name, provider_config in providers.items():
if not isinstance(provider_config, dict):
continue
provider_type = provider_config.get("type") or provider_config.get("provider") or "unknown"
provider_models = provider_config.get("models")
if not isinstance(provider_models, dict):
continue
for model_name, model_config in provider_models.items():
if not isinstance(model_config, dict):
continue
ref = _format_model_ref(provider_name, model_name)
merged = get_model_config(ref) or {}
entry = {
"name": ref,
"label": merged.get("label", ref),
"provider": provider_type,
"provider_name": provider_name,
"model": merged.get("model"),
"is_default": ref == default_name,
"is_voice_default": ref == voice_default_name,
"is_user_defined": is_user_model(ref),
"realtime": bool(merged.get("realtime")),
"reasoning": merged.get("reasoning"),
"max_input_tokens": merged.get("max_input_tokens"),
"max_output_tokens": merged.get("max_output_tokens"),
"has_api_key": bool(merged.get("api_key")),
}
if merged.get("model_settings"):
entry["model_settings"] = merged["model_settings"]
results.append(entry)
results.sort(key=lambda item: item["name"])
return results
def add_model(
*,
provider_name: str,
provider_type: str,
model_name: str,
model: str,
label: Optional[str] = None,
provider_api_key: Optional[str] = None,
provider_base_url: Optional[str] = None,
provider_api_version: Optional[str] = None,
provider_realtime_url: Optional[str] = None,
provider_realtime_base_url: Optional[str] = None,
model_api_key: Optional[str] = None,
realtime: Optional[bool] = None,
realtime_url: Optional[str] = None,
realtime_base_url: Optional[str] = None,
reasoning: Optional[str] = None,
max_input_tokens: Optional[int] = None,
max_output_tokens: Optional[int] = None,
model_settings: Optional[Dict[str, Any]] = None,
set_default: bool = False,
set_voice_default: bool = False,
) -> None:
"""Add or update a model under a provider profile."""
config = load_models_config()
providers = config.setdefault("providers", {})
provider_obj = providers.get(provider_name)
if not isinstance(provider_obj, dict):
provider_obj = {}
provider_obj["type"] = provider_type
if provider_api_key is not None:
provider_obj["api_key"] = provider_api_key
if provider_base_url is not None:
provider_obj["base_url"] = provider_base_url
if provider_api_version is not None:
provider_obj["api_version"] = provider_api_version
if provider_realtime_url is not None:
provider_obj["realtime_url"] = provider_realtime_url
if provider_realtime_base_url is not None:
provider_obj["realtime_base_url"] = provider_realtime_base_url
models_obj = provider_obj.get("models")
if not isinstance(models_obj, dict):
models_obj = {}
model_config: Dict[str, Any] = {
"model": model,
"label": label or model_name,
}
if model_api_key is not None:
model_config["api_key"] = model_api_key
if realtime is not None:
model_config["realtime"] = bool(realtime)
if realtime_url is not None:
model_config["realtime_url"] = realtime_url
if realtime_base_url is not None:
model_config["realtime_base_url"] = realtime_base_url
if reasoning:
model_config["reasoning"] = reasoning
if max_input_tokens is not None:
model_config["max_input_tokens"] = max_input_tokens
if max_output_tokens is not None:
model_config["max_output_tokens"] = max_output_tokens
if model_settings is not None:
model_config["model_settings"] = model_settings
models_obj[model_name] = model_config
provider_obj["models"] = models_obj
providers[provider_name] = provider_obj
model_ref = _format_model_ref(provider_name, model_name)
if set_default or not config.get("default"):
config["default"] = model_ref
if set_voice_default and realtime:
config["voice_default"] = model_ref
save_models_config(config)
def remove_model(name: str) -> bool:
"""Remove a user-defined model configuration. Returns True if removed.
Note: Cannot remove project default models, only user-defined ones.
"""
config = load_models_config()
parsed = _parse_model_ref(name)
if not parsed:
return False
provider_name, model_name = parsed
providers = config.get("providers")
if not isinstance(providers, dict):
return False
provider = providers.get(provider_name)
if not isinstance(provider, dict):
return False
models_obj = provider.get("models")
if not isinstance(models_obj, dict) or model_name not in models_obj:
return False
del models_obj[model_name]
provider["models"] = models_obj
providers[provider_name] = provider
config["providers"] = providers
normalized = normalize_model_ref(name)
if normalized and config.get("default") == normalized:
config["default"] = None
save_models_config(config)
return True
def set_default_model(name: str) -> bool:
"""Set the default model. Returns True if successful.
Works with both project default models and user-defined models.
"""
normalized = normalize_model_ref(name)
config = get_model_config(normalized) if normalized else None
if not normalized or not config or config.get("realtime"):
return False
config = load_models_config()
config["default"] = normalized
save_models_config(config)
return True
def set_voice_default_model(name: str) -> bool:
normalized = normalize_model_ref(name)
if not normalized:
return False
config = get_model_config(normalized)
if not config or not config.get("realtime"):
return False
models_config = load_models_config()
models_config["voice_default"] = normalized
save_models_config(models_config)
return True
def list_text_models() -> List[Dict[str, Any]]:
return [m for m in list_models() if not m.get("realtime")]
def list_voice_models() -> List[Dict[str, Any]]:
return [m for m in list_models() if m.get("realtime")]
def resolve_voice_settings_for_realtime(model_name: Optional[str] = None) -> Dict[str, Any]:
name = normalize_model_ref(model_name) if model_name else get_voice_default_model_name()
if not name:
return {}
merged = get_model_config(name)
if not merged or not merged.get("realtime"):
return {}
runtime = resolve_model_for_runtime(name)
provider = (runtime.get("provider") or "").strip().lower()
api_key = runtime.get("api_key") or ""
base_url = runtime.get("base_url")
settings: Dict[str, Any] = {
"openai_api_key": api_key,
"openai_realtime_url": None,
"openai_realtime_base_url": None,
}
if provider == "azure":
if base_url:
settings["openai_base_url"] = base_url
settings["azure_openai_endpoint"] = base_url
api_version = runtime.get("api_version")
if api_version:
settings["azure_openai_api_version"] = api_version
deployment = runtime.get("model")
if deployment:
settings["azure_openai_deployment_name"] = deployment
else:
if base_url:
settings["openai_base_url"] = base_url
else:
settings["openai_base_url"] = "https://api.openai.com/v1"
model_id = runtime.get("model")
if model_id:
settings["openai_realtime_model"] = model_id
realtime_url = merged.get("realtime_url")
if realtime_url:
settings["openai_realtime_url"] = _expand_env_vars(realtime_url)
realtime_base_url = merged.get("realtime_base_url")
if realtime_base_url:
settings["openai_realtime_base_url"] = _expand_env_vars(realtime_base_url)
return settings
def resolve_model_for_runtime(
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Resolve a model configuration for runtime use.
Returns a dict with:
- provider: str ("openai", "azure", "openai-compatible", "litellm")
- model: str (the model identifier)
- api_key: str (resolved API key from model config or environment)
- base_url: Optional[str]
- api_version: Optional[str] (for Azure)
- reasoning: Optional[str] ("low", "medium", "high")
- max_input_tokens: Optional[int] (model's max input context)
- max_output_tokens: Optional[int] (model's max output tokens)
- model_settings: Optional[Dict] (model-specific settings)
- label: str
Expands environment variables in api_key and base_url.
Uses merged models (project defaults + user config).
"""
name = normalize_model_ref(model_name) if model_name else get_default_model_name()
if not name:
return {}
merged = get_model_config(name)
if not merged:
return {}
provider = (merged.get("provider") or "openai").strip().lower()
result: Dict[str, Any] = {
"name": name,
"provider": provider,
"provider_name": merged.get("provider_name"),
"model": _expand_env_vars(merged.get("model", "")),
"label": merged.get("label", name),
"reasoning": merged.get("reasoning"),
"max_input_tokens": merged.get("max_input_tokens"),
"max_output_tokens": merged.get("max_output_tokens"),
}
if merged.get("model_settings"):
result["model_settings"] = merged["model_settings"]
if merged.get("base_url"):
result["base_url"] = _expand_env_vars(merged["base_url"])
if merged.get("api_version"):
result["api_version"] = merged["api_version"]
api_key_value = merged.get("api_key")
if api_key_value:
resolved_key = _expand_env_vars(api_key_value)
result["api_key"] = resolved_key
else:
result["api_key"] = None
if not result.get("api_key"):
if provider == "azure":
result["api_key"] = os.getenv("AZURE_OPENAI_API_KEY")
elif provider == "litellm":
model_str = result.get("model", "")
if model_str.startswith("anthropic/") or model_str.startswith("claude"):
result["api_key"] = os.getenv("ANTHROPIC_API_KEY")
elif model_str.startswith("gemini/") or model_str.startswith("google/"):
result["api_key"] = os.getenv("GOOGLE_API_KEY")
else:
result["api_key"] = os.getenv("OPENAI_API_KEY")
else:
result["api_key"] = os.getenv("OPENAI_API_KEY")
return result
def apply_api_keys_to_env() -> None:
"""Apply model API keys to environment variables.
Sets OPENAI_API_KEY from the default model's api_key if not already set.
This ensures the SDK can find the API key without explicit configuration.
"""
# Get the default model's API key
default_config = resolve_model_for_runtime()
if not default_config:
return
api_key = default_config.get("api_key")
provider = default_config.get("provider", "openai")
if api_key:
# Set the appropriate environment variable based on provider
if provider == "azure":
if "AZURE_OPENAI_API_KEY" not in os.environ:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
elif provider == "litellm":
model_str = default_config.get("model", "")
if model_str.startswith("anthropic/") or model_str.startswith("claude"):
if "ANTHROPIC_API_KEY" not in os.environ:
os.environ["ANTHROPIC_API_KEY"] = api_key
elif model_str.startswith("gemini/") or model_str.startswith("google/"):
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = api_key
else:
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = api_key
else:
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = api_key
def is_user_provider(name: str) -> bool:
"""Check if a provider exists in the user config."""
config = load_models_config()
providers = config.get("providers")
if not isinstance(providers, dict):
return False
return name in providers
def list_providers() -> List[Dict[str, Any]]:
"""List all providers (merged from defaults and user config) with summary info."""
merged = get_merged_providers()
user_config = load_models_config()
user_providers = user_config.get("providers")
if not isinstance(user_providers, dict):
user_providers = {}
results: List[Dict[str, Any]] = []
for name, provider in merged.items():
if not isinstance(provider, dict):
continue
provider_type = provider.get("type") or provider.get("provider") or "unknown"
models = provider.get("models")
model_count = len(models) if isinstance(models, dict) else 0
results.append({
"name": name,
"type": provider_type,
"base_url": provider.get("base_url"),
"has_api_key": bool(provider.get("api_key")),
"model_count": model_count,
"is_user_defined": name in user_providers,
})
results.sort(key=lambda item: item["name"])
return results
def get_provider(name: str) -> Optional[Dict[str, Any]]:
"""Get a single provider's full config from merged providers."""
merged = get_merged_providers()
provider = merged.get(name)
if not isinstance(provider, dict):
return None
return dict(provider)
def add_provider(
name: str,
provider_type: str,
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
realtime_url: Optional[str] = None,
realtime_base_url: Optional[str] = None,
) -> None:
"""Create or update a provider profile in user config (without touching models)."""
config = load_models_config()
providers = config.setdefault("providers", {})
provider_obj = providers.get(name)
if not isinstance(provider_obj, dict):
provider_obj = {}
provider_obj["type"] = provider_type
if api_key is not None:
provider_obj["api_key"] = api_key
if base_url is not None:
provider_obj["base_url"] = base_url
if api_version is not None:
provider_obj["api_version"] = api_version
if realtime_url is not None:
provider_obj["realtime_url"] = realtime_url
if realtime_base_url is not None:
provider_obj["realtime_base_url"] = realtime_base_url
# Preserve existing models
if "models" not in provider_obj:
provider_obj["models"] = {}
providers[name] = provider_obj
save_models_config(config)
def update_provider(name: str, **kwargs: Any) -> bool:
"""Update specific fields on a user-defined provider. Returns False if not found."""
config = load_models_config()
providers = config.get("providers")
if not isinstance(providers, dict):
return False
provider_obj = providers.get(name)
if not isinstance(provider_obj, dict):
return False
allowed_keys = {"type", "api_key", "base_url", "api_version", "realtime_url", "realtime_base_url"}
for key, value in kwargs.items():
if key in allowed_keys and value is not None:
provider_obj[key] = value
providers[name] = provider_obj
save_models_config(config)
return True
def remove_provider(name: str) -> bool:
"""Remove a user-defined provider and all its models. Returns False if not found."""
config = load_models_config()
providers = config.get("providers")
if not isinstance(providers, dict) or name not in providers:
return False
del providers[name]
# Clear default/voice_default if they referenced models under this provider
default_ref = config.get("default")
if default_ref:
parsed = _parse_model_ref(default_ref)
if parsed and parsed[0] == name:
config["default"] = None
voice_ref = config.get("voice_default")
if voice_ref:
parsed = _parse_model_ref(voice_ref)
if parsed and parsed[0] == name:
config["voice_default"] = None
save_models_config(config)
return True
def has_models_configured() -> bool:
"""Check whether a usable default model is configured."""
resolved = resolve_model_for_runtime()
if not resolved:
return False
provider = (resolved.get("provider") or "").strip().lower()
if provider == "openai-compatible":
return bool(resolved.get("model")) and bool(resolved.get("base_url"))
return bool(resolved.get("api_key"))