|
46 | 46 | "litellm_prefix": "deepseek", |
47 | 47 | }, |
48 | 48 | "gemini": { |
49 | | - "keywords": ("gemini",), |
| 49 | + "keywords": ("gemini", "google"), |
50 | 50 | "env_key": "GEMINI_API_KEY", |
51 | 51 | "litellm_prefix": "gemini", |
52 | 52 | }, |
@@ -105,23 +105,24 @@ def __init__(self, config: Dict[str, Any]): |
105 | 105 | if self.api_key: |
106 | 106 | self._setup_env(self.api_key, self.model) |
107 | 107 |
|
108 | | - if self.api_base: |
109 | | - litellm.api_base = self.api_base |
110 | | - |
| 108 | + # Configure LiteLLM behavior (these are global but safe to re-set) |
111 | 109 | litellm.suppress_debug_info = True |
112 | 110 | litellm.drop_params = True |
113 | 111 |
|
114 | 112 | def _setup_env(self, api_key: str, model: str | None) -> None: |
115 | 113 | """Set environment variables based on detected provider.""" |
116 | 114 | provider = self._provider_name |
117 | | - if not provider and model: |
118 | | - provider = detect_provider_by_model(model) |
| 115 | + if (not provider or provider == "litellm") and model: |
| 116 | + detected = detect_provider_by_model(model) |
| 117 | + if detected: |
| 118 | + provider = detected |
119 | 119 |
|
120 | 120 | if provider and provider in PROVIDER_CONFIGS: |
121 | 121 | env_key = PROVIDER_CONFIGS[provider]["env_key"] |
122 | 122 | os.environ[env_key] = api_key |
123 | 123 | self._detected_provider = provider |
124 | 124 | else: |
| 125 | + # Fallback to OpenAI if provider is unknown or literal litellm |
125 | 126 | os.environ["OPENAI_API_KEY"] = api_key |
126 | 127 |
|
127 | 128 | def _resolve_model(self, model: str) -> str: |
@@ -202,7 +203,14 @@ def _build_kwargs(self, model: str, messages: list) -> dict[str, Any]: |
202 | 203 | if self.api_key: |
203 | 204 | kwargs["api_key"] = self.api_key |
204 | 205 | if self.api_base: |
205 | | - kwargs["api_base"] = self.api_base |
| 206 | + # For Gemini, LiteLLM constructs the URL itself. If user provides a full Google endpoint |
| 207 | + # as api_base, it might break the URL construction in LiteLLM. |
| 208 | + # We only pass api_base if it doesn't look like a standard Google endpoint versioned URL. |
| 209 | + is_google_endpoint = "generativelanguage.googleapis.com" in self.api_base and ( |
| 210 | + "/v1" in self.api_base or "/v1beta" in self.api_base |
| 211 | + ) |
| 212 | + if not is_google_endpoint: |
| 213 | + kwargs["api_base"] = self.api_base |
206 | 214 | if self._extra_headers: |
207 | 215 | kwargs["extra_headers"] = self._extra_headers |
208 | 216 |
|
|
0 commit comments