import os from typing import Any from enum import Enum from smolagents import InferenceClientModel, OpenAIServerModel # type: ignore from .models import LLMProviderType ENV_KEY_MAP = { LLMProviderType.OPENAI: "OPENAI_API_KEY", LLMProviderType.GEMINI: "GEMINI_API_KEY", LLMProviderType.CLAUDE: "CLAUDE_API_KEY", LLMProviderType.HF: "HF_API_KEY", LLMProviderType.OPENROUTER: "OPENROUTER_API_KEY", } # Base URLs for OpenAI-compatible API providers BASE_URL_MAP = { LLMProviderType.OPENAI: "https://api.openai.com/v1", LLMProviderType.OPENROUTER: "https://openrouter.ai/api/v1", LLMProviderType.CLAUDE: "https://api.anthropic.com/v1", LLMProviderType.GEMINI: "https://generativelanguage.googleapis.com/v1beta/openai/", } class LLMProvider: def __init__(self, provider: LLMProviderType, model_id: str): """ provider: LLMProviderType enum value. model: model name string for the provider. If not provided, raises ValueError. """ if provider is None: raise ValueError("LLMProvider requires a provider argument or LLM_PROVIDER env variable.") else: self.provider = provider if model_id is None: raise ValueError("LLMProvider requires a model argument.") else: self.model_id = model_id key_env = ENV_KEY_MAP.get(self.provider) self.api_key = os.getenv(key_env) if key_env else None def get_model(self, **kwargs: Any) -> Any: """Return a model client for the selected provider and model.""" if not self.api_key: raise ValueError(f"API key for provider {self.provider} not found in environment.") if not self.model_id: raise ValueError("Model name must be provided.") if self.provider == LLMProviderType.HF: return InferenceClientModel( model_id=self.model_id, token=self.api_key, timeout=300 ) elif self.provider in [ LLMProviderType.OPENAI, LLMProviderType.OPENROUTER, LLMProviderType.CLAUDE, LLMProviderType.GEMINI, ]: api_base = BASE_URL_MAP.get(self.provider) if not api_base: raise ValueError(f"Base URL not configured for provider: {self.provider}") return OpenAIServerModel( model_id=self.model_id, api_key=self.api_key, api_base=api_base, timeout=300 ) else: raise ValueError(f"Unknown provider: {self.provider}")