diff --git a/README.md b/README.md index 9ef267d0..fa5b5b65 100644 --- a/README.md +++ b/README.md @@ -91,23 +91,17 @@ OpenViking requires the following model capabilities: #### Supported VLM Providers -OpenViking supports multiple VLM providers: - -| Provider | Model | Get API Key | -|----------|-------|-------------| -| `volcengine` | doubao | [Volcengine Console](https://console.volcengine.com/ark) | -| `openai` | gpt | [OpenAI Platform](https://platform.openai.com) | -| `anthropic` | claude | [Anthropic Console](https://console.anthropic.com) | -| `deepseek` | deepseek | [DeepSeek Platform](https://platform.deepseek.com) | -| `gemini` | gemini | [Google AI Studio](https://aistudio.google.com) | -| `moonshot` | kimi | [Moonshot Platform](https://platform.moonshot.cn) | -| `zhipu` | glm | [Zhipu Open Platform](https://open.bigmodel.cn) | -| `dashscope` | qwen | [DashScope Console](https://dashscope.console.aliyun.com) | -| `minimax` | minimax | [MiniMax Platform](https://platform.minimax.io) | -| `openrouter` | (any model) | [OpenRouter](https://openrouter.ai) | -| `vllm` | (local model) | — | - -> 💡 **Tip**: OpenViking uses a **Provider Registry** for unified model access. The system automatically detects the provider based on model name keywords, so you can switch between providers seamlessly. +OpenViking supports three VLM providers: + +| Provider | Description | Get API Key | +|----------|-------------|-------------| +| `volcengine` | 火山引擎豆包模型 | [Volcengine Console](https://console.volcengine.com/ark) | +| `openai` | OpenAI 官方 API | [OpenAI Platform](https://platform.openai.com) | +| `litellm` | 统一调用多种第三方模型 (Anthropic, DeepSeek, Gemini, vLLM, Ollama, etc.) | See [LiteLLM Providers](https://docs.litellm.ai/docs/providers) | + +> 💡 **Tip**: +> - `litellm` 支持通过统一接口调用多种模型,model 字段需遵循 [LiteLLM 格式规范](https://docs.litellm.ai/docs/providers) +> - 系统自动检测常见模型(如 `claude-*`, `deepseek-*`, `gemini-*`, `hosted_vllm/*`, `ollama/*` 等),其他模型需按 LiteLLM 格式填写完整前缀 #### Provider-Specific Notes @@ -122,7 +116,7 @@ Volcengine supports both model names and endpoint IDs. Using model names is reco "provider": "volcengine", "model": "doubao-seed-1-6-240615", "api_key": "your-api-key", - "api_base" : "https://ark.cn-beijing.volces.com/api/v3", + "api_base": "https://ark.cn-beijing.volces.com/api/v3" } } ``` @@ -135,7 +129,7 @@ You can also use endpoint IDs (found in [Volcengine ARK Console](https://console "provider": "volcengine", "model": "ep-20241220174930-xxxxx", "api_key": "your-api-key", - "api_base" : "https://ark.cn-beijing.volces.com/api/v3", + "api_base": "https://ark.cn-beijing.volces.com/api/v3" } } ``` @@ -143,17 +137,30 @@ You can also use endpoint IDs (found in [Volcengine ARK Console](https://console
-Zhipu AI (智谱) +OpenAI + +Use OpenAI's official API: + +```json +{ + "vlm": { + "provider": "openai", + "model": "gpt-4o", + "api_key": "your-api-key", + "api_base": "https://api.openai.com/v1" + } +} +``` -If you're on Zhipu's coding plan, use the coding API endpoint: +You can also use a custom OpenAI-compatible endpoint: ```json { "vlm": { - "provider": "zhipu", - "model": "glm-4-plus", + "provider": "openai", + "model": "gpt-4o", "api_key": "your-api-key", - "api_base": "https://open.bigmodel.cn/api/coding/paas/v4" + "api_base": "https://your-custom-endpoint.com/v1" } } ``` @@ -161,44 +168,52 @@ If you're on Zhipu's coding plan, use the coding API endpoint:
-MiniMax (中国大陆) +LiteLLM (Anthropic, DeepSeek, Gemini, vLLM, Ollama, etc.) -For MiniMax's mainland China platform (minimaxi.com), specify the API base: +LiteLLM provides unified access to various models. The `model` field should follow LiteLLM's naming convention: ```json { "vlm": { - "provider": "minimax", - "model": "abab6.5s-chat", - "api_key": "your-api-key", - "api_base": "https://api.minimaxi.com/v1" + "provider": "litellm", + "model": "claude-3-5-sonnet-20240620", + "api_key": "your-anthropic-api-key" } } ``` -
+**Common model formats:** -
-Local Models (vLLM) +| Provider | Model Example | Notes | +|----------|---------------|-------| +| Anthropic | `claude-3-5-sonnet-20240620` | Auto-detected, uses `ANTHROPIC_API_KEY` | +| DeepSeek | `deepseek-chat` | Auto-detected, uses `DEEPSEEK_API_KEY` | +| Gemini | `gemini-pro` | Auto-detected, uses `GEMINI_API_KEY` | +| OpenRouter | `openrouter/openai/gpt-4o` | Full prefix required | +| vLLM | `hosted_vllm/llama-3.1-8b` | Set `api_base` to vLLM server | +| Ollama | `ollama/llama3.1` | Set `api_base` to Ollama server | -Run OpenViking with your own local models using vLLM: +**Local Models (vLLM / Ollama):** ```bash -# Start vLLM server -vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000 + +# Start Ollama +ollama serve ``` ```json +// Ollama { "vlm": { - "provider": "vllm", - "model": "meta-llama/Llama-3.1-8B-Instruct", - "api_key": "dummy", - "api_base": "http://localhost:8000/v1" + "provider": "litellm", + "model": "ollama/llama3.1", + "api_base": "http://localhost:11434" } } ``` +For complete model support, see [LiteLLM Providers Documentation](https://docs.litellm.ai/docs/providers). +
### 3. Environment Configuration @@ -234,7 +249,7 @@ Create a configuration file `~/.openviking/ov.conf`, remove the comments before } ``` -> **Note**: For embedding models, currently `volcengine` (Doubao), `openai`, and `jina` providers are supported. For VLM models, we support multiple providers including volcengine, openai, deepseek, anthropic, gemini, moonshot, zhipu, dashscope, minimax, and more. +> **Note**: For embedding models, currently `volcengine` (Doubao), `openai`, and `jina` providers are supported. For VLM models, we support three providers: `volcengine`, `openai`, and `litellm`. The `litellm` provider supports various models including Anthropic (Claude), DeepSeek, Gemini, Moonshot, Zhipu, DashScope, MiniMax, vLLM, Ollama, and more. #### Configuration Examples diff --git a/openviking/models/vlm/__init__.py b/openviking/models/vlm/__init__.py index e9d01f4f..e58b5196 100644 --- a/openviking/models/vlm/__init__.py +++ b/openviking/models/vlm/__init__.py @@ -6,14 +6,7 @@ from .backends.openai_vlm import OpenAIVLM from .backends.volcengine_vlm import VolcEngineVLM from .base import VLMBase, VLMFactory -from .registry import ( - PROVIDERS, - ProviderSpec, - find_by_model, - find_by_name, - find_gateway, - get_all_provider_names, -) +from .registry import get_all_provider_names, is_valid_provider __all__ = [ "VLMBase", @@ -21,10 +14,6 @@ "OpenAIVLM", "VolcEngineVLM", "LiteLLMVLMProvider", - "ProviderSpec", - "PROVIDERS", - "find_by_model", - "find_by_name", - "find_gateway", "get_all_provider_names", + "is_valid_provider", ] diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index b31579c5..a537f106 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """LiteLLM VLM Provider implementation with multi-provider support.""" +import logging import os os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" @@ -15,16 +16,82 @@ from litellm import acompletion, completion from ..base import VLMBase -from ..registry import find_by_model, find_gateway + +logger = logging.getLogger(__name__) + +PROVIDER_CONFIGS: Dict[str, Dict[str, Any]] = { + "openrouter": { + "keywords": ("openrouter",), + "env_key": "OPENROUTER_API_KEY", + "litellm_prefix": "openrouter", + }, + "hosted_vllm": { + "keywords": ("hosted_vllm",), + "env_key": "HOSTED_VLLM_API_KEY", + "litellm_prefix": "hosted_vllm", + }, + "ollama": { + "keywords": ("ollama",), + "env_key": "OLLAMA_API_KEY", + "litellm_prefix": "ollama", + }, + "anthropic": { + "keywords": ("claude", "anthropic"), + "env_key": "ANTHROPIC_API_KEY", + "litellm_prefix": "anthropic", + }, + "deepseek": { + "keywords": ("deepseek",), + "env_key": "DEEPSEEK_API_KEY", + "litellm_prefix": "deepseek", + }, + "gemini": { + "keywords": ("gemini",), + "env_key": "GEMINI_API_KEY", + "litellm_prefix": "gemini", + }, + "openai": { + "keywords": ("gpt", "o1", "o3", "o4"), + "env_key": "OPENAI_API_KEY", + "litellm_prefix": "", + }, + "moonshot": { + "keywords": ("moonshot", "kimi"), + "env_key": "MOONSHOT_API_KEY", + "litellm_prefix": "moonshot", + }, + "zhipu": { + "keywords": ("glm", "zhipu"), + "env_key": "ZHIPUAI_API_KEY", + "litellm_prefix": "zhipu", + }, + "dashscope": { + "keywords": ("qwen", "dashscope"), + "env_key": "DASHSCOPE_API_KEY", + "litellm_prefix": "dashscope", + }, + "minimax": { + "keywords": ("minimax",), + "env_key": "MINIMAX_API_KEY", + "litellm_prefix": "minimax", + }, +} + + +def detect_provider_by_model(model: str) -> str | None: + """Detect provider by model name.""" + model_lower = model.lower() + for provider, config in PROVIDER_CONFIGS.items(): + if any(kw in model_lower for kw in config["keywords"]): + return provider + return None class LiteLLMVLMProvider(VLMBase): """ Multi-provider VLM implementation based on LiteLLM. - Supports OpenRouter, Anthropic, OpenAI, Gemini, DeepSeek, VolcEngine and many other providers - through a unified interface. Provider-specific logic is driven by the registry - (see providers/registry.py) — no if-elif chains needed here. + Supports various providers through LiteLLM's unified interface. """ def __init__(self, config: Dict[str, Any]): @@ -33,11 +100,10 @@ def __init__(self, config: Dict[str, Any]): self._provider_name = config.get("provider") self._extra_headers = config.get("extra_headers") or {} self._thinking = config.get("thinking", False) - - self._gateway = find_gateway(self._provider_name, self.api_key, self.api_base) + self._detected_provider: str | None = None if self.api_key: - self._setup_env(self.api_key, self.api_base, self.model) + self._setup_env(self.api_key, self.model) if self.api_base: litellm.api_base = self.api_base @@ -45,66 +111,63 @@ def __init__(self, config: Dict[str, Any]): litellm.suppress_debug_info = True litellm.drop_params = True - def _setup_env(self, api_key: str, api_base: str | None, model: str | None) -> None: + def _setup_env(self, api_key: str, model: str | None) -> None: """Set environment variables based on detected provider.""" - spec = self._gateway or find_by_model(model or "") - if not spec: - return - - if self._gateway: - os.environ[spec.env_key] = api_key + provider = self._provider_name + if not provider and model: + provider = detect_provider_by_model(model) + + if provider and provider in PROVIDER_CONFIGS: + env_key = PROVIDER_CONFIGS[provider]["env_key"] + os.environ[env_key] = api_key + self._detected_provider = provider else: - os.environ.setdefault(spec.env_key, api_key) - - effective_base = api_base or spec.default_api_base - for env_name, env_val in spec.env_extras: - resolved = env_val.replace("{api_key}", api_key) - resolved = resolved.replace("{api_base}", effective_base or "") - os.environ.setdefault(env_name, resolved) + os.environ["OPENAI_API_KEY"] = api_key def _resolve_model(self, model: str) -> str: - """Resolve model name by applying provider/gateway prefixes.""" - if self._gateway: - if self._gateway.strip_model_prefix: - model = model.split("/")[-1] - prefix = self._gateway.litellm_prefix + """Resolve model name by applying provider prefixes.""" + provider = self._detected_provider or detect_provider_by_model(model) + + if provider and provider in PROVIDER_CONFIGS: + prefix = PROVIDER_CONFIGS[provider]["litellm_prefix"] if prefix and not model.startswith(f"{prefix}/"): - model = f"{prefix}/{model}" + return f"{prefix}/{model}" return model - if self._provider_name == "openai" and self.api_base: - from openviking.models.vlm.registry import find_by_name - openai_spec = find_by_name("openai") - is_openai_official = "api.openai.com" in self.api_base - if openai_spec and not is_openai_official and not model.startswith("openai/"): - return f"openai/{model}" - - spec = find_by_model(model) - if spec and spec.litellm_prefix: - if not any(model.startswith(s) for s in spec.skip_prefixes): - model = f"{spec.litellm_prefix}/{model}" + if self.api_base and not model.startswith(("openai/", "hosted_vllm/", "ollama/")): + return f"openai/{model}" + return model - def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: - """Apply model-specific parameter overrides from the registry.""" - model_lower = model.lower() - spec = find_by_model(model) - if spec: - for pattern, overrides in spec.model_overrides: - if pattern in model_lower: - kwargs.update(overrides) - return + def _detect_image_format(self, data: bytes) -> str: + """Detect image format from magic bytes. + + Supported formats: PNG, JPEG, GIF, WebP + """ + if len(data) < 8: + logger.warning(f"[LiteLLMVLM] Image data too small: {len(data)} bytes") + return "image/png" + + if data[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + elif data[:2] == b"\xff\xd8": + return "image/jpeg" + elif data[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + return "image/webp" - if self._provider_name == "volcengine": - kwargs["thinking"] = {"type": "enabled" if self._thinking else "disabled"} + logger.warning(f"[LiteLLMVLM] Unknown image format, magic bytes: {data[:8].hex()}") + return "image/png" def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: """Prepare image data for vision completion.""" if isinstance(image, bytes): b64 = base64.b64encode(image).decode("utf-8") + mime_type = self._detect_image_format(image) return { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{b64}"}, + "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } elif isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) @@ -119,7 +182,8 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: ".webp": "image/webp", }.get(suffix, "image/png") with open(path, "rb") as f: - b64 = base64.b64encode(f.read()).decode("utf-8") + data = f.read() + b64 = base64.b64encode(data).decode("utf-8") return { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, @@ -135,8 +199,6 @@ def _build_kwargs(self, model: str, messages: list) -> dict[str, Any]: "temperature": self.temperature, } - self._apply_model_overrides(model, kwargs) - if self.api_key: kwargs["api_key"] = self.api_key if self.api_base: @@ -150,11 +212,7 @@ def get_completion(self, prompt: str, thinking: bool = False) -> str: """Get text completion synchronously.""" model = self._resolve_model(self.model or "gpt-4o-mini") messages = [{"role": "user", "content": prompt}] - original_thinking = self._thinking - if thinking: - self._thinking = thinking kwargs = self._build_kwargs(model, messages) - self._thinking = original_thinking response = completion(**kwargs) self._update_token_usage_from_response(response) @@ -166,11 +224,7 @@ async def get_completion_async( """Get text completion asynchronously.""" model = self._resolve_model(self.model or "gpt-4o-mini") messages = [{"role": "user", "content": prompt}] - original_thinking = self._thinking - if thinking: - self._thinking = thinking kwargs = self._build_kwargs(model, messages) - self._thinking = original_thinking last_error = None for attempt in range(max_retries + 1): @@ -202,11 +256,7 @@ def get_vision_completion( content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] - original_thinking = self._thinking - if thinking: - self._thinking = thinking kwargs = self._build_kwargs(model, messages) - self._thinking = original_thinking response = completion(**kwargs) self._update_token_usage_from_response(response) @@ -227,11 +277,7 @@ async def get_vision_completion_async( content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] - original_thinking = self._thinking - if thinking: - self._thinking = thinking kwargs = self._build_kwargs(model, messages) - self._thinking = original_thinking response = await acompletion(**kwargs) self._update_token_usage_from_response(response) diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index d6f6effa..de59b8dc 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -4,11 +4,14 @@ import asyncio import base64 +import logging from pathlib import Path from typing import Any, Dict, List, Union from ..base import VLMBase +logger = logging.getLogger(__name__) + class OpenAIVLM(VLMBase): """OpenAI VLM backend""" @@ -17,7 +20,6 @@ def __init__(self, config: Dict[str, Any]): super().__init__(config) self._sync_client = None self._async_client = None - # Ensure provider type is correct self.provider = "openai" def get_client(self): @@ -92,13 +94,35 @@ async def get_completion_async( else: raise RuntimeError("Unknown error in async completion") + def _detect_image_format(self, data: bytes) -> str: + """Detect image format from magic bytes. + + Supported formats: PNG, JPEG, GIF, WebP + """ + if len(data) < 8: + logger.warning(f"[OpenAIVLM] Image data too small: {len(data)} bytes") + return "image/png" + + if data[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + elif data[:2] == b"\xff\xd8": + return "image/jpeg" + elif data[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + return "image/webp" + + logger.warning(f"[OpenAIVLM] Unknown image format, magic bytes: {data[:8].hex()}") + return "image/png" + def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: - """Prepare image data""" + """Prepare image data for vision completion.""" if isinstance(image, bytes): b64 = base64.b64encode(image).decode("utf-8") + mime_type = self._detect_image_format(image) return { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{b64}"}, + "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } elif isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) @@ -113,7 +137,8 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: ".webp": "image/webp", }.get(suffix, "image/png") with open(path, "rb") as f: - b64 = base64.b64encode(f.read()).decode("utf-8") + data = f.read() + b64 = base64.b64encode(data).decode("utf-8") return { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index e4c4d290..3b80bdc6 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -4,11 +4,14 @@ import asyncio import base64 +import logging from pathlib import Path from typing import Any, Dict, List, Union from .openai_vlm import OpenAIVLM +logger = logging.getLogger(__name__) + class VolcEngineVLM(OpenAIVLM): """VolcEngine VLM backend""" @@ -98,13 +101,83 @@ async def get_completion_async( else: raise RuntimeError("Unknown error in async completion") + def _detect_image_format(self, data: bytes) -> str: + """Detect image format from magic bytes. + + Returns the MIME type, or raises ValueError for unsupported formats like SVG. + + Supported formats per VolcEngine docs: + https://www.volcengine.com/docs/82379/1362931 + - JPEG, PNG, GIF, WEBP, BMP, TIFF, ICO, DIB, ICNS, SGI, JPEG2000, HEIC, HEIF + """ + if len(data) < 12: + logger.warning(f"[VolcEngineVLM] Image data too small: {len(data)} bytes") + return "image/png" + + # PNG: 89 50 4E 47 0D 0A 1A 0A + if data[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + # JPEG: FF D8 + elif data[:2] == b"\xff\xd8": + return "image/jpeg" + # GIF: GIF87a or GIF89a + elif data[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + # WEBP: RIFF....WEBP + elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + return "image/webp" + # BMP: BM + elif data[:2] == b"BM": + return "image/bmp" + # TIFF (little-endian): 49 49 2A 00 + # TIFF (big-endian): 4D 4D 00 2A + elif data[:4] == b"II*\x00" or data[:4] == b"MM\x00*": + return "image/tiff" + # ICO: 00 00 01 00 + elif data[:4] == b"\x00\x00\x01\x00": + return "image/ico" + # ICNS: 69 63 6E 73 ("icns") + elif data[:4] == b"icns": + return "image/icns" + # SGI: 01 DA + elif data[:2] == b"\x01\xda": + return "image/sgi" + # JPEG2000: 00 00 00 0C 6A 50 20 20 (JP2 signature) + elif data[:8] == b"\x00\x00\x00\x0cjP " or data[:4] == b"\xff\x4f\xff\x51": + return "image/jp2" + # HEIC/HEIF: ftyp box with heic/heif brand + # 00 00 00 XX 66 74 79 70 68 65 69 63 (heic) + # 00 00 00 XX 66 74 79 70 68 65 69 66 (heif) + elif len(data) >= 12 and data[4:8] == b"ftyp": + brand = data[8:12] + if brand == b"heic": + return "image/heic" + elif brand == b"heif": + return "image/heif" + elif brand[:3] == b"mif": + return "image/heif" + # SVG (not supported) + elif data[:4] == b" Dict[str, Any]: """Prepare image data""" if isinstance(image, bytes): b64 = base64.b64encode(image).decode("utf-8") + mime_type = self._detect_image_format(image) + logger.info( + f"[VolcEngineVLM] Preparing image from bytes, size={len(image)}, detected mime={mime_type}" + ) return { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{b64}"}, + "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } elif isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) @@ -117,6 +190,21 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", + ".bmp": "image/bmp", + ".dib": "image/bmp", + ".tiff": "image/tiff", + ".tif": "image/tiff", + ".ico": "image/ico", + ".icns": "image/icns", + ".sgi": "image/sgi", + ".j2c": "image/jp2", + ".j2k": "image/jp2", + ".jp2": "image/jp2", + ".jpc": "image/jp2", + ".jpf": "image/jp2", + ".jpx": "image/jp2", + ".heic": "image/heic", + ".heif": "image/heif", }.get(suffix, "image/png") with open(path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index cd563f9c..4b602f78 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -125,23 +125,22 @@ def create(config: Dict[str, Any]) -> VLMBase: ValueError: If provider is not supported ImportError: If related dependencies are not installed """ - provider = config.get("provider") or config.get("backend") or "openai" + provider = (config.get("provider") or config.get("backend") or "openai").lower() - use_litellm = config.get("use_litellm", True) + if provider == "volcengine": + from .backends.volcengine_vlm import VolcEngineVLM - if not use_litellm: - if provider == "openai": - from .backends.openai_vlm import OpenAIVLM + return VolcEngineVLM(config) - return OpenAIVLM(config) - elif provider == "volcengine": - from .backends.volcengine_vlm import VolcEngineVLM + elif provider == "openai": + from .backends.openai_vlm import OpenAIVLM - return VolcEngineVLM(config) + return OpenAIVLM(config) - from .backends.litellm_vlm import LiteLLMVLMProvider + else: + from .backends.litellm_vlm import LiteLLMVLMProvider - return LiteLLMVLMProvider(config) + return LiteLLMVLMProvider(config) @staticmethod def get_available_providers() -> List[str]: diff --git a/openviking/models/vlm/registry.py b/openviking/models/vlm/registry.py index 815b637a..ff3ddb3b 100644 --- a/openviking/models/vlm/registry.py +++ b/openviking/models/vlm/registry.py @@ -3,229 +3,19 @@ """ Provider Registry — single source of truth for LLM provider metadata. -Adding a new provider: - 1. Add a ProviderSpec to PROVIDERS below. - 2. Use it in config with providers["newprovider"] = {"api_key": "xxx"} - Done. Env vars, prefixing, config matching, status display all derive from here. - -Order matters — it controls match priority and fallback. Gateways first. -Every entry writes out all fields so you can copy-paste as a template. +Supported providers: volcengine, openai, litellm """ from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Optional - - -@dataclass(frozen=True) -class ProviderSpec: - """VLM Provider metadata definition. - - Placeholders in env_extras values: - {api_key} - the user's API key - {api_base} - api_base from config, or this spec's default_api_base - """ - - name: str - keywords: tuple[str, ...] - env_key: str - display_name: str = "" - - litellm_prefix: str = "" - skip_prefixes: tuple[str, ...] = () - - env_extras: tuple[tuple[str, str], ...] = () - - is_gateway: bool = False - is_local: bool = False - detect_by_key_prefix: str = "" - detect_by_base_keyword: str = "" - default_api_base: str = "" - - strip_model_prefix: bool = False - - model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () - - @property - def label(self) -> str: - return self.display_name or self.name.title() - - -PROVIDERS: tuple[ProviderSpec, ...] = ( - ProviderSpec( - name="custom", - keywords=(), - env_key="OPENAI_API_KEY", - display_name="Custom", - litellm_prefix="openai", - skip_prefixes=("openai/",), - is_gateway=True, - strip_model_prefix=True, - ), - - ProviderSpec( - name="openrouter", - keywords=("openrouter",), - env_key="OPENROUTER_API_KEY", - display_name="OpenRouter", - litellm_prefix="openrouter", - is_gateway=True, - detect_by_key_prefix="sk-or-", - detect_by_base_keyword="openrouter", - default_api_base="https://openrouter.ai/api/v1", - ), - - ProviderSpec( - name="volcengine", - keywords=("doubao", "volcengine", "ep-"), - env_key="VOLCENGINE_API_KEY", - display_name="VolcEngine", - litellm_prefix="volcengine", - skip_prefixes=("volcengine/",), - default_api_base="https://ark.cn-beijing.volces.com/api/v3", - ), - - ProviderSpec( - name="openai", - keywords=("openai", "gpt"), - env_key="OPENAI_API_KEY", - display_name="OpenAI", - litellm_prefix="", - ), - - ProviderSpec( - name="anthropic", - keywords=("anthropic", "claude"), - env_key="ANTHROPIC_API_KEY", - display_name="Anthropic", - litellm_prefix="", - ), - - ProviderSpec( - name="deepseek", - keywords=("deepseek",), - env_key="DEEPSEEK_API_KEY", - display_name="DeepSeek", - litellm_prefix="deepseek", - skip_prefixes=("deepseek/",), - ), - - ProviderSpec( - name="gemini", - keywords=("gemini",), - env_key="GEMINI_API_KEY", - display_name="Gemini", - litellm_prefix="gemini", - skip_prefixes=("gemini/",), - ), - - ProviderSpec( - name="moonshot", - keywords=("moonshot", "kimi"), - env_key="MOONSHOT_API_KEY", - display_name="Moonshot", - litellm_prefix="moonshot", - skip_prefixes=("moonshot/",), - env_extras=( - ("MOONSHOT_API_BASE", "{api_base}"), - ), - default_api_base="https://api.moonshot.ai/v1", - model_overrides=( - ("kimi-k2.5", {"temperature": 1.0}), - ), - ), - - ProviderSpec( - name="zhipu", - keywords=("zhipu", "glm", "zai"), - env_key="ZAI_API_KEY", - display_name="Zhipu AI", - litellm_prefix="zai", - skip_prefixes=("zhipu/", "zai/"), - env_extras=( - ("ZHIPUAI_API_KEY", "{api_key}"), - ), - ), - - ProviderSpec( - name="dashscope", - keywords=("qwen", "dashscope"), - env_key="DASHSCOPE_API_KEY", - display_name="DashScope", - litellm_prefix="dashscope", - skip_prefixes=("dashscope/",), - ), - - ProviderSpec( - name="minimax", - keywords=("minimax",), - env_key="MINIMAX_API_KEY", - display_name="MiniMax", - litellm_prefix="minimax", - skip_prefixes=("minimax/",), - default_api_base="https://api.minimax.io/v1", - ), - - ProviderSpec( - name="vllm", - keywords=("vllm",), - env_key="HOSTED_VLLM_API_KEY", - display_name="vLLM/Local", - litellm_prefix="hosted_vllm", - is_local=True, - ), -) - - -def find_by_model(model: str) -> ProviderSpec | None: - """Match a standard provider by model-name keyword (case-insensitive). - Skips gateways/local — those are matched by api_key/api_base instead.""" - model_lower = model.lower() - for spec in PROVIDERS: - if spec.is_gateway or spec.is_local: - continue - if any(kw in model_lower for kw in spec.keywords): - return spec - return None - - -def find_gateway( - provider_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, -) -> ProviderSpec | None: - """Detect gateway/local provider. - - Priority: - 1. provider_name — if it maps to a gateway/local spec, use it directly. - 2. api_key prefix — e.g. "sk-or-" → OpenRouter. - 3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix. - """ - if provider_name: - spec = find_by_name(provider_name) - if spec and (spec.is_gateway or spec.is_local): - return spec - - for spec in PROVIDERS: - if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): - return spec - - for spec in PROVIDERS: - if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: - return spec - - return None - - -def find_by_name(name: str) -> ProviderSpec | None: - """Find a provider spec by config field name, e.g. "dashscope".""" - for spec in PROVIDERS: - if spec.name == name: - return spec - return None +VALID_PROVIDERS: tuple[str, ...] = ("volcengine", "openai", "litellm") def get_all_provider_names() -> list[str]: """Get all provider names list.""" - return [spec.name for spec in PROVIDERS] + return list(VALID_PROVIDERS) + + +def is_valid_provider(name: str) -> bool: + """Check if provider name is valid.""" + return name.lower() in VALID_PROVIDERS diff --git a/openviking/parse/parsers/media/utils.py b/openviking/parse/parsers/media/utils.py index 1e8ad30d..4615c285 100644 --- a/openviking/parse/parsers/media/utils.py +++ b/openviking/parse/parsers/media/utils.py @@ -17,6 +17,38 @@ logger = get_logger(__name__) +def _is_svg(data: bytes) -> bool: + """Check if the data is an SVG file.""" + return data[:4] == b" Optional[bytes]: +# """Convert SVG to PNG using cairosvg or wand.""" +# try: +# import cairosvg +# return cairosvg.svg2png(bytestring=svg_data) +# except ImportError: +# pass +# except OSError: +# pass # libcairo not installed +# +# try: +# from wand.image import Image as WandImage +# with WandImage(blob=svg_data, format='svg') as img: +# img.format = 'png' +# return img.make_blob() +# except ImportError: +# pass +# +# return None + + def get_media_type(source_path: Optional[str], source_format: Optional[str]) -> Optional[str]: """ Determine media type from source path or format. @@ -85,6 +117,14 @@ async def generate_image_summary( if not isinstance(image_bytes, bytes): raise ValueError(f"Expected bytes for image file, got {type(image_bytes)}") + # Check for unsupported formats (SVG, etc.) by detecting magic bytes + # SVG format is not supported by VolcEngine VLM API, skip VLM analysis + if _is_svg(image_bytes): + logger.info( + f"[MediaUtils.generate_image_summary] SVG format detected, skipping VLM analysis: {image_uri}" + ) + return {"name": file_name, "summary": "SVG image (format not supported by VLM)"} + logger.info( f"[MediaUtils.generate_image_summary] Generating summary for image: {image_uri}" ) @@ -107,6 +147,13 @@ async def generate_image_summary( ) return {"name": file_name, "summary": response.strip()} + except ValueError as e: + if "SVG format" in str(e) or "not supported" in str(e): + logger.warning( + f"[MediaUtils.generate_image_summary] Unsupported image format for {image_uri}: {e}" + ) + return {"name": file_name, "summary": f"Unsupported image format: {str(e)}"} + raise except Exception as e: logger.error( f"[MediaUtils.generate_image_summary] Failed to generate image summary: {e}", diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index 411c7d76..e3ce672e 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -21,7 +21,7 @@ class VLMConfig(BaseModel): providers: Dict[str, Dict[str, Any]] = Field( default_factory=dict, - description="Multi-provider configuration, e.g. {'deepseek': {'api_key': 'xxx', 'api_base': 'xxx'}}", + description="Multi-provider configuration, e.g. {'openai': {'api_key': 'xxx', 'api_base': 'xxx'}}", ) default_provider: Optional[str] = Field(default=None, description="Default provider name") @@ -85,55 +85,31 @@ def _get_effective_api_key(self) -> str | None: return None def _match_provider(self, model: str | None = None) -> tuple[Dict[str, Any] | None, str | None]: - """Match provider config by model name. + """Match provider config. Returns: (provider_config_dict, provider_name) """ - from openviking.models.vlm.registry import PROVIDERS - - model_lower = (model or self.model or "").lower() - if self.provider: p = self.providers.get(self.provider) if p and p.get("api_key"): return p, self.provider - for spec in PROVIDERS: - p = self.providers.get(spec.name) - if p and any(kw in model_lower for kw in spec.keywords) and p.get("api_key"): - return p, spec.name - - for spec in PROVIDERS: - if spec.is_gateway: - p = self.providers.get(spec.name) - if p and p.get("api_key"): - return p, spec.name - - for spec in PROVIDERS: - if not spec.is_gateway: - p = self.providers.get(spec.name) - if p and p.get("api_key"): - return p, spec.name + for name, config in self.providers.items(): + if config.get("api_key"): + return config, name return None, None def get_provider_config( self, model: str | None = None - ) -> tuple[Dict[str, Any] | None, str | None, "Any | None"]: - """Get provider config and spec. + ) -> tuple[Dict[str, Any] | None, str | None]: + """Get provider config. Returns: - (provider_config_dict, provider_name, ProviderSpec) + (provider_config_dict, provider_name) """ - from openviking.models.vlm.registry import find_by_name, find_gateway - - config, name = self._match_provider(model) - if config and name: - spec = find_by_name(name) - gateway = find_gateway(name, config.get("api_key"), config.get("api_base")) - return config, name, gateway or spec - return None, None, None + return self._match_provider(model) def get_vlm_instance(self) -> Any: """Get VLM instance.""" @@ -146,7 +122,7 @@ def get_vlm_instance(self) -> Any: def _build_vlm_config_dict(self) -> Dict[str, Any]: """Build VLM instance config dict.""" - config, name, spec = self.get_provider_config() + config, name = self.get_provider_config() result = { "model": self.model, @@ -161,9 +137,6 @@ def _build_vlm_config_dict(self) -> Dict[str, Any]: result["api_base"] = config.get("api_base") result["extra_headers"] = config.get("extra_headers") - if spec and not result.get("api_base") and spec.default_api_base: - result["api_base"] = spec.default_api_base - return result def get_completion(self, prompt: str, thinking: bool = False) -> str: