From dd2be9a1ee46251d00822df87100e35eec672b05 Mon Sep 17 00:00:00 2001 From: Lyle Hopkins Date: Tue, 24 Mar 2026 20:39:42 +0000 Subject: [PATCH 1/3] feat: add Ollama native API support to inference proxy Add pattern detection, provider profile, and validation probe for Ollama's native /api/chat, /api/tags, and /api/show endpoints. Proxy changes (l7/inference.rs): - POST /api/chat -> ollama_chat protocol - GET /api/tags -> ollama_model_discovery protocol - POST /api/show -> ollama_model_discovery protocol Provider profile (openshell-core/inference.rs): - New 'ollama' provider type with default endpoint http://host.openshell.internal:11434 - Supports ollama_chat, ollama_model_discovery, and OpenAI-compatible protocols (openai_chat_completions, openai_completions, model_discovery) - Credential lookup via OLLAMA_API_KEY, base URL via OLLAMA_BASE_URL Validation (backend.rs): - Ollama validation probe sends minimal /api/chat request with stream:false Tests: 4 new tests for pattern detection (ollama chat, tags, show, and GET /api/chat rejection). Signed-off-by: Lyle Hopkins --- crates/openshell-core/src/inference.rs | 19 ++++++++ crates/openshell-router/src/backend.rs | 14 ++++++ crates/openshell-sandbox/src/l7/inference.rs | 49 ++++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/crates/openshell-core/src/inference.rs b/crates/openshell-core/src/inference.rs index a06c427f..e8611e21 100644 --- a/crates/openshell-core/src/inference.rs +++ b/crates/openshell-core/src/inference.rs @@ -56,6 +56,14 @@ const OPENAI_PROTOCOLS: &[&str] = &[ const ANTHROPIC_PROTOCOLS: &[&str] = &["anthropic_messages", "model_discovery"]; +const OLLAMA_PROTOCOLS: &[&str] = &[ + "ollama_chat", + "ollama_model_discovery", + "openai_chat_completions", + "openai_completions", + "model_discovery", +]; + static OPENAI_PROFILE: InferenceProviderProfile = InferenceProviderProfile { provider_type: "openai", default_base_url: "https://api.openai.com/v1", @@ -86,6 +94,16 @@ static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile { default_headers: &[], }; +static OLLAMA_PROFILE: InferenceProviderProfile = InferenceProviderProfile { + provider_type: "ollama", + default_base_url: "http://host.openshell.internal:11434", + protocols: OLLAMA_PROTOCOLS, + credential_key_names: &["OLLAMA_API_KEY"], + base_url_config_keys: &["OLLAMA_BASE_URL", "OLLAMA_HOST"], + auth: AuthHeader::Bearer, + default_headers: &[], +}; + /// Look up the inference provider profile for a given provider type. /// /// Returns `None` for provider types that don't support inference routing @@ -95,6 +113,7 @@ pub fn profile_for(provider_type: &str) -> Option<&'static InferenceProviderProf "openai" => Some(&OPENAI_PROFILE), "anthropic" => Some(&ANTHROPIC_PROFILE), "nvidia" => Some(&NVIDIA_PROFILE), + "ollama" => Some(&OLLAMA_PROFILE), _ => None, } } diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index d82ea082..6eb01a43 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -223,6 +223,20 @@ fn validation_probe(route: &ResolvedRoute) -> Result Vec { protocol: "anthropic_messages".to_string(), kind: "messages".to_string(), }, + InferenceApiPattern { + method: "POST".to_string(), + path_glob: "/api/chat".to_string(), + protocol: "ollama_chat".to_string(), + kind: "ollama_chat".to_string(), + }, + InferenceApiPattern { + method: "GET".to_string(), + path_glob: "/api/tags".to_string(), + protocol: "ollama_model_discovery".to_string(), + kind: "ollama_tags".to_string(), + }, + InferenceApiPattern { + method: "POST".to_string(), + path_glob: "/api/show".to_string(), + protocol: "ollama_model_discovery".to_string(), + kind: "ollama_show".to_string(), + }, InferenceApiPattern { method: "GET".to_string(), path_glob: "/v1/models".to_string(), @@ -372,6 +390,37 @@ mod tests { assert!(result.is_none()); } + #[test] + fn detect_ollama_chat() { + let patterns = default_patterns(); + let result = detect_inference_pattern("POST", "/api/chat", &patterns); + assert!(result.is_some()); + assert_eq!(result.unwrap().protocol, "ollama_chat"); + } + + #[test] + fn detect_ollama_tags() { + let patterns = default_patterns(); + let result = detect_inference_pattern("GET", "/api/tags", &patterns); + assert!(result.is_some()); + assert_eq!(result.unwrap().protocol, "ollama_model_discovery"); + } + + #[test] + fn detect_ollama_show() { + let patterns = default_patterns(); + let result = detect_inference_pattern("POST", "/api/show", &patterns); + assert!(result.is_some()); + assert_eq!(result.unwrap().protocol, "ollama_model_discovery"); + } + + #[test] + fn no_match_ollama_chat_get() { + let patterns = default_patterns(); + let result = detect_inference_pattern("GET", "/api/chat", &patterns); + assert!(result.is_none()); + } + #[test] fn detect_get_models() { let patterns = default_patterns(); From 1d605b789a1df2aa853b4c6a35f14bb192063e7a Mon Sep 17 00:00:00 2001 From: Lyle Hopkins Date: Tue, 24 Mar 2026 21:03:58 +0000 Subject: [PATCH 2/3] feat: implement multi-route inference proxy support - Proto: add InferenceModelEntry message with alias/provider/model fields; add repeated models field to ClusterInferenceConfig, Set/Get request/response - Server: add upsert_multi_model_route() for storing multiple model entries under a single route slot; update resolve_route_by_name() to expand multi-model configs into per-alias ResolvedRoute entries - Router: add select_route() with alias-first, protocol-fallback strategy; add model_hint parameter to proxy_with_candidates() variants - Sandbox proxy: extract model field from JSON body as routing hint - Tests: 7 new tests covering select_route, multi-model resolution, and bundle expansion; all 291 existing tests continue to pass Signed-off-by: Lyle Hopkins --- crates/openshell-router/src/lib.rs | 91 ++++- .../tests/backend_integration.rs | 10 + crates/openshell-sandbox/src/proxy.rs | 9 +- crates/openshell-server/src/inference.rs | 335 +++++++++++++++++- proto/inference.proto | 29 +- 5 files changed, 445 insertions(+), 29 deletions(-) diff --git a/crates/openshell-router/src/lib.rs b/crates/openshell-router/src/lib.rs index a5712d9a..0fcb470e 100644 --- a/crates/openshell-router/src/lib.rs +++ b/crates/openshell-router/src/lib.rs @@ -36,6 +36,28 @@ pub struct Router { client: reqwest::Client, } +/// Select a route from `candidates` using alias-first, protocol-fallback strategy. +/// +/// 1. If `model_hint` is provided, find a candidate whose `name` matches the hint +/// **and** whose protocols include `protocol`. +/// 2. Otherwise, return the first candidate whose protocols contain `protocol`. +fn select_route<'a>( + candidates: &'a [ResolvedRoute], + protocol: &str, + model_hint: Option<&str>, +) -> Option<&'a ResolvedRoute> { + if let Some(hint) = model_hint { + if let Some(r) = candidates.iter().find(|r| { + r.name == hint && r.protocols.iter().any(|p| p == protocol) + }) { + return Some(r); + } + } + candidates + .iter() + .find(|r| r.protocols.iter().any(|p| p == protocol)) +} + impl Router { pub fn new() -> Result { let client = reqwest::Client::builder() @@ -57,8 +79,10 @@ impl Router { /// Proxy a raw HTTP request to the first compatible route from `candidates`. /// - /// Filters candidates by `source_protocol` compatibility (exact match against - /// one of the route's `protocols`), then forwards to the first match. + /// When `model_hint` is provided, the router first looks for a candidate whose + /// `name` (alias) matches the hint. If no alias matches, it falls back to + /// protocol-based selection (first candidate whose `protocols` list contains + /// `source_protocol`). pub async fn proxy_with_candidates( &self, source_protocol: &str, @@ -67,11 +91,10 @@ impl Router { headers: Vec<(String, String)>, body: bytes::Bytes, candidates: &[ResolvedRoute], + model_hint: Option<&str>, ) -> Result { let normalized_source = source_protocol.trim().to_ascii_lowercase(); - let route = candidates - .iter() - .find(|r| r.protocols.iter().any(|p| p == &normalized_source)) + let route = select_route(candidates, &normalized_source, model_hint) .ok_or_else(|| RouterError::NoCompatibleRoute(source_protocol.to_string()))?; info!( @@ -111,11 +134,10 @@ impl Router { headers: Vec<(String, String)>, body: bytes::Bytes, candidates: &[ResolvedRoute], + model_hint: Option<&str>, ) -> Result { let normalized_source = source_protocol.trim().to_ascii_lowercase(); - let route = candidates - .iter() - .find(|r| r.protocols.iter().any(|p| p == &normalized_source)) + let route = select_route(candidates, &normalized_source, model_hint) .ok_or_else(|| RouterError::NoCompatibleRoute(source_protocol.to_string()))?; info!( @@ -187,4 +209,57 @@ mod tests { let err = Router::from_config(&config).unwrap_err(); assert!(matches!(err, RouterError::Internal(_))); } + + fn make_route(name: &str, protocols: Vec<&str>) -> ResolvedRoute { + ResolvedRoute { + name: name.to_string(), + endpoint: "http://localhost".to_string(), + model: format!("{name}-model"), + api_key: "key".to_string(), + protocols: protocols.into_iter().map(String::from).collect(), + auth: config::AuthHeader::Bearer, + default_headers: Vec::new(), + } + } + + #[test] + fn select_route_protocol_fallback_when_no_hint() { + let routes = vec![ + make_route("ollama-local", vec!["openai_chat_completions"]), + make_route("anthropic-prod", vec!["anthropic_messages"]), + ]; + let r = select_route(&routes, "anthropic_messages", None).unwrap(); + assert_eq!(r.name, "anthropic-prod"); + } + + #[test] + fn select_route_alias_match_takes_priority() { + let routes = vec![ + make_route("ollama-local", vec!["openai_chat_completions"]), + make_route("openai-prod", vec!["openai_chat_completions", "openai_responses"]), + ]; + // Both support openai_chat_completions, but hint selects the second one. + let r = select_route(&routes, "openai_chat_completions", Some("openai-prod")).unwrap(); + assert_eq!(r.name, "openai-prod"); + } + + #[test] + fn select_route_alias_must_also_match_protocol() { + let routes = vec![ + make_route("ollama-local", vec!["openai_chat_completions"]), + make_route("anthropic-prod", vec!["anthropic_messages"]), + ]; + // Hint says "anthropic-prod" but protocol is openai_chat_completions — can't use it. + // Falls back to protocol match → ollama-local. + let r = select_route(&routes, "openai_chat_completions", Some("anthropic-prod")).unwrap(); + assert_eq!(r.name, "ollama-local"); + } + + #[test] + fn select_route_no_match_returns_none() { + let routes = vec![ + make_route("ollama-local", vec!["openai_chat_completions"]), + ]; + assert!(select_route(&routes, "anthropic_messages", None).is_none()); + } } diff --git a/crates/openshell-router/tests/backend_integration.rs b/crates/openshell-router/tests/backend_integration.rs index 4861bd6d..28b12a21 100644 --- a/crates/openshell-router/tests/backend_integration.rs +++ b/crates/openshell-router/tests/backend_integration.rs @@ -66,6 +66,7 @@ async fn proxy_forwards_request_to_backend() { vec![("content-type".to_string(), "application/json".to_string())], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); @@ -98,6 +99,7 @@ async fn proxy_upstream_401_returns_error() { vec![], bytes::Bytes::new(), &candidates, + None, ) .await .unwrap(); @@ -127,6 +129,7 @@ async fn proxy_no_compatible_route_returns_error() { vec![], bytes::Bytes::new(), &candidates, + None, ) .await .unwrap_err(); @@ -160,6 +163,7 @@ async fn proxy_strips_auth_header() { vec![("authorization".to_string(), "Bearer client-key".to_string())], bytes::Bytes::new(), &candidates, + None, ) .await .unwrap(); @@ -194,6 +198,7 @@ async fn proxy_mock_route_returns_canned_response() { vec![("content-type".to_string(), "application/json".to_string())], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); @@ -239,6 +244,7 @@ async fn proxy_overrides_model_in_request_body() { vec![("content-type".to_string(), "application/json".to_string())], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); @@ -277,6 +283,7 @@ async fn proxy_inserts_model_when_absent_from_body() { vec![("content-type".to_string(), "application/json".to_string())], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); @@ -332,6 +339,7 @@ async fn proxy_uses_x_api_key_for_anthropic_route() { ], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); @@ -380,6 +388,7 @@ async fn proxy_anthropic_does_not_send_bearer_auth() { vec![("content-type".to_string(), "application/json".to_string())], bytes::Bytes::from(b"{}".to_vec()), &candidates, + None, ) .await .unwrap(); @@ -436,6 +445,7 @@ async fn proxy_forwards_client_anthropic_version_header() { ], bytes::Bytes::from(body), &candidates, + None, ) .await .unwrap(); diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 1f38a2cc..f706f8cd 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -105,7 +105,7 @@ impl InferenceContext { ) -> Result { let routes = self.system_routes.read().await; self.router - .proxy_with_candidates(protocol, method, path, headers, body, &routes) + .proxy_with_candidates(protocol, method, path, headers, body, &routes, None) .await } } @@ -993,6 +993,12 @@ async fn route_inference_request( return Ok(true); } + // Extract the model field from the JSON body as a routing hint. + // If parsing fails or model is absent, we fall back to protocol-only matching. + let model_hint = serde_json::from_slice::(&request.body) + .ok() + .and_then(|v| v.get("model")?.as_str().map(String::from)); + match ctx .router .proxy_with_candidates_streaming( @@ -1002,6 +1008,7 @@ async fn route_inference_request( filtered_headers, bytes::Bytes::from(request.body.clone()), &routes, + model_hint.as_deref(), ) .await { diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index bbabaf70..7004d45c 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -3,9 +3,9 @@ use openshell_core::proto::{ ClusterInferenceConfig, GetClusterInferenceRequest, GetClusterInferenceResponse, - GetInferenceBundleRequest, GetInferenceBundleResponse, InferenceRoute, Provider, ResolvedRoute, - SetClusterInferenceRequest, SetClusterInferenceResponse, ValidatedEndpoint, - inference_server::Inference, + GetInferenceBundleRequest, GetInferenceBundleResponse, InferenceModelEntry, InferenceRoute, + Provider, ResolvedRoute, SetClusterInferenceRequest, SetClusterInferenceResponse, + ValidatedEndpoint, inference_server::Inference, }; use openshell_router::config::ResolvedRoute as RouterResolvedRoute; use openshell_router::{ValidationFailureKind, verify_backend_endpoint}; @@ -81,6 +81,35 @@ impl Inference for InferenceService { let req = request.into_inner(); let route_name = effective_route_name(&req.route_name)?; let verify = !req.no_verify; + + // Multi-model path: when models list is non-empty, use it. + if !req.models.is_empty() { + let result = upsert_multi_model_route( + self.state.store.as_ref(), + route_name, + &req.models, + verify, + ) + .await?; + + let config = result + .route + .config + .as_ref() + .ok_or_else(|| Status::internal("managed route missing config"))?; + + return Ok(Response::new(SetClusterInferenceResponse { + provider_name: config.provider_name.clone(), + model_id: config.model_id.clone(), + version: result.route.version, + route_name: route_name.to_string(), + validation_performed: !result.validation.is_empty(), + validated_endpoints: result.validation, + models: config.models.clone(), + })); + } + + // Legacy single-model path. let route = upsert_cluster_inference_route( self.state.store.as_ref(), route_name, @@ -103,6 +132,7 @@ impl Inference for InferenceService { route_name: route_name.to_string(), validation_performed: !route.validation.is_empty(), validated_endpoints: route.validation, + models: Vec::new(), })) } @@ -140,6 +170,7 @@ impl Inference for InferenceService { model_id: config.model_id.clone(), version: route.version, route_name: route_name.to_string(), + models: config.models.clone(), })) } } @@ -204,10 +235,114 @@ async fn upsert_cluster_inference_route( Ok(UpsertedInferenceRoute { route, validation }) } +/// Upsert a multi-model inference route. +/// +/// Each entry in `models` is validated against its provider, then all entries +/// are stored atomically in a single `ClusterInferenceConfig`. +async fn upsert_multi_model_route( + store: &Store, + route_name: &str, + models: &[InferenceModelEntry], + verify: bool, +) -> Result { + if models.is_empty() { + return Err(Status::invalid_argument("models list is empty")); + } + + // Validate aliases are unique. + let mut seen_aliases = std::collections::HashSet::new(); + for entry in models { + if entry.alias.trim().is_empty() { + return Err(Status::invalid_argument( + "each model entry must have a non-empty alias", + )); + } + if entry.provider_name.trim().is_empty() { + return Err(Status::invalid_argument(format!( + "model entry '{}' is missing provider_name", + entry.alias, + ))); + } + if entry.model_id.trim().is_empty() { + return Err(Status::invalid_argument(format!( + "model entry '{}' is missing model_id", + entry.alias, + ))); + } + let alias_key = entry.alias.trim().to_ascii_lowercase(); + if !seen_aliases.insert(alias_key) { + return Err(Status::invalid_argument(format!( + "duplicate alias '{}'", + entry.alias, + ))); + } + } + + // Validate each entry's provider exists and is inference-capable. + let mut validation = Vec::new(); + for entry in models { + let provider = store + .get_message_by_name::(&entry.provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| { + Status::failed_precondition(format!( + "provider '{}' (for alias '{}') not found", + entry.provider_name, entry.alias, + )) + })?; + + let resolved = resolve_provider_route(&provider)?; + + if verify { + validation.push( + verify_provider_endpoint(&provider.name, &entry.model_id, &resolved).await?, + ); + } + } + + // Use the first entry's provider as the legacy single-model fields for + // backward compat (old clients reading the config see something useful). + let config = ClusterInferenceConfig { + provider_name: models[0].provider_name.clone(), + model_id: models[0].model_id.clone(), + models: models.to_vec(), + }; + + let existing = store + .get_message_by_name::(route_name) + .await + .map_err(|e| Status::internal(format!("fetch route failed: {e}")))?; + + let route = if let Some(existing) = existing { + InferenceRoute { + id: existing.id, + name: existing.name, + config: Some(config), + version: existing.version.saturating_add(1), + } + } else { + InferenceRoute { + id: uuid::Uuid::new_v4().to_string(), + name: route_name.to_string(), + config: Some(config), + version: 1, + } + }; + + store + .put_message(&route) + .await + .map_err(|e| Status::internal(format!("persist route failed: {e}")))?; + + Ok(UpsertedInferenceRoute { route, validation }) +} + fn build_cluster_inference_config(provider: &Provider, model_id: &str) -> ClusterInferenceConfig { ClusterInferenceConfig { provider_name: provider.name.clone(), model_id: model_id.to_string(), + models: Vec::new(), } } @@ -371,12 +506,8 @@ fn find_provider_config_value(provider: &Provider, preferred_keys: &[&str]) -> O /// Resolve the inference bundle (all managed routes + revision hash). async fn resolve_inference_bundle(store: &Store) -> Result { let mut routes = Vec::new(); - if let Some(r) = resolve_route_by_name(store, CLUSTER_INFERENCE_ROUTE_NAME).await? { - routes.push(r); - } - if let Some(r) = resolve_route_by_name(store, SANDBOX_SYSTEM_ROUTE_NAME).await? { - routes.push(r); - } + routes.extend(resolve_route_by_name(store, CLUSTER_INFERENCE_ROUTE_NAME).await?); + routes.extend(resolve_route_by_name(store, SANDBOX_SYSTEM_ROUTE_NAME).await?); let now_ms = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -408,20 +539,48 @@ async fn resolve_inference_bundle(store: &Store) -> Result Result, Status> { +) -> Result, Status> { let route = store .get_message_by_name::(route_name) .await .map_err(|e| Status::internal(format!("fetch route failed: {e}")))?; let Some(route) = route else { - return Ok(None); + return Ok(Vec::new()); }; let Some(config) = route.config.as_ref() else { - return Ok(None); + return Ok(Vec::new()); }; + // Multi-model path: each model entry becomes a separate ResolvedRoute with name = alias. + if !config.models.is_empty() { + let mut results = Vec::with_capacity(config.models.len()); + for entry in &config.models { + let provider = store + .get_message_by_name::(&entry.provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| { + Status::failed_precondition(format!( + "configured provider '{}' was not found", + entry.provider_name + )) + })?; + let resolved = resolve_provider_route(&provider)?; + results.push(ResolvedRoute { + name: entry.alias.clone(), + base_url: resolved.route.endpoint, + model_id: entry.model_id.clone(), + api_key: resolved.route.api_key, + protocols: resolved.route.protocols, + provider_type: resolved.provider_type, + }); + } + return Ok(results); + } + + // Legacy single-model path. if config.provider_name.trim().is_empty() { return Err(Status::failed_precondition(format!( "route '{route_name}' is missing provider_name" @@ -447,14 +606,14 @@ async fn resolve_route_by_name( let resolved = resolve_provider_route(&provider)?; - Ok(Some(ResolvedRoute { + Ok(vec![ResolvedRoute { name: route_name.to_string(), base_url: resolved.route.endpoint, model_id: config.model_id.clone(), api_key: resolved.route.api_key, protocols: resolved.route.protocols, provider_type: resolved.provider_type, - })) + }]) } #[cfg(test)] @@ -470,6 +629,7 @@ mod tests { config: Some(ClusterInferenceConfig { provider_name: provider_name.to_string(), model_id: model_id.to_string(), + models: Vec::new(), }), version: 1, } @@ -546,10 +706,10 @@ mod tests { .await .expect("store should connect"); - let route = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) + let routes = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) .await .expect("resolution should not fail"); - assert!(route.is_none()); + assert!(routes.is_empty()); } #[tokio::test] @@ -654,6 +814,7 @@ mod tests { config: Some(ClusterInferenceConfig { provider_name: "openai-dev".to_string(), model_id: "test/model".to_string(), + models: Vec::new(), }), version: 7, }; @@ -665,6 +826,8 @@ mod tests { let managed = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) .await .expect("route should resolve") + .into_iter() + .next() .expect("managed route should exist"); assert_eq!(managed.base_url, "https://station.example.com/v1"); @@ -702,6 +865,8 @@ mod tests { let first = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) .await .expect("route should resolve") + .into_iter() + .next() .expect("managed route should exist"); assert_eq!(first.api_key, "sk-initial"); @@ -721,6 +886,8 @@ mod tests { let second = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) .await .expect("route should resolve") + .into_iter() + .next() .expect("managed route should exist"); assert_eq!(second.api_key, "sk-rotated"); } @@ -1006,4 +1173,140 @@ mod tests { let err = effective_route_name("unknown-route").unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); } + + #[tokio::test] + async fn resolve_multi_model_route_returns_per_alias_routes() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let openai = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-openai"); + let anthropic = make_provider("anthropic-dev", "anthropic", "ANTHROPIC_API_KEY", "sk-ant"); + store.put_message(&openai).await.expect("persist openai"); + store + .put_message(&anthropic) + .await + .expect("persist anthropic"); + + let route = InferenceRoute { + id: "r-multi".to_string(), + name: CLUSTER_INFERENCE_ROUTE_NAME.to_string(), + config: Some(ClusterInferenceConfig { + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o".to_string(), + models: vec![ + InferenceModelEntry { + alias: "my-gpt".to_string(), + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o".to_string(), + }, + InferenceModelEntry { + alias: "my-claude".to_string(), + provider_name: "anthropic-dev".to_string(), + model_id: "claude-sonnet-4-20250514".to_string(), + }, + ], + }), + version: 1, + }; + store.put_message(&route).await.expect("persist route"); + + let resolved = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("should resolve"); + + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].name, "my-gpt"); + assert_eq!(resolved[0].model_id, "gpt-4o"); + assert_eq!(resolved[0].provider_type, "openai"); + assert_eq!(resolved[1].name, "my-claude"); + assert_eq!(resolved[1].model_id, "claude-sonnet-4-20250514"); + assert_eq!(resolved[1].provider_type, "anthropic"); + } + + #[tokio::test] + async fn bundle_with_multi_model_route_includes_all_aliases() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let openai = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-openai"); + let anthropic = make_provider("anthropic-dev", "anthropic", "ANTHROPIC_API_KEY", "sk-ant"); + store.put_message(&openai).await.expect("persist openai"); + store + .put_message(&anthropic) + .await + .expect("persist anthropic"); + + let route = InferenceRoute { + id: "r-multi".to_string(), + name: CLUSTER_INFERENCE_ROUTE_NAME.to_string(), + config: Some(ClusterInferenceConfig { + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o".to_string(), + models: vec![ + InferenceModelEntry { + alias: "my-gpt".to_string(), + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o".to_string(), + }, + InferenceModelEntry { + alias: "my-claude".to_string(), + provider_name: "anthropic-dev".to_string(), + model_id: "claude-sonnet-4-20250514".to_string(), + }, + ], + }), + version: 1, + }; + store.put_message(&route).await.expect("persist route"); + + let bundle = resolve_inference_bundle(&store) + .await + .expect("bundle should resolve"); + + assert_eq!(bundle.routes.len(), 2); + let names: Vec<&str> = bundle.routes.iter().map(|r| r.name.as_str()).collect(); + assert!(names.contains(&"my-gpt")); + assert!(names.contains(&"my-claude")); + } + + #[tokio::test] + async fn upsert_multi_model_route_stores_and_retrieves() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let openai = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-openai"); + store.put_message(&openai).await.expect("persist"); + + let models = vec![InferenceModelEntry { + alias: "fast-gpt".to_string(), + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o-mini".to_string(), + }]; + + let result = upsert_multi_model_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + &models, + false, + ) + .await + .expect("upsert should succeed"); + + assert_eq!(result.route.version, 1); + assert_eq!(result.route.config.as_ref().unwrap().models.len(), 1); + assert_eq!(result.route.config.as_ref().unwrap().models[0].alias, "fast-gpt"); + + // Legacy fields should be populated from first entry + assert_eq!( + result.route.config.as_ref().unwrap().provider_name, + "openai-dev" + ); + assert_eq!( + result.route.config.as_ref().unwrap().model_id, + "gpt-4o-mini" + ); + } } diff --git a/proto/inference.proto b/proto/inference.proto index a15f4b84..51b32b10 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -31,10 +31,24 @@ service Inference { // Only `provider_name` and `model_id` are stored; endpoint, protocols, // credentials, and auth style are resolved from the provider at bundle time. message ClusterInferenceConfig { - // Provider record name backing this route. + // Provider record name backing this route (legacy single-model). string provider_name = 1; - // Model identifier to force on generation calls. + // Model identifier to force on generation calls (legacy single-model). string model_id = 2; + // Multi-model entries. When non-empty, takes precedence over the single + // provider_name/model_id fields above. + repeated InferenceModelEntry models = 3; +} + +// A single model entry within a multi-model inference configuration. +message InferenceModelEntry { + // Short alias used by agents in the "model" field (e.g. "qwen-coder"). + // The proxy matches this against the model field in API request payloads. + string alias = 1; + // Provider record name (e.g. "ollama-local", "openai-dev"). + string provider_name = 2; + // Backend model identifier (e.g. "qwen3-coder:30b", "gpt-5.4"). + string model_id = 3; } // Storage envelope for the managed cluster inference route. @@ -49,9 +63,9 @@ message InferenceRoute { } message SetClusterInferenceRequest { - // Provider record name to use for credentials + endpoint mapping. + // Provider record name to use for credentials + endpoint mapping (legacy single-model). string provider_name = 1; - // Model identifier to force on generation calls. + // Model identifier to force on generation calls (legacy single-model). string model_id = 2; // Route name to target. Empty string defaults to "inference.local" (user-facing). // Use "sandbox-system" for the sandbox system-level inference route. @@ -60,6 +74,9 @@ message SetClusterInferenceRequest { bool verify = 4; // Skip synchronous endpoint validation before persistence. bool no_verify = 5; + // Multi-model entries. When non-empty, takes precedence over single + // provider_name/model_id above. + repeated InferenceModelEntry models = 6; } message ValidatedEndpoint { @@ -77,6 +94,8 @@ message SetClusterInferenceResponse { bool validation_performed = 5; // The concrete endpoints that were probed during validation, when available. repeated ValidatedEndpoint validated_endpoints = 6; + // Multi-model entries that were configured (echoed back). + repeated InferenceModelEntry models = 7; } message GetClusterInferenceRequest { @@ -91,6 +110,8 @@ message GetClusterInferenceResponse { uint64 version = 3; // Route name that was queried. string route_name = 4; + // Multi-model entries when configured. + repeated InferenceModelEntry models = 5; } message GetInferenceBundleRequest {} From ab7117584a010d1178bb789ad8bee108d48dded1 Mon Sep 17 00:00:00 2001 From: Lyle Hopkins Date: Wed, 25 Mar 2026 00:01:49 +0000 Subject: [PATCH 3/3] feat(cli): add multi-model inference CLI and codex URL fixes - Add --model-alias flag to 'inference set' for multi-model config (e.g. --model-alias gpt=openai/gpt-4 --model-alias claude=anthropic/claude-sonnet-4-20250514) - Add gateway_inference_set_multi() handler in run.rs - Update inference get/print to display multi-model entries - Import InferenceModelEntry proto type in CLI - Fix build_backend_url to always strip /v1 prefix for codex paths - Add /v1/codex/* inference pattern for openai_responses protocol - Fix backend tests to use /v1 endpoint suffix Signed-off-by: Lyle Hopkins --- architecture/inference-routing.md | 73 +++++++-- crates/openshell-cli/src/main.rs | 52 +++++-- crates/openshell-cli/src/run.rs | 141 +++++++++++++++++- crates/openshell-router/src/backend.rs | 23 ++- crates/openshell-router/src/lib.rs | 21 ++- .../tests/backend_integration.rs | 8 +- crates/openshell-sandbox/src/l7/inference.rs | 6 + crates/openshell-server/src/inference.rs | 58 +++++-- 8 files changed, 322 insertions(+), 60 deletions(-) diff --git a/architecture/inference-routing.md b/architecture/inference-routing.md index 0d3a95af..b2cf2f26 100644 --- a/architecture/inference-routing.md +++ b/architecture/inference-routing.md @@ -21,8 +21,9 @@ sequenceDiagram Agent->>Proxy: CONNECT inference.local:443 Proxy->>Proxy: TLS terminate (MITM) Proxy->>Proxy: Parse HTTP, detect pattern - Proxy->>Router: proxy_with_candidates() - Router->>Router: Select route by protocol + Proxy->>Proxy: Extract model hint from body + Proxy->>Router: proxy_with_candidates(model_hint) + Router->>Router: Select route by alias or protocol Router->>Router: Rewrite auth + model Router->>Backend: HTTPS request Backend->>Router: Response headers + body stream @@ -41,15 +42,16 @@ File: `crates/openshell-core/src/inference.rs` `InferenceProviderProfile` is the single source of truth for provider-specific inference knowledge: default endpoint, supported protocols, credential key lookup order, auth header style, and default headers. -Three profiles are defined: +Four profiles are defined: | Provider | Default Base URL | Protocols | Auth | Default Headers | -|----------|-----------------|-----------|------|-----------------| +|----------|-----------------|-----------|------|------------------| | `openai` | `https://api.openai.com/v1` | `openai_chat_completions`, `openai_completions`, `openai_responses`, `model_discovery` | `Authorization: Bearer` | (none) | | `anthropic` | `https://api.anthropic.com/v1` | `anthropic_messages`, `model_discovery` | `x-api-key` | `anthropic-version: 2023-06-01` | | `nvidia` | `https://integrate.api.nvidia.com/v1` | `openai_chat_completions`, `openai_completions`, `openai_responses`, `model_discovery` | `Authorization: Bearer` | (none) | +| `ollama` | `http://host.openshell.internal:11434` | `ollama_chat`, `ollama_model_discovery`, `openai_chat_completions`, `openai_completions`, `model_discovery` | `Authorization: Bearer` | (none) | -Each profile also defines `credential_key_names` (e.g. `["OPENAI_API_KEY"]`) and `base_url_config_keys` (e.g. `["OPENAI_BASE_URL"]`) used by the gateway to resolve credentials and endpoint overrides from provider records. +Each profile also defines `credential_key_names` (e.g. `["OPENAI_API_KEY"]`) and `base_url_config_keys` (e.g. `["OPENAI_BASE_URL"]`) used by the gateway to resolve credentials and endpoint overrides from provider records. The Ollama profile uses `OLLAMA_API_KEY` for credentials and checks both `OLLAMA_BASE_URL` and `OLLAMA_HOST` for endpoint overrides. Its default endpoint uses `host.openshell.internal` so sandboxes can reach an Ollama instance running on the gateway host. Unknown provider types return `None` from `profile_for()` and default to `Bearer` auth with no default headers via `auth_for_provider_type()`. @@ -70,7 +72,19 @@ The gateway implements the `Inference` gRPC service defined in `proto/inference. 5. Builds a managed route spec that stores only `provider_name` and `model_id`. The spec intentionally leaves `base_url`, `api_key`, and `protocols` empty -- these are resolved dynamically at bundle time from the provider record. 6. Upserts the route with name `inference.local`. Version starts at 1 and increments monotonically on each update. -`GetClusterInference` returns `provider_name`, `model_id`, and `version` for the managed route. Returns `NOT_FOUND` if cluster inference is not configured. +`GetClusterInference` returns `provider_name`, `model_id`, `version`, and any configured `models` entries for the managed route. Returns `NOT_FOUND` if cluster inference is not configured. + +### Multi-model routes + +`upsert_multi_model_route()` configures multiple provider/model pairs on a single route, each identified by a short alias: + +1. Validates that each `InferenceModelEntry` has non-empty `alias`, `provider_name`, and `model_id`. +2. Checks that aliases are unique (case-insensitive). +3. Verifies each provider exists and is inference-capable. +4. Optionally probes each endpoint (skipped with `--no-verify`). +5. Stores the full `models` vector in the route config. The first entry's provider/model are also written to the legacy single-model fields for backward compatibility. + +At bundle time, each `InferenceModelEntry` is resolved into a separate `ResolvedRoute` whose `name` is set to the alias. The router's alias-first selection (see Route Selection) then matches the agent's `model` field against these names. ### Bundle delivery @@ -92,11 +106,15 @@ File: `proto/inference.proto` Key messages: -- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + optional `no_verify` override, with verification enabled by default -- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `version` +- `InferenceModelEntry` -- `alias` + `provider_name` + `model_id` (a single alias-to-provider mapping) +- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + optional `no_verify` override + `repeated InferenceModelEntry models`, with verification enabled by default +- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `version` + `repeated InferenceModelEntry models` +- `GetClusterInferenceResponse` -- `provider_name` + `model_id` + `version` + `repeated InferenceModelEntry models` - `GetInferenceBundleResponse` -- `repeated ResolvedRoute routes` + `revision` + `generated_at_ms` - `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type` +When `models` is non-empty in a set request, the gateway uses `upsert_multi_model_route()` and ignores the legacy `provider_name`/`model_id` fields. When `models` is empty, the legacy single-model path is used. + ## Data Plane (Sandbox) Files: @@ -117,7 +135,7 @@ When a `CONNECT inference.local:443` arrives: 1. Proxy responds `200 Connection Established`. 2. `handle_inference_interception()` TLS-terminates the client connection using the sandbox CA (MITM). 3. Raw HTTP requests are parsed from the TLS tunnel using `try_parse_http_request()` (supports Content-Length and chunked transfer encoding). -4. Each parsed request is passed to `route_inference_request()`. +4. Each parsed request is passed to `route_inference_request()`. Before routing, the proxy extracts a `model_hint` from the JSON request body's `model` field (if present). This hint is passed to the router for alias-based route selection. 5. The tunnel supports HTTP keep-alive: multiple requests can be processed sequentially. 6. Buffer starts at 64 KiB (`INITIAL_INFERENCE_BUF`) and grows up to 10 MiB (`MAX_INFERENCE_BUF`). Requests exceeding the max get `413 Payload Too Large`. @@ -133,10 +151,16 @@ Supported built-in patterns: | `POST` | `/v1/completions` | `openai_completions` | `completion` | | `POST` | `/v1/responses` | `openai_responses` | `responses` | | `POST` | `/v1/messages` | `anthropic_messages` | `messages` | +| `POST` | `/v1/codex/*` | `openai_responses` | `codex_responses` | | `GET` | `/v1/models` | `model_discovery` | `models_list` | | `GET` | `/v1/models/*` | `model_discovery` | `models_get` | +| `POST` | `/api/chat` | `ollama_chat` | `ollama_chat` | +| `GET` | `/api/tags` | `ollama_model_discovery` | `ollama_tags` | +| `POST` | `/api/show` | `ollama_model_discovery` | `ollama_show` | + +Query strings are stripped before matching. Path matching is exact for most patterns; `/v1/models/*` and `/v1/codex/*` match any sub-path (e.g. `/v1/models/gpt-4.1`, `/v1/codex/responses`). Absolute-form URIs (e.g. `https://inference.local/v1/chat/completions`) are normalized to path-only form by `normalize_inference_path()` before detection. -Query strings are stripped before matching. Path matching is exact for most patterns; `/v1/models/*` matches any sub-path (e.g. `/v1/models/gpt-4.1`). Absolute-form URIs (e.g. `https://inference.local/v1/chat/completions`) are normalized to path-only form by `normalize_inference_path()` before detection. +Ollama patterns use `/api/` paths (no `/v1/` prefix), matching Ollama's native API. This allows agents to use the Ollama client library directly against `inference.local`. If no pattern matches, the proxy returns `403 Forbidden` with `{"error": "connection not allowed by policy"}`. @@ -161,7 +185,16 @@ Files: ### Route selection -`proxy_with_candidates()` finds the first route whose `protocols` list contains the detected source protocol (normalized to lowercase). If no route matches, returns `RouterError::NoCompatibleRoute`. +`select_route()` picks the best route from the candidate list using a two-phase strategy: + +1. **Alias match (preferred)**: If a `model_hint` is provided (extracted from the request body's `model` field), select the first candidate whose `name` equals the hint AND whose `protocols` list contains the detected source protocol. +2. **Protocol fallback**: If no alias matches, fall back to the first candidate whose `protocols` list contains the source protocol. + +This enables multi-route configurations where the agent selects a backend by setting the `model` field to an alias name (e.g. `"model": "my-gpt"` routes to the aliased provider). If the model field is absent, not a known alias, or parsing fails, routing falls back to protocol-based selection. + +If no route matches either phase, returns `RouterError::NoCompatibleRoute`. + +`proxy_with_candidates()` and `proxy_with_candidates_streaming()` both accept an optional `model_hint: Option<&str>` parameter, passed through from the sandbox proxy. ### Request rewriting @@ -171,7 +204,7 @@ Files: 2. **Header stripping**: Removes `authorization`, `x-api-key`, `host`, and any header names that will be set from route defaults. 3. **Default headers**: Applies route-level default headers (e.g. `anthropic-version: 2023-06-01`) unless the client already sent them. 4. **Model rewrite**: Parses the request body as JSON and replaces the `model` field with the route's configured model. Non-JSON bodies are forwarded unchanged. -5. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the endpoint already ends with `/v1` and the request path starts with `/v1/`, the duplicate prefix is deduplicated. +5. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the request path is exactly `/v1` or starts with `/v1/`, the `/v1` prefix is always stripped before appending. This handles both `/v1`-suffixed endpoints (e.g. `api.openai.com/v1`) and non-versioned endpoints (e.g. `chatgpt.com/backend-api` for Codex) uniformly. ### Header sanitization @@ -297,12 +330,24 @@ The system route is stored as a separate `InferenceRoute` record in the gateway Cluster inference commands: -- `openshell inference set --provider --model ` -- configures user-facing cluster inference +- `openshell inference set --provider --model ` -- configures user-facing cluster inference (single model) +- `openshell inference set --model-alias ALIAS=PROVIDER/MODEL [--model-alias ...]` -- configures multi-model cluster inference - `openshell inference set --system --provider --model ` -- configures system inference - `openshell inference get` -- displays both user and system inference configuration - `openshell inference get --system` -- displays only the system inference configuration -The `--provider` flag references a provider record name (not a provider type). The provider must already exist in the cluster and have a supported inference type (`openai`, `anthropic`, or `nvidia`). +The `--provider` flag references a provider record name (not a provider type). The provider must already exist in the cluster and have a supported inference type (`openai`, `anthropic`, `nvidia`, or `ollama`). + +`--model-alias` can be repeated to configure multiple providers simultaneously. It conflicts with `--provider` and `--model` -- the two modes are mutually exclusive. Example: + +```bash +openshell inference set \ + --model-alias my-gpt=openai-dev/gpt-4o \ + --model-alias my-claude=anthropic-dev/claude-sonnet-4-20250514 \ + --model-alias my-llama=ollama-local/llama3 +``` + +Agents select a backend by setting the `model` field in their inference request to the alias name (e.g. `"model": "my-gpt"`). Inference writes verify by default. `--no-verify` is the explicit opt-out for endpoints that are not up yet. diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 5de31c79..e1d58994 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -6,7 +6,7 @@ use clap::{CommandFactory, Parser, Subcommand, ValueEnum, ValueHint}; use clap_complete::engine::ArgValueCompleter; use clap_complete::env::CompleteEnv; -use miette::Result; +use miette::{Result, miette}; use owo_colors::OwoColorize; use std::io::Write; @@ -286,6 +286,7 @@ const GATEWAY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m const INFERENCE_EXAMPLES: &str = "\x1b[1mEXAMPLES\x1b[0m $ openshell inference set --provider openai --model gpt-4 + $ openshell inference set --model-alias gpt=openai/gpt-4 --model-alias claude=anthropic/claude-sonnet-4-20250514 $ openshell inference get $ openshell inference update --model gpt-4-turbo "; @@ -918,15 +919,26 @@ enum GatewayCommands { #[derive(Subcommand, Debug)] enum InferenceCommands { /// Set gateway-level inference provider and model. + /// + /// Use --provider/--model for single-model mode, or --model-alias for + /// multi-model mode (multiple providers routed by alias). #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Set { - /// Provider name. - #[arg(long, add = ArgValueCompleter::new(completers::complete_provider_names))] - provider: String, + /// Provider name (single-model mode). + #[arg(long, required_unless_present = "model_alias", add = ArgValueCompleter::new(completers::complete_provider_names))] + provider: Option, - /// Model identifier to force for generation calls. - #[arg(long)] - model: String, + /// Model identifier to force for generation calls (single-model mode). + #[arg(long, required_unless_present = "model_alias")] + model: Option, + + /// Add a model alias in the form ALIAS=PROVIDER/MODEL. + /// Can be repeated to configure multiple providers simultaneously. + /// Not supported with --system. + /// + /// Example: --model-alias my-gpt=openai-dev/gpt-4o --model-alias my-claude=anthropic-dev/claude-sonnet-4-20250514 + #[arg(long, conflicts_with_all = ["provider", "model", "system"])] + model_alias: Vec, /// Configure the system inference route instead of the user-facing /// route. System inference is used by platform functions (e.g. the @@ -2024,14 +2036,32 @@ async fn main() -> Result<()> { InferenceCommands::Set { provider, model, + model_alias, system, no_verify, } => { let route_name = if system { "sandbox-system" } else { "" }; - run::gateway_inference_set( - endpoint, &provider, &model, route_name, no_verify, &tls, - ) - .await?; + if !model_alias.is_empty() { + run::gateway_inference_set_multi( + endpoint, + &model_alias, + route_name, + no_verify, + &tls, + ) + .await?; + } else { + let provider = provider.as_deref().ok_or_else(|| { + miette!("--provider is required in single-model mode") + })?; + let model = model + .as_deref() + .ok_or_else(|| miette!("--model is required in single-model mode"))?; + run::gateway_inference_set( + endpoint, provider, model, route_name, no_verify, &tls, + ) + .await?; + } } InferenceCommands::Update { provider, diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index e32eec2a..f20f43ef 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -26,8 +26,8 @@ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, DeleteProviderRequest, DeleteSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ListProvidersRequest, - ListSandboxPoliciesRequest, ListSandboxesRequest, PolicyStatus, Provider, + GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, InferenceModelEntry, + ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxesRequest, PolicyStatus, Provider, RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, setting_value, @@ -3504,6 +3504,7 @@ pub async fn gateway_inference_set( route_name: route_name.to_string(), verify: false, no_verify, + models: vec![], }) .await; @@ -3534,6 +3535,101 @@ pub async fn gateway_inference_set( Ok(()) } +pub async fn gateway_inference_set_multi( + server: &str, + model_aliases: &[String], + route_name: &str, + no_verify: bool, + tls: &TlsOptions, +) -> Result<()> { + let mut models = Vec::with_capacity(model_aliases.len()); + for entry in model_aliases { + let (alias, rest) = entry.split_once('=').ok_or_else(|| { + miette!("invalid --model-alias format: {entry:?}. Expected ALIAS=PROVIDER/MODEL") + })?; + let (provider, model) = rest.split_once('/').ok_or_else(|| { + miette!("invalid --model-alias value after '=': {rest:?}. Expected PROVIDER/MODEL") + })?; + if alias.trim().is_empty() { + return Err(miette!("empty alias in --model-alias {entry:?}")); + } + if provider.trim().is_empty() { + return Err(miette!("empty provider in --model-alias {entry:?}")); + } + if model.trim().is_empty() { + return Err(miette!("empty model in --model-alias {entry:?}")); + } + models.push(InferenceModelEntry { + alias: alias.to_string(), + provider_name: provider.to_string(), + model_id: model.to_string(), + }); + } + + let progress = if std::io::stdout().is_terminal() { + let spinner = ProgressBar::new_spinner(); + spinner.set_style( + ProgressStyle::with_template("{spinner:.cyan} {msg} ({elapsed})") + .unwrap_or_else(|_| ProgressStyle::default_spinner()), + ); + spinner.set_message("Configuring multi-model inference..."); + spinner.enable_steady_tick(Duration::from_millis(120)); + Some(spinner) + } else { + None + }; + + let mut client = grpc_inference_client(server, tls).await?; + let response = client + .set_cluster_inference(SetClusterInferenceRequest { + provider_name: String::new(), + model_id: String::new(), + route_name: route_name.to_string(), + verify: false, + no_verify, + models, + }) + .await; + + if let Some(progress) = &progress { + progress.finish_and_clear(); + } + + let response = response.map_err(format_inference_status)?; + let configured = response.into_inner(); + let label = if configured.route_name == "sandbox-system" { + "System multi-model inference configured:" + } else { + "Gateway multi-model inference configured:" + }; + println!("{}", label.cyan().bold()); + println!(); + println!(" {} {}", "Route:".dimmed(), configured.route_name); + println!(" {} {}", "Version:".dimmed(), configured.version); + if configured.models.is_empty() { + println!(" {} {}", "Provider:".dimmed(), configured.provider_name); + println!(" {} {}", "Model:".dimmed(), configured.model_id); + } else { + println!(" {}", "Models:".dimmed()); + for m in &configured.models { + println!( + " {} {} {}/{}", + "-".dimmed(), + m.alias.bold(), + m.provider_name, + m.model_id + ); + } + } + if configured.validation_performed { + println!(" {}", "Validated Endpoints:".dimmed()); + for endpoint in configured.validated_endpoints { + println!(" - {} ({})", endpoint.url, endpoint.protocol); + } + } + Ok(()) +} + pub async fn gateway_inference_update( server: &str, provider_name: Option<&str>, @@ -3582,6 +3678,7 @@ pub async fn gateway_inference_update( route_name: route_name.to_string(), verify: false, no_verify, + models: vec![], }) .await; @@ -3636,9 +3733,23 @@ pub async fn gateway_inference_get( }; println!("{}", label.cyan().bold()); println!(); - println!(" {} {}", "Provider:".dimmed(), configured.provider_name); - println!(" {} {}", "Model:".dimmed(), configured.model_id); - println!(" {} {}", "Version:".dimmed(), configured.version); + if !configured.models.is_empty() { + println!(" {} {}", "Version:".dimmed(), configured.version); + println!(" {}", "Models:".dimmed()); + for m in &configured.models { + println!( + " {} {} {}/{}", + "-".dimmed(), + m.alias.bold(), + m.provider_name, + m.model_id + ); + } + } else { + println!(" {} {}", "Provider:".dimmed(), configured.provider_name); + println!(" {} {}", "Model:".dimmed(), configured.model_id); + println!(" {} {}", "Version:".dimmed(), configured.version); + } } else { // Show both routes by default. print_inference_route(&mut client, "Gateway inference", "").await; @@ -3663,9 +3774,23 @@ async fn print_inference_route( let configured = response.into_inner(); println!("{}", format!("{label}:").cyan().bold()); println!(); - println!(" {} {}", "Provider:".dimmed(), configured.provider_name); - println!(" {} {}", "Model:".dimmed(), configured.model_id); - println!(" {} {}", "Version:".dimmed(), configured.version); + if !configured.models.is_empty() { + println!(" {} {}", "Version:".dimmed(), configured.version); + println!(" {}", "Models:".dimmed()); + for m in &configured.models { + println!( + " {} {} {}/{}", + "-".dimmed(), + m.alias.bold(), + m.provider_name, + m.model_id + ); + } + } else { + println!(" {} {}", "Provider:".dimmed(), configured.provider_name); + println!(" {} {}", "Model:".dimmed(), configured.model_id); + println!(" {} {}", "Version:".dimmed(), configured.version); + } } Err(e) if e.code() == Code::NotFound => { println!("{}", format!("{label}:").cyan().bold()); diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 6eb01a43..d10ac17c 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -234,6 +234,7 @@ fn validation_probe(route: &ResolvedRoute) -> Result String { let base = endpoint.trim_end_matches('/'); + // When the endpoint already contains /v1 (e.g. api.openai.com/v1) + // and the proxy path also starts with /v1/, strip the duplicate + // prefix so the resulting URL is correct. if base.ends_with("/v1") && (path == "/v1" || path.starts_with("/v1/")) { return format!("{base}{}", &path[3..]); } @@ -458,10 +462,18 @@ mod tests { } #[test] - fn build_backend_url_preserves_non_versioned_base() { + fn build_backend_url_preserves_v1_for_plain_endpoint() { assert_eq!( - build_backend_url("https://api.anthropic.com", "/v1/messages"), - "https://api.anthropic.com/v1/messages" + build_backend_url("https://my-proxy.example.com", "/v1/chat/completions"), + "https://my-proxy.example.com/v1/chat/completions" + ); + } + + #[test] + fn build_backend_url_codex_path() { + assert_eq!( + build_backend_url("https://api.openai.com/v1", "/v1/codex/responses"), + "https://api.openai.com/v1/codex/responses" ); } @@ -488,8 +500,9 @@ mod tests { #[tokio::test] async fn verify_backend_endpoint_uses_route_auth_and_shape() { let mock_server = MockServer::start().await; + // Use endpoint with /v1 suffix to match real Anthropic endpoint layout. let route = test_route( - &mock_server.uri(), + &format!("{}/v1", mock_server.uri()), &["anthropic_messages"], AuthHeader::Custom("x-api-key"), ); @@ -519,7 +532,7 @@ mod tests { #[tokio::test] async fn verify_backend_endpoint_accepts_mock_routes() { let route = test_route( - "mock://test-backend", + "mock://test-backend/v1", &["openai_chat_completions"], AuthHeader::Bearer, ); diff --git a/crates/openshell-router/src/lib.rs b/crates/openshell-router/src/lib.rs index 0fcb470e..a3c62ab5 100644 --- a/crates/openshell-router/src/lib.rs +++ b/crates/openshell-router/src/lib.rs @@ -47,8 +47,10 @@ fn select_route<'a>( model_hint: Option<&str>, ) -> Option<&'a ResolvedRoute> { if let Some(hint) = model_hint { + let normalized_hint = hint.trim().to_ascii_lowercase(); if let Some(r) = candidates.iter().find(|r| { - r.name == hint && r.protocols.iter().any(|p| p == protocol) + r.name.trim().to_ascii_lowercase() == normalized_hint + && r.protocols.iter().any(|p| p == protocol) }) { return Some(r); } @@ -236,7 +238,10 @@ mod tests { fn select_route_alias_match_takes_priority() { let routes = vec![ make_route("ollama-local", vec!["openai_chat_completions"]), - make_route("openai-prod", vec!["openai_chat_completions", "openai_responses"]), + make_route( + "openai-prod", + vec!["openai_chat_completions", "openai_responses"], + ), ]; // Both support openai_chat_completions, but hint selects the second one. let r = select_route(&routes, "openai_chat_completions", Some("openai-prod")).unwrap(); @@ -257,9 +262,17 @@ mod tests { #[test] fn select_route_no_match_returns_none() { + let routes = vec![make_route("ollama-local", vec!["openai_chat_completions"])]; + assert!(select_route(&routes, "anthropic_messages", None).is_none()); + } + + #[test] + fn select_route_alias_match_is_case_insensitive() { let routes = vec![ - make_route("ollama-local", vec!["openai_chat_completions"]), + make_route("My-GPT", vec!["openai_chat_completions"]), + make_route("anthropic-prod", vec!["anthropic_messages"]), ]; - assert!(select_route(&routes, "anthropic_messages", None).is_none()); + let r = select_route(&routes, "openai_chat_completions", Some("my-gpt")).unwrap(); + assert_eq!(r.name, "My-GPT"); } } diff --git a/crates/openshell-router/tests/backend_integration.rs b/crates/openshell-router/tests/backend_integration.rs index 28b12a21..5abc32ab 100644 --- a/crates/openshell-router/tests/backend_integration.rs +++ b/crates/openshell-router/tests/backend_integration.rs @@ -9,7 +9,7 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; fn mock_candidates(base_url: &str) -> Vec { vec![ResolvedRoute { name: "inference.local".to_string(), - endpoint: base_url.to_string(), + endpoint: format!("{base_url}/v1"), model: "meta/llama-3.1-8b-instruct".to_string(), api_key: "test-api-key".to_string(), protocols: vec!["openai_chat_completions".to_string()], @@ -313,7 +313,7 @@ async fn proxy_uses_x_api_key_for_anthropic_route() { let router = Router::new().unwrap(); let candidates = vec![ResolvedRoute { name: "inference.local".to_string(), - endpoint: mock_server.uri(), + endpoint: format!("{}/v1", mock_server.uri()), model: "claude-sonnet-4-20250514".to_string(), api_key: "test-anthropic-key".to_string(), protocols: vec!["anthropic_messages".to_string()], @@ -372,7 +372,7 @@ async fn proxy_anthropic_does_not_send_bearer_auth() { let router = Router::new().unwrap(); let candidates = vec![ResolvedRoute { name: "inference.local".to_string(), - endpoint: mock_server.uri(), + endpoint: format!("{}/v1", mock_server.uri()), model: "claude-sonnet-4-20250514".to_string(), api_key: "anthropic-key".to_string(), protocols: vec!["anthropic_messages".to_string()], @@ -417,7 +417,7 @@ async fn proxy_forwards_client_anthropic_version_header() { let router = Router::new().unwrap(); let candidates = vec![ResolvedRoute { name: "inference.local".to_string(), - endpoint: mock_server.uri(), + endpoint: format!("{}/v1", mock_server.uri()), model: "claude-sonnet-4-20250514".to_string(), api_key: "test-anthropic-key".to_string(), protocols: vec!["anthropic_messages".to_string()], diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index ac63a4a4..3dda7346 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -37,6 +37,12 @@ pub fn default_patterns() -> Vec { protocol: "openai_responses".to_string(), kind: "responses".to_string(), }, + InferenceApiPattern { + method: "POST".to_string(), + path_glob: "/v1/codex/*".to_string(), + protocol: "openai_responses".to_string(), + kind: "codex_responses".to_string(), + }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/messages".to_string(), diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 7004d45c..5c759434 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -249,6 +249,10 @@ async fn upsert_multi_model_route( return Err(Status::invalid_argument("models list is empty")); } + // Names reserved for internal route partitioning (sandbox uses + // `name == "sandbox-system"` to split user vs system caches). + const RESERVED_ALIASES: &[&str] = &["sandbox-system", "inference.local"]; + // Validate aliases are unique. let mut seen_aliases = std::collections::HashSet::new(); for entry in models { @@ -257,6 +261,13 @@ async fn upsert_multi_model_route( "each model entry must have a non-empty alias", )); } + let alias_lower = entry.alias.trim().to_ascii_lowercase(); + if RESERVED_ALIASES.iter().any(|r| *r == alias_lower) { + return Err(Status::invalid_argument(format!( + "alias '{}' is reserved and cannot be used", + entry.alias, + ))); + } if entry.provider_name.trim().is_empty() { return Err(Status::invalid_argument(format!( "model entry '{}' is missing provider_name", @@ -295,9 +306,8 @@ async fn upsert_multi_model_route( let resolved = resolve_provider_route(&provider)?; if verify { - validation.push( - verify_provider_endpoint(&provider.name, &entry.model_id, &resolved).await?, - ); + validation + .push(verify_provider_endpoint(&provider.name, &entry.model_id, &resolved).await?); } } @@ -1038,7 +1048,7 @@ mod tests { "OPENAI_API_KEY", "sk-test", "OPENAI_BASE_URL", - &mock_server.uri(), + &format!("{}/v1", mock_server.uri()), ); store .put_message(&provider) @@ -1079,7 +1089,7 @@ mod tests { "OPENAI_API_KEY", "sk-test", "OPENAI_BASE_URL", - &mock_server.uri(), + &format!("{}/v1", mock_server.uri()), ); store .put_message(&provider) @@ -1286,18 +1296,16 @@ mod tests { model_id: "gpt-4o-mini".to_string(), }]; - let result = upsert_multi_model_route( - &store, - CLUSTER_INFERENCE_ROUTE_NAME, - &models, - false, - ) - .await - .expect("upsert should succeed"); + let result = upsert_multi_model_route(&store, CLUSTER_INFERENCE_ROUTE_NAME, &models, false) + .await + .expect("upsert should succeed"); assert_eq!(result.route.version, 1); assert_eq!(result.route.config.as_ref().unwrap().models.len(), 1); - assert_eq!(result.route.config.as_ref().unwrap().models[0].alias, "fast-gpt"); + assert_eq!( + result.route.config.as_ref().unwrap().models[0].alias, + "fast-gpt" + ); // Legacy fields should be populated from first entry assert_eq!( @@ -1309,4 +1317,26 @@ mod tests { "gpt-4o-mini" ); } + + #[tokio::test] + async fn upsert_multi_model_route_rejects_reserved_alias() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let openai = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-openai"); + store.put_message(&openai).await.expect("persist"); + + let models = vec![InferenceModelEntry { + alias: "sandbox-system".to_string(), + provider_name: "openai-dev".to_string(), + model_id: "gpt-4o".to_string(), + }]; + + let err = upsert_multi_model_route(&store, CLUSTER_INFERENCE_ROUTE_NAME, &models, false) + .await + .expect_err("should reject reserved alias"); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("reserved")); + } }