From 290dbdd0dbf034d17880501c8e2d2123509b8199 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Fri, 27 Mar 2026 20:25:10 +0800 Subject: [PATCH 01/10] =?UTF-8?q?feat(tool):=20=E6=B7=BB=E5=8A=A0=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E7=9B=B8=E5=85=B3=E7=9A=84=E6=96=B0=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E5=92=8CAPI=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 此提交添加了一系列新的工具相关的模块和API接口,包括客户端模板、模型定义以及各种控制和OpenAPI接口等组件。同时更新了依赖版本并完善了初始化配置。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/__init__.py | 8 + agentrun/tool/__client_async_template.py | 53 ++ agentrun/tool/__init__.py | 41 ++ agentrun/tool/__tool_async_template.py | 311 ++++++++++ agentrun/tool/api/__init__.py | 1 + agentrun/tool/api/control.py | 128 ++++ agentrun/tool/api/mcp.py | 167 +++++ agentrun/tool/api/openapi.py | 337 ++++++++++ agentrun/tool/client.py | 87 +++ agentrun/tool/model.py | 406 +++++++++++++ agentrun/tool/tool.py | 440 ++++++++++++++ pyproject.toml | 2 +- tests/unittests/tool/__init__.py | 5 + tests/unittests/tool/test_mcp.py | 308 ++++++++++ tests/unittests/tool/test_model.py | 661 ++++++++++++++++++++ tests/unittests/tool/test_openapi.py | 694 +++++++++++++++++++++ tests/unittests/tool/test_tool.py | 743 +++++++++++++++++++++++ 17 files changed, 4391 insertions(+), 1 deletion(-) create mode 100644 agentrun/tool/__client_async_template.py create mode 100644 agentrun/tool/__init__.py create mode 100644 agentrun/tool/__tool_async_template.py create mode 100644 agentrun/tool/api/__init__.py create mode 100644 agentrun/tool/api/control.py create mode 100644 agentrun/tool/api/mcp.py create mode 100644 agentrun/tool/api/openapi.py create mode 100644 agentrun/tool/client.py create mode 100644 agentrun/tool/model.py create mode 100644 agentrun/tool/tool.py create mode 100644 tests/unittests/tool/__init__.py create mode 100644 tests/unittests/tool/test_mcp.py create mode 100644 tests/unittests/tool/test_model.py create mode 100644 tests/unittests/tool/test_openapi.py create mode 100644 tests/unittests/tool/test_tool.py diff --git a/agentrun/__init__.py b/agentrun/__init__.py index 316a2e5..24961ac 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -109,6 +109,10 @@ SandboxClient, Template, ) +# Tool +from agentrun.tool import Tool as ToolResource +from agentrun.tool import ToolClient as ToolResourceClient +from agentrun.tool import ToolControlAPI as ToolResourceControlAPI # ToolSet from agentrun.toolset import ToolSet, ToolSetClient from agentrun.utils.config import Config @@ -247,6 +251,10 @@ "AioSandbox", "CustomSandbox", "Template", + ######## Tool ######## + "ToolResource", + "ToolResourceClient", + "ToolResourceControlAPI", ######## ToolSet ######## "ToolSetClient", "ToolSet", diff --git a/agentrun/tool/__client_async_template.py b/agentrun/tool/__client_async_template.py new file mode 100644 index 0000000..2f504a9 --- /dev/null +++ b/agentrun/tool/__client_async_template.py @@ -0,0 +1,53 @@ +"""Tool 客户端 / Tool Client + +此模块提供工具的客户端 API。 +This module provides the client API for tools. +""" + +from typing import Any, Dict, List, Optional + +from agentrun.tool.api.control import ToolControlAPI +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .tool import Tool + + +class ToolClient: + """Tool 客户端 / Tool Client + + 提供工具的获取功能。 + Provides get function for tools. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = ToolControlAPI(config) + + async def get_async( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """异步获取工具 / Get tool asynchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = await self.__control_api.get_tool_async( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) diff --git a/agentrun/tool/__init__.py b/agentrun/tool/__init__.py new file mode 100644 index 0000000..cce0e04 --- /dev/null +++ b/agentrun/tool/__init__.py @@ -0,0 +1,41 @@ +"""Tool 模块 / Tool Module + +此模块提供工具管理功能。 +This module provides tool management functionality. +""" + +from .api.control import ToolControlAPI +from .api.mcp import ToolMCPSession +from .api.openapi import ToolOpenAPIClient +from .client import ToolClient +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNASConfig, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) +from .tool import Tool + +__all__ = [ + "ToolControlAPI", + "ToolMCPSession", + "ToolOpenAPIClient", + "ToolClient", + "Tool", + "ToolType", + "McpConfig", + "ToolCodeConfiguration", + "ToolContainerConfiguration", + "ToolInfo", + "ToolLogConfiguration", + "ToolNASConfig", + "ToolNetworkConfiguration", + "ToolOSSMountConfig", + "ToolSchema", +] diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py new file mode 100644 index 0000000..41f5bca --- /dev/null +++ b/agentrun/tool/__tool_async_template.py @@ -0,0 +1,311 @@ +"""Tool 资源类 / Tool Resource Class + +提供工具资源的面向对象封装和完整生命周期管理。 +Provides object-oriented wrapper and complete lifecycle management for tool resources. +""" + +from typing import Any, Dict, List, Optional + +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import BaseModel + +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class Tool(BaseModel): + """工具资源 / Tool Resource + + 提供工具的查询、调用等功能。 + Provides query, invocation and other functionality for tools. + + Attributes: + code_configuration: 代码包配置 / Code configuration + container_configuration: 容器配置 / Container configuration + created_time: 创建时间 / Creation time + data_endpoint: 数据链路端点 / Data endpoint + description: 描述 / Description + environment_variables: 环境变量 / Environment variables + gpu: GPU 配置 / GPU configuration + internet_access: 是否允许公网访问 / Whether internet access is allowed + last_modified_time: 最后修改时间 / Last modified time + log_configuration: 日志配置 / Log configuration + mcp_config: MCP 配置 / MCP configuration + memory: 内存大小(MB) / Memory size in MB + name: 工具名称 / Tool name + network_config: 网络配置 / Network configuration + oss_mount_config: OSS 挂载配置 / OSS mount configuration + protocol_spec: 协议规格(OpenAPI JSON) / Protocol spec (OpenAPI JSON) + protocol_type: 协议类型 / Protocol type + status: 状态 / Status + timeout: 超时时间(秒) / Timeout in seconds + tool_id: 工具 ID / Tool ID + tool_name: 工具名称 / Tool name + tool_type: 工具类型(MCP/FUNCTIONCALL) / Tool type + version_id: 版本 ID / Version ID + """ + + code_configuration: Optional[ToolCodeConfiguration] = None + """代码包配置 / Code configuration""" + + container_configuration: Optional[ToolContainerConfiguration] = None + """容器配置 / Container configuration""" + + created_time: Optional[str] = None + """创建时间 / Creation time""" + + data_endpoint: Optional[str] = None + """数据链路端点 / Data endpoint""" + + description: Optional[str] = None + """描述 / Description""" + + environment_variables: Optional[Dict[str, str]] = None + """环境变量 / Environment variables""" + + gpu: Optional[str] = None + """GPU 配置 / GPU configuration""" + + internet_access: Optional[bool] = None + """是否允许公网访问 / Whether internet access is allowed""" + + last_modified_time: Optional[str] = None + """最后修改时间 / Last modified time""" + + log_configuration: Optional[ToolLogConfiguration] = None + """日志配置 / Log configuration""" + + mcp_config: Optional[McpConfig] = None + """MCP 配置 / MCP configuration""" + + memory: Optional[int] = None + """内存大小(MB) / Memory size in MB""" + + name: Optional[str] = None + """工具名称 / Tool name""" + + network_config: Optional[ToolNetworkConfiguration] = None + """网络配置 / Network configuration""" + + oss_mount_config: Optional[ToolOSSMountConfig] = None + """OSS 挂载配置 / OSS mount configuration""" + + protocol_spec: Optional[str] = None + """协议规格(OpenAPI JSON 字符串) / Protocol spec (OpenAPI JSON string)""" + + protocol_type: Optional[str] = None + """协议类型 / Protocol type""" + + status: Optional[str] = None + """状态 / Status""" + + timeout: Optional[int] = None + """超时时间(秒) / Timeout in seconds""" + + tool_id: Optional[str] = None + """工具 ID / Tool ID""" + + tool_name: Optional[str] = None + """工具名称 / Tool name""" + + tool_type: Optional[str] = None + """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + + version_id: Optional[str] = None + """版本 ID / Version ID""" + + @classmethod + def __get_client(cls, config: Optional[Config] = None): + from .client import ToolClient + + return ToolClient(config) + + @classmethod + async def get_by_name_async( + cls, name: str, config: Optional[Config] = None + ) -> "Tool": + """异步通过名称获取工具 / Get tool by name asynchronously""" + cli = cls.__get_client(config) + return await cli.get_async(name=name) + + async def get_async(self, config: Optional[Config] = None) -> "Tool": + """异步刷新工具信息 / Refresh tool info asynchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = await self.get_by_name_async( + name=effective_name, config=config + ) + return self.update_self(result) + + def _get_functioncall_server_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 FunctionCall 工具的 fallback server URL / Get fallback server URL for FunctionCall tools + + 当 OpenAPI spec 中没有 servers 字段时,使用 data_endpoint 构造 URL。 + Constructs URL from data_endpoint when servers is not present in OpenAPI spec. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}" + + def _get_tool_type(self) -> Optional[ToolType]: + """获取工具类型 / Get tool type""" + raw_type = self.tool_type + if raw_type: + try: + return ToolType(raw_type) + except ValueError: + return None + return None + + def _get_mcp_endpoint( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 MCP 数据链路 URL / Get MCP data endpoint URL + + 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 + 如果 self.data_endpoint 为空,则从 Config 中获取。 + Determines /mcp or /sse path based on session_affinity. + Falls back to Config if self.data_endpoint is not set. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + if session_affinity == "MCP_STREAMABLE": + return f"{data_endpoint}/tools/{effective_name}/mcp" + return f"{data_endpoint}/tools/{effective_name}/sse" + + async def list_tools_async( + self, config: Optional[Config] = None + ) -> List[ToolInfo]: + """异步获取子工具列表 / Get sub-tool list asynchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return await session.list_tools_async() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return await openapi_client.list_tools_async() + + return [] + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """异步调用子工具 / Call sub-tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = await session.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = await openapi_client.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") diff --git a/agentrun/tool/api/__init__.py b/agentrun/tool/api/__init__.py new file mode 100644 index 0000000..fd1b7de --- /dev/null +++ b/agentrun/tool/api/__init__.py @@ -0,0 +1 @@ +"""Tool API 模块 / Tool API Module""" diff --git a/agentrun/tool/api/control.py b/agentrun/tool/api/control.py new file mode 100644 index 0000000..b2630ba --- /dev/null +++ b/agentrun/tool/api/control.py @@ -0,0 +1,128 @@ +"""Tool 管控链路 API / Tool Control API + +通过底层 agentrun20250910 SDK 与平台交互,获取 Tool 资源。 +Interacts with the platform via the agentrun20250910 SDK to get Tool resources. +""" + +from typing import Dict, Optional + +from alibabacloud_agentrun20250910.models import Tool as InnerTool +from alibabacloud_tea_openapi.exceptions._client import ClientException +from alibabacloud_tea_openapi.exceptions._server import ServerException +from darabonba.runtime import RuntimeOptions +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.control_api import ControlAPI +from agentrun.utils.exception import ClientError, ServerError +from agentrun.utils.log import logger + + +class ToolControlAPI(ControlAPI): + """Tool 管控链路 API / Tool Control API""" + + def __init__(self, config: Optional[Config] = None): + """初始化 API 客户端 / Initialize API client + + Args: + config: 全局配置对象 / Global configuration object + """ + super().__init__(config) + + def get_tool( + self, + name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> InnerTool: + """获取工具 / Get tool + + Args: + name: Tool 名称 / Tool name + headers: 请求头 / Request headers + config: 配置 / Configuration + + Returns: + InnerTool: 底层 SDK 的 Tool 对象 / Inner SDK Tool object + + Raises: + ClientError: 客户端错误 / Client error + ServerError: 服务器错误 / Server error + """ + try: + client = self._get_client(config) + response = client.get_tool_with_options( + name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_tool, request Request ID:" + f" {response.headers['x-acs-request-id'] if response.headers else ''}\n" + f" request: {[name]}\n response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[name], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def get_tool_async( + self, + name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> InnerTool: + """异步获取工具 / Get tool asynchronously + + Args: + name: Tool 名称 / Tool name + headers: 请求头 / Request headers + config: 配置 / Configuration + + Returns: + InnerTool: 底层 SDK 的 Tool 对象 / Inner SDK Tool object + + Raises: + ClientError: 客户端错误 / Client error + ServerError: 服务器错误 / Server error + """ + try: + client = self._get_client(config) + response = await client.get_tool_with_options_async( + name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_tool, request Request ID:" + f" {response.headers['x-acs-request-id'] if response.headers else ''}\n" + f" request: {[name]}\n response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[name], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e diff --git a/agentrun/tool/api/mcp.py b/agentrun/tool/api/mcp.py new file mode 100644 index 0000000..a0cef61 --- /dev/null +++ b/agentrun/tool/api/mcp.py @@ -0,0 +1,167 @@ +"""Tool MCP 数据链路 / Tool MCP Data API + +通过 MCP 协议与 Tool 的数据链路交互,支持 SSE 和 Streamable HTTP 两种传输方式。 +Interacts with Tool data endpoints via MCP protocol, supporting SSE and Streamable HTTP transports. +""" + +import asyncio +from typing import Any, Dict, List, Optional + +from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.log import logger + + +class ToolMCPSession: + """Tool MCP 会话管理 / Tool MCP Session Manager + + 独立实现的 MCP 会话管理,支持 SSE 和 Streamable HTTP 两种传输方式。 + Independent MCP session manager supporting SSE and Streamable HTTP transports. + """ + + def __init__( + self, + endpoint: str, + session_affinity: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + """初始化 MCP 会话 / Initialize MCP session + + Args: + endpoint: MCP 数据链路 URL / MCP data endpoint URL + session_affinity: 会话亲和性策略 / Session affinity strategy + headers: 请求头 / Request headers + """ + self.endpoint = endpoint + self.session_affinity = session_affinity + self.headers = headers or {} + + @property + def is_streamable(self) -> bool: + """是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport""" + return self.session_affinity == "MCP_STREAMABLE" + + async def list_tools_async(self) -> List[ToolInfo]: + """异步获取工具列表 / Get tool list asynchronously + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + try: + from mcp import ClientSession + + if self.is_streamable: + from mcp.client.streamable_http import streamablehttp_client + + async with streamablehttp_client( + self.endpoint, headers=self.headers + ) as (read_stream, write_stream, _): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.list_tools() + return [ + ToolInfo.from_mcp_tool(tool) + for tool in result.tools + ] + else: + from mcp.client.sse import sse_client + + async with sse_client(self.endpoint, headers=self.headers) as ( + read_stream, + write_stream, + ): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.list_tools() + return [ + ToolInfo.from_mcp_tool(tool) + for tool in result.tools + ] + except ImportError: + logger.warning( + "mcp package is not installed. Install it with: pip install mcp" + ) + return [] + + def list_tools(self) -> List[ToolInfo]: + """同步获取工具列表 / Get tool list synchronously + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + return asyncio.get_event_loop().run_until_complete( + self.list_tools_async() + ) + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """异步调用工具 / Call tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + + Returns: + Any: 工具执行结果 / Tool execution result + """ + try: + from mcp import ClientSession + + if self.is_streamable: + from mcp.client.streamable_http import streamablehttp_client + + async with streamablehttp_client( + self.endpoint, headers=self.headers + ) as (read_stream, write_stream, _): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.call_tool( + name, arguments=arguments or {} + ) + return result + else: + from mcp.client.sse import sse_client + + async with sse_client(self.endpoint, headers=self.headers) as ( + read_stream, + write_stream, + ): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.call_tool( + name, arguments=arguments or {} + ) + return result + except ImportError: + raise ImportError( + "mcp package is required for MCP tool calls. " + "Install it with: pip install mcp" + ) + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """同步调用工具 / Call tool synchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + + Returns: + Any: 工具执行结果 / Tool execution result + """ + return asyncio.get_event_loop().run_until_complete( + self.call_tool_async(name, arguments) + ) diff --git a/agentrun/tool/api/openapi.py b/agentrun/tool/api/openapi.py new file mode 100644 index 0000000..5873c7f --- /dev/null +++ b/agentrun/tool/api/openapi.py @@ -0,0 +1,337 @@ +"""Tool OpenAPI 数据链路 / Tool OpenAPI Data API + +解析 FunctionCall 类型 Tool 的 protocol_spec(OpenAPI JSON), +提取 operations 转换为 ToolInfo 列表,并通过 Server URL 发起 HTTP 调用。 +Parses protocol_spec (OpenAPI JSON) for FunctionCall type Tools, +extracts operations as ToolInfo list, and makes HTTP calls via Server URL. +""" + +import json +from typing import Any, Dict, List, Optional + +import httpx + +from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.log import logger + + +class ToolOpenAPIClient: + """FunctionCall 类型 Tool 的 OpenAPI 客户端 / OpenAPI Client for FunctionCall Tools + + 解析 protocol_spec 中的 OpenAPI Schema,提供 list_tools 和 call_tool 能力。 + Parses OpenAPI Schema from protocol_spec, provides list_tools and call_tool capabilities. + """ + + def __init__( + self, + protocol_spec: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + fallback_server_url: Optional[str] = None, + ): + """初始化 OpenAPI 客户端 / Initialize OpenAPI client + + Args: + protocol_spec: OpenAPI JSON 字符串 / OpenAPI JSON string + headers: 请求头 / Request headers + fallback_server_url: 当 OpenAPI spec 中没有 servers 时的备用 URL / + Fallback URL when servers is not present in OpenAPI spec + """ + self.headers = headers or {} + self._fallback_server_url = fallback_server_url + self._spec: Optional[Dict[str, Any]] = None + self._operations: Optional[List[Dict[str, Any]]] = None + + if protocol_spec: + try: + self._spec = json.loads(protocol_spec) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse protocol_spec as JSON") + + @property + def server_url(self) -> Optional[str]: + """获取 OpenAPI Schema 中的 Server URL / Get Server URL from OpenAPI Schema + + 优先从 spec.servers 获取,如果不存在则使用 fallback_server_url。 + Prefers spec.servers, falls back to fallback_server_url if not present. + """ + if self._spec: + servers = self._spec.get("servers", []) + if servers and isinstance(servers, list): + url = servers[0].get("url") + if url: + return url + return self._fallback_server_url + + def _resolve_ref(self, ref: str) -> Dict[str, Any]: + """解析 $ref 引用 / Resolve $ref reference + + 支持 JSON Pointer 格式的本地引用,如 #/components/schemas/WeatherRequest。 + Supports local JSON Pointer references like #/components/schemas/WeatherRequest. + + Args: + ref: $ref 字符串 / $ref string + + Returns: + 解析后的 schema 字典 / Resolved schema dict + """ + if not self._spec or not ref.startswith("#/"): + return {} + + parts = ref[2:].split("/") + current: Any = self._spec + for part in parts: + if isinstance(current, dict): + current = current.get(part, {}) + else: + return {} + return current if isinstance(current, dict) else {} + + def _resolve_schema( + self, schema: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """递归解析 schema 中的所有 $ref 引用 / Recursively resolve all $ref in schema + + Args: + schema: 可能包含 $ref 的 schema / Schema that may contain $ref + + Returns: + 解析后的完整 schema / Fully resolved schema + """ + if not schema or not isinstance(schema, dict): + return schema + + if "$ref" in schema: + resolved = self._resolve_ref(schema["$ref"]) + return self._resolve_schema(resolved) + + result = {} + for key, value in schema.items(): + if key == "properties" and isinstance(value, dict): + result[key] = { + prop_name: self._resolve_schema(prop_schema) or prop_schema + for prop_name, prop_schema in value.items() + } + elif key in ("items", "additionalProperties") and isinstance( + value, dict + ): + result[key] = self._resolve_schema(value) or value + elif key in ("anyOf", "oneOf", "allOf") and isinstance(value, list): + result[key] = [ + self._resolve_schema(item) or item for item in value + ] + else: + result[key] = value + + return result + + def _parse_operations(self) -> List[Dict[str, Any]]: + """解析 OpenAPI Schema 中的所有 operations / Parse all operations from OpenAPI Schema""" + if self._operations is not None: + return self._operations + + self._operations = [] + if not self._spec: + return self._operations + + paths = self._spec.get("paths", {}) + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + for method in ("get", "post", "put", "delete", "patch"): + operation = path_item.get(method) + if not operation or not isinstance(operation, dict): + continue + + operation_id = operation.get("operationId", f"{method}_{path}") + summary = operation.get("summary", "") + description = operation.get("description", "") + + request_body_schema = None + request_body = operation.get("requestBody", {}) + if isinstance(request_body, dict): + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + raw_schema = json_content.get("schema") + request_body_schema = self._resolve_schema(raw_schema) + + parameters_schema = None + parameters = operation.get("parameters", []) + if parameters and isinstance(parameters, list): + props = {} + required_params = [] + for param in parameters: + if not isinstance(param, dict): + continue + param_name = param.get("name", "") + param_schema = param.get("schema", {"type": "string"}) + param_schema["description"] = param.get( + "description", "" + ) + props[param_name] = param_schema + if param.get("required"): + required_params.append(param_name) + if props: + parameters_schema = { + "type": "object", + "properties": props, + } + if required_params: + parameters_schema["required"] = required_params + + input_schema = request_body_schema or parameters_schema + + self._operations.append({ + "operation_id": operation_id, + "summary": summary, + "description": description, + "method": method.upper(), + "path": path, + "input_schema": input_schema, + }) + + return self._operations + + def list_tools(self) -> List[ToolInfo]: + """获取工具列表 / Get tool list + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + operations = self._parse_operations() + tools = [] + for operation in operations: + parameters = None + if operation.get("input_schema"): + parameters = ToolSchema.from_any_openapi_schema( + operation["input_schema"] + ) + + tool_description = operation["summary"] or operation["description"] + tools.append( + ToolInfo( + name=operation["operation_id"], + description=tool_description, + parameters=parameters + or ToolSchema(type="object", properties={}), + ) + ) + return tools + + async def list_tools_async(self) -> List[ToolInfo]: + """异步获取工具列表 / Get tool list asynchronously""" + return self.list_tools() + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """调用工具 / Call tool + + Args: + name: operationId / Operation ID + arguments: 调用参数 / Call arguments + + Returns: + Any: 调用结果 / Call result + + Raises: + ValueError: operation 不存在 / Operation not found + """ + operations = self._parse_operations() + target_operation = None + for operation in operations: + if operation["operation_id"] == name: + target_operation = operation + break + + if not target_operation: + raise ValueError( + f"Operation '{name}' not found in OpenAPI spec. Available" + f" operations: {[op['operation_id'] for op in operations]}" + ) + + base_url = self.server_url + if not base_url: + raise ValueError("No server URL found in OpenAPI spec") + + url = f"{base_url.rstrip('/')}{target_operation['path']}" + method = target_operation["method"] + + logger.debug( + f"Calling FunctionCall tool: {method} {url} with args={arguments}" + ) + + with httpx.Client(headers=self.headers, timeout=30.0) as client: + if method in ("POST", "PUT", "PATCH"): + response = client.request(method, url, json=arguments or {}) + else: + response = client.request(method, url, params=arguments or {}) + + response.raise_for_status() + + content_type = response.headers.get("content-type", "") + if "application/json" in content_type: + return response.json() + return response.text + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """异步调用工具 / Call tool asynchronously + + Args: + name: operationId / Operation ID + arguments: 调用参数 / Call arguments + + Returns: + Any: 调用结果 / Call result + + Raises: + ValueError: operation 不存在 / Operation not found + """ + operations = self._parse_operations() + target_operation = None + for operation in operations: + if operation["operation_id"] == name: + target_operation = operation + break + + if not target_operation: + raise ValueError( + f"Operation '{name}' not found in OpenAPI spec. Available" + f" operations: {[op['operation_id'] for op in operations]}" + ) + + base_url = self.server_url + if not base_url: + raise ValueError("No server URL found in OpenAPI spec") + + url = f"{base_url.rstrip('/')}{target_operation['path']}" + method = target_operation["method"] + + logger.debug( + f"Calling FunctionCall tool async: {method} {url} with" + f" args={arguments}" + ) + + async with httpx.AsyncClient( + headers=self.headers, timeout=30.0 + ) as client: + if method in ("POST", "PUT", "PATCH"): + response = await client.request( + method, url, json=arguments or {} + ) + else: + response = await client.request( + method, url, params=arguments or {} + ) + + response.raise_for_status() + + content_type = response.headers.get("content-type", "") + if "application/json" in content_type: + return response.json() + return response.text diff --git a/agentrun/tool/client.py b/agentrun/tool/client.py new file mode 100644 index 0000000..048de3b --- /dev/null +++ b/agentrun/tool/client.py @@ -0,0 +1,87 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/tool/__client_async_template.py + +Tool 客户端 / Tool Client + +此模块提供工具的客户端 API。 +This module provides the client API for tools. +""" + +from typing import Any, Dict, List, Optional + +from agentrun.tool.api.control import ToolControlAPI +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .tool import Tool + + +class ToolClient: + """Tool 客户端 / Tool Client + + 提供工具的获取功能。 + Provides get function for tools. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = ToolControlAPI(config) + + async def get_async( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """异步获取工具 / Get tool asynchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = await self.__control_api.get_tool_async( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) + + def get( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """同步获取工具 / Get tool synchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = self.__control_api.get_tool( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) diff --git a/agentrun/tool/model.py b/agentrun/tool/model.py new file mode 100644 index 0000000..994bc51 --- /dev/null +++ b/agentrun/tool/model.py @@ -0,0 +1,406 @@ +"""Tool 模型定义 / Tool Model Definitions + +定义工具相关的数据模型和枚举。 +Defines data models and enumerations related to tools. +""" + +from enum import Enum +from typing import Any, Dict, List, Optional + +from agentrun.utils.model import BaseModel + + +class ToolType(str, Enum): + """工具类型 / Tool Type""" + + MCP = "MCP" + """MCP 协议工具 / MCP Protocol Tool""" + FUNCTIONCALL = "FUNCTIONCALL" + """函数调用工具 / Function Call Tool""" + + +class McpConfig(BaseModel): + """MCP 工具配置 / MCP Tool Configuration + + 包含 MCP 工具的会话亲和性、代理配置等信息。 + Contains session affinity, proxy configuration, etc. for MCP tools. + """ + + session_affinity: Optional[str] = None + """会话亲和性策略 / Session affinity strategy + NONE: 无亲和性 / No affinity + MCP_SSE: 基于 SSE 的会话亲和性 / SSE-based session affinity + MCP_STREAMABLE: 基于流式 HTTP 的会话亲和性 / Streamable HTTP-based session affinity + """ + + session_affinity_config: Optional[str] = None + """会话亲和性的详细配置,JSON 格式字符串 / Session affinity config, JSON string""" + + proxy_enabled: Optional[bool] = None + """是否启用 MCP 代理 / Whether MCP proxy is enabled""" + + bound_configuration: Optional[Dict[str, Any]] = None + """工具的绑定配置 / Tool binding configuration""" + + mcp_proxy_configuration: Optional[Dict[str, Any]] = None + """MCP 代理的详细配置 / MCP proxy detailed configuration""" + + +class ToolCodeConfiguration(BaseModel): + """代码包配置 / Code Configuration + + 代码包类型工具的配置信息。 + Configuration for code-package type tools. + """ + + code_checksum: Optional[str] = None + """代码校验和 / Code checksum""" + + code_size: Optional[int] = None + """代码大小(字节)/ Code size in bytes""" + + command: Optional[List[str]] = None + """启动命令 / Startup command""" + + language: Optional[str] = None + """编程语言 / Programming language""" + + oss_bucket_name: Optional[str] = None + """OSS 存储桶名称 / OSS bucket name""" + + oss_object_name: Optional[str] = None + """OSS 对象名称 / OSS object name""" + + +class ToolContainerConfiguration(BaseModel): + """容器配置 / Container Configuration + + 容器类型工具的配置信息。 + Configuration for container type tools. + """ + + args: Optional[List[str]] = None + """容器启动参数 / Container startup arguments""" + + command: Optional[List[str]] = None + """容器启动命令 / Container startup command""" + + image: Optional[str] = None + """容器镜像地址 / Container image URL""" + + port: Optional[int] = None + """容器端口 / Container port""" + + +class ToolLogConfiguration(BaseModel): + """日志配置 / Log Configuration + + 工具的日志配置信息。 + Log configuration for tools. + """ + + log_store: Optional[str] = None + """SLS 日志库 / SLS log store""" + + project: Optional[str] = None + """SLS 项目 / SLS project""" + + +class ToolNASConfig(BaseModel): + """NAS 文件存储配置 / NAS Configuration + + 工具访问 NAS 文件系统的配置。 + Configuration for tool access to NAS file system. + """ + + group_id: Optional[int] = None + """组 ID / Group ID""" + + mount_points: Optional[List[Dict[str, Any]]] = None + """挂载点列表 / Mount points list""" + + user_id: Optional[int] = None + """用户 ID / User ID""" + + +class ToolNetworkConfiguration(BaseModel): + """网络配置 / Network Configuration + + 工具的网络配置信息。 + Network configuration for tools. + """ + + security_group_id: Optional[str] = None + """安全组 ID / Security group ID""" + + vpc_id: Optional[str] = None + """VPC ID""" + + vswitch_ids: Optional[List[str]] = None + """交换机 ID 列表 / VSwitch IDs""" + + +class ToolOSSMountConfig(BaseModel): + """OSS 挂载配置 / OSS Mount Configuration + + 工具访问 OSS 存储的挂载配置。 + Configuration for tool access to OSS storage. + """ + + mount_points: Optional[List[Dict[str, Any]]] = None + """挂载点列表 / Mount points list""" + + +class ToolSchema(BaseModel): + """JSON Schema 兼容的工具参数描述 / JSON Schema Compatible Tool Parameter Description + + 支持完整的 JSON Schema 字段,能够描述复杂的嵌套数据结构。 + Supports full JSON Schema fields for describing complex nested data structures. + """ + + type: Optional[str] = None + """数据类型 / Data type""" + + description: Optional[str] = None + """描述信息 / Description""" + + title: Optional[str] = None + """标题 / Title""" + + properties: Optional[Dict[str, "ToolSchema"]] = None + """对象属性 / Object properties""" + + required: Optional[List[str]] = None + """必填字段 / Required fields""" + + additional_properties: Optional[bool] = None + """是否允许额外属性 / Whether additional properties are allowed""" + + items: Optional["ToolSchema"] = None + """数组元素类型 / Array item type""" + + min_items: Optional[int] = None + """数组最小长度 / Minimum array length""" + + max_items: Optional[int] = None + """数组最大长度 / Maximum array length""" + + pattern: Optional[str] = None + """字符串正则模式 / String regex pattern""" + + min_length: Optional[int] = None + """字符串最小长度 / Minimum string length""" + + max_length: Optional[int] = None + """字符串最大长度 / Maximum string length""" + + format: Optional[str] = None + """字符串格式 / String format (date, date-time, email, uri, etc.)""" + + enum: Optional[List[Any]] = None + """枚举值 / Enum values""" + + minimum: Optional[float] = None + """数值最小值 / Minimum numeric value""" + + maximum: Optional[float] = None + """数值最大值 / Maximum numeric value""" + + exclusive_minimum: Optional[float] = None + """数值排他最小值 / Exclusive minimum numeric value""" + + exclusive_maximum: Optional[float] = None + """数值排他最大值 / Exclusive maximum numeric value""" + + any_of: Optional[List["ToolSchema"]] = None + """任一匹配 / Any of""" + + one_of: Optional[List["ToolSchema"]] = None + """唯一匹配 / One of""" + + all_of: Optional[List["ToolSchema"]] = None + """全部匹配 / All of""" + + default: Optional[Any] = None + """默认值 / Default value""" + + @classmethod + def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": + """从任意 OpenAPI/JSON Schema 创建 ToolSchema / Create ToolSchema from any OpenAPI/JSON Schema + + 递归解析所有嵌套结构,保留完整的 schema 信息。 + Recursively parses all nested structures, preserving complete schema information. + """ + if not schema or not isinstance(schema, dict): + return cls(type="string") + + from pydash import get as pydash_get + + properties_raw = pydash_get(schema, "properties", {}) + properties = ( + { + key: cls.from_any_openapi_schema(value) + for key, value in properties_raw.items() + } + if properties_raw + else None + ) + + items_raw = pydash_get(schema, "items") + items = cls.from_any_openapi_schema(items_raw) if items_raw else None + + any_of_raw = pydash_get(schema, "anyOf") + any_of = ( + [cls.from_any_openapi_schema(s) for s in any_of_raw] + if any_of_raw + else None + ) + + one_of_raw = pydash_get(schema, "oneOf") + one_of = ( + [cls.from_any_openapi_schema(s) for s in one_of_raw] + if one_of_raw + else None + ) + + all_of_raw = pydash_get(schema, "allOf") + all_of = ( + [cls.from_any_openapi_schema(s) for s in all_of_raw] + if all_of_raw + else None + ) + + return cls( + type=pydash_get(schema, "type"), + description=pydash_get(schema, "description"), + title=pydash_get(schema, "title"), + properties=properties, + required=pydash_get(schema, "required"), + additional_properties=pydash_get(schema, "additionalProperties"), + items=items, + min_items=pydash_get(schema, "minItems"), + max_items=pydash_get(schema, "maxItems"), + pattern=pydash_get(schema, "pattern"), + min_length=pydash_get(schema, "minLength"), + max_length=pydash_get(schema, "maxLength"), + format=pydash_get(schema, "format"), + enum=pydash_get(schema, "enum"), + minimum=pydash_get(schema, "minimum"), + maximum=pydash_get(schema, "maximum"), + exclusive_minimum=pydash_get(schema, "exclusiveMinimum"), + exclusive_maximum=pydash_get(schema, "exclusiveMaximum"), + any_of=any_of, + one_of=one_of, + all_of=all_of, + default=pydash_get(schema, "default"), + ) + + def to_json_schema(self) -> Dict[str, Any]: + """转换为标准 JSON Schema 格式 / Convert to standard JSON Schema format""" + result: Dict[str, Any] = {} + + if self.type: + result["type"] = self.type + if self.description: + result["description"] = self.description + if self.title: + result["title"] = self.title + + if self.properties: + result["properties"] = { + k: v.to_json_schema() for k, v in self.properties.items() + } + if self.required: + result["required"] = self.required + if self.additional_properties is not None: + result["additionalProperties"] = self.additional_properties + + if self.items: + result["items"] = self.items.to_json_schema() + if self.min_items is not None: + result["minItems"] = self.min_items + if self.max_items is not None: + result["maxItems"] = self.max_items + + if self.pattern: + result["pattern"] = self.pattern + if self.min_length is not None: + result["minLength"] = self.min_length + if self.max_length is not None: + result["maxLength"] = self.max_length + if self.format: + result["format"] = self.format + if self.enum: + result["enum"] = self.enum + + if self.minimum is not None: + result["minimum"] = self.minimum + if self.maximum is not None: + result["maximum"] = self.maximum + if self.exclusive_minimum is not None: + result["exclusiveMinimum"] = self.exclusive_minimum + if self.exclusive_maximum is not None: + result["exclusiveMaximum"] = self.exclusive_maximum + + if self.any_of: + result["anyOf"] = [s.to_json_schema() for s in self.any_of] + if self.one_of: + result["oneOf"] = [s.to_json_schema() for s in self.one_of] + if self.all_of: + result["allOf"] = [s.to_json_schema() for s in self.all_of] + + if self.default is not None: + result["default"] = self.default + + return result + + +class ToolInfo(BaseModel): + """工具信息 / Tool Information + + 描述单个工具的名称、描述和参数 schema。 + Describes a single tool's name, description, and parameter schema. + """ + + name: Optional[str] = None + """工具名称 / Tool name""" + + description: Optional[str] = None + """工具描述 / Tool description""" + + parameters: Optional[ToolSchema] = None + """工具参数 schema / Tool parameter schema""" + + @classmethod + def from_mcp_tool(cls, tool: Any) -> "ToolInfo": + """从 MCP tool 创建 ToolInfo / Create ToolInfo from MCP tool""" + if hasattr(tool, "name"): + tool_name = tool.name + tool_description = getattr(tool, "description", None) + input_schema = getattr(tool, "inputSchema", None) or getattr( + tool, "input_schema", None + ) + elif isinstance(tool, dict): + tool_name = tool.get("name") + tool_description = tool.get("description") + input_schema = tool.get("inputSchema") or tool.get("input_schema") + else: + raise ValueError(f"Unsupported MCP tool format: {type(tool)}") + + if not tool_name: + raise ValueError("MCP tool must have a name") + + parameters = None + if input_schema: + if isinstance(input_schema, dict): + parameters = ToolSchema.from_any_openapi_schema(input_schema) + elif hasattr(input_schema, "model_dump"): + parameters = ToolSchema.from_any_openapi_schema( + input_schema.model_dump() + ) + + return cls( + name=tool_name, + description=tool_description, + parameters=parameters or ToolSchema(type="object", properties={}), + ) diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py new file mode 100644 index 0000000..bdfaa5d --- /dev/null +++ b/agentrun/tool/tool.py @@ -0,0 +1,440 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/tool/__tool_async_template.py + +Tool 资源类 / Tool Resource Class + +提供工具资源的面向对象封装和完整生命周期管理。 +Provides object-oriented wrapper and complete lifecycle management for tool resources. +""" + +from typing import Any, Dict, List, Optional + +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import BaseModel + +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class Tool(BaseModel): + """工具资源 / Tool Resource + + 提供工具的查询、调用等功能。 + Provides query, invocation and other functionality for tools. + + Attributes: + code_configuration: 代码包配置 / Code configuration + container_configuration: 容器配置 / Container configuration + created_time: 创建时间 / Creation time + data_endpoint: 数据链路端点 / Data endpoint + description: 描述 / Description + environment_variables: 环境变量 / Environment variables + gpu: GPU 配置 / GPU configuration + internet_access: 是否允许公网访问 / Whether internet access is allowed + last_modified_time: 最后修改时间 / Last modified time + log_configuration: 日志配置 / Log configuration + mcp_config: MCP 配置 / MCP configuration + memory: 内存大小(MB) / Memory size in MB + name: 工具名称 / Tool name + network_config: 网络配置 / Network configuration + oss_mount_config: OSS 挂载配置 / OSS mount configuration + protocol_spec: 协议规格(OpenAPI JSON) / Protocol spec (OpenAPI JSON) + protocol_type: 协议类型 / Protocol type + status: 状态 / Status + timeout: 超时时间(秒) / Timeout in seconds + tool_id: 工具 ID / Tool ID + tool_name: 工具名称 / Tool name + tool_type: 工具类型(MCP/FUNCTIONCALL) / Tool type + version_id: 版本 ID / Version ID + """ + + code_configuration: Optional[ToolCodeConfiguration] = None + """代码包配置 / Code configuration""" + + container_configuration: Optional[ToolContainerConfiguration] = None + """容器配置 / Container configuration""" + + created_time: Optional[str] = None + """创建时间 / Creation time""" + + data_endpoint: Optional[str] = None + """数据链路端点 / Data endpoint""" + + description: Optional[str] = None + """描述 / Description""" + + environment_variables: Optional[Dict[str, str]] = None + """环境变量 / Environment variables""" + + gpu: Optional[str] = None + """GPU 配置 / GPU configuration""" + + internet_access: Optional[bool] = None + """是否允许公网访问 / Whether internet access is allowed""" + + last_modified_time: Optional[str] = None + """最后修改时间 / Last modified time""" + + log_configuration: Optional[ToolLogConfiguration] = None + """日志配置 / Log configuration""" + + mcp_config: Optional[McpConfig] = None + """MCP 配置 / MCP configuration""" + + memory: Optional[int] = None + """内存大小(MB) / Memory size in MB""" + + name: Optional[str] = None + """工具名称 / Tool name""" + + network_config: Optional[ToolNetworkConfiguration] = None + """网络配置 / Network configuration""" + + oss_mount_config: Optional[ToolOSSMountConfig] = None + """OSS 挂载配置 / OSS mount configuration""" + + protocol_spec: Optional[str] = None + """协议规格(OpenAPI JSON 字符串) / Protocol spec (OpenAPI JSON string)""" + + protocol_type: Optional[str] = None + """协议类型 / Protocol type""" + + status: Optional[str] = None + """状态 / Status""" + + timeout: Optional[int] = None + """超时时间(秒) / Timeout in seconds""" + + tool_id: Optional[str] = None + """工具 ID / Tool ID""" + + tool_name: Optional[str] = None + """工具名称 / Tool name""" + + tool_type: Optional[str] = None + """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + + version_id: Optional[str] = None + """版本 ID / Version ID""" + + @classmethod + def __get_client(cls, config: Optional[Config] = None): + from .client import ToolClient + + return ToolClient(config) + + @classmethod + async def get_by_name_async( + cls, name: str, config: Optional[Config] = None + ) -> "Tool": + """异步通过名称获取工具 / Get tool by name asynchronously""" + cli = cls.__get_client(config) + return await cli.get_async(name=name) + + @classmethod + def get_by_name(cls, name: str, config: Optional[Config] = None) -> "Tool": + """同步通过名称获取工具 / Get tool by name synchronously""" + cli = cls.__get_client(config) + return cli.get(name=name) + + async def get_async(self, config: Optional[Config] = None) -> "Tool": + """异步刷新工具信息 / Refresh tool info asynchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = await self.get_by_name_async( + name=effective_name, config=config + ) + return self.update_self(result) + + def get(self, config: Optional[Config] = None) -> "Tool": + """同步刷新工具信息 / Refresh tool info synchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = self.get_by_name(name=effective_name, config=config) + return self.update_self(result) + + def _get_functioncall_server_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 FunctionCall 工具的 fallback server URL / Get fallback server URL for FunctionCall tools + + 当 OpenAPI spec 中没有 servers 字段时,使用 data_endpoint 构造 URL。 + Constructs URL from data_endpoint when servers is not present in OpenAPI spec. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}" + + def _get_tool_type(self) -> Optional[ToolType]: + """获取工具类型 / Get tool type""" + raw_type = self.tool_type + if raw_type: + try: + return ToolType(raw_type) + except ValueError: + return None + return None + + def _get_mcp_endpoint( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 MCP 数据链路 URL / Get MCP data endpoint URL + + 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 + 如果 self.data_endpoint 为空,则从 Config 中获取。 + Determines /mcp or /sse path based on session_affinity. + Falls back to Config if self.data_endpoint is not set. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + if session_affinity == "MCP_STREAMABLE": + return f"{data_endpoint}/tools/{effective_name}/mcp" + return f"{data_endpoint}/tools/{effective_name}/sse" + + async def list_tools_async( + self, config: Optional[Config] = None + ) -> List[ToolInfo]: + """异步获取子工具列表 / Get sub-tool list asynchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return await session.list_tools_async() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return await openapi_client.list_tools_async() + + return [] + + def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: + """同步获取子工具列表 / Get sub-tool list synchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return session.list_tools() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return openapi_client.list_tools() + + return [] + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """异步调用子工具 / Call sub-tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = await session.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = await openapi_client.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """同步调用子工具 / Call sub-tool synchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = session.call_tool(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = openapi_client.call_tool(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") diff --git a/pyproject.toml b/pyproject.toml index b3950bb..2c5c075 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "litellm>=1.79.3", "alibabacloud-devs20230714>=2.4.1", "pydash>=8.0.5", - "alibabacloud-agentrun20250910>=5.3.1", + "alibabacloud-agentrun20250910>=5.6.0", "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", diff --git a/tests/unittests/tool/__init__.py b/tests/unittests/tool/__init__.py new file mode 100644 index 0000000..745294f --- /dev/null +++ b/tests/unittests/tool/__init__.py @@ -0,0 +1,5 @@ +"""Tool 模块单元测试 / Tool Module Unit Tests + +测试 tool 模块中数据模型、API 客户端和资源类的相关功能。 +Tests data models, API clients, and resource classes in the tool module. +""" diff --git a/tests/unittests/tool/test_mcp.py b/tests/unittests/tool/test_mcp.py new file mode 100644 index 0000000..83dc3fa --- /dev/null +++ b/tests/unittests/tool/test_mcp.py @@ -0,0 +1,308 @@ +"""Tool MCP 会话单元测试 / Tool MCP Session Unit Tests + +测试 ToolMCPSession 的 MCP 协议交互功能。 +Tests MCP protocol interaction functionality of ToolMCPSession. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from agentrun.tool.api.mcp import ToolMCPSession +from agentrun.tool.model import ToolInfo + + +class TestToolMCPSessionInit: + """测试 ToolMCPSession 初始化 / Test ToolMCPSession initialization""" + + def test_init_with_defaults(self): + """测试使用默认参数初始化""" + session = ToolMCPSession(endpoint="http://example.com/mcp") + assert session.endpoint == "http://example.com/mcp" + assert session.session_affinity is None + assert session.headers == {} + + def test_init_with_all_parameters(self): + """测试使用所有参数初始化""" + headers = {"Authorization": "Bearer token"} + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + headers=headers, + ) + assert session.endpoint == "http://example.com/mcp" + assert session.session_affinity == "MCP_STREAMABLE" + assert session.headers == headers + + +class TestToolMCPSessionIsStreamable: + """测试 is_streamable 属性""" + + def test_is_streamable_returns_true_for_mcp_streamable(self): + """测试 MCP_STREAMABLE 返回 True""" + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + ) + assert session.is_streamable is True + + def test_is_streamable_returns_false_for_other_values(self): + """测试其他值返回 False""" + for value in [None, "MCP_SSE", "OTHER", ""]: + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity=value, + ) + assert session.is_streamable is False + + +def _make_mock_mcp_tool(name: str, description: str) -> MagicMock: + """创建 mock MCP tool 对象""" + tool = MagicMock() + tool.name = name + tool.description = description + tool.inputSchema = {"type": "object", "properties": {}} + return tool + + +def _setup_mock_mcp_modules( + mock_session: AsyncMock, +) -> dict: + """设置 mock mcp 模块,返回需要注入到 sys.modules 的字典""" + mock_client_session_cls = MagicMock() + mock_session_ctx = AsyncMock() + mock_session_ctx.__aenter__.return_value = mock_session + mock_session_ctx.__aexit__.return_value = None + mock_client_session_cls.return_value = mock_session_ctx + + # mock streamablehttp_client + mock_streamable_fn = MagicMock() + mock_streamable_ctx = AsyncMock() + mock_streamable_ctx.__aenter__.return_value = ( + AsyncMock(), + AsyncMock(), + MagicMock(), + ) + mock_streamable_ctx.__aexit__.return_value = None + mock_streamable_fn.return_value = mock_streamable_ctx + + # mock sse_client + mock_sse_fn = MagicMock() + mock_sse_ctx = AsyncMock() + mock_sse_ctx.__aenter__.return_value = (AsyncMock(), AsyncMock()) + mock_sse_ctx.__aexit__.return_value = None + mock_sse_fn.return_value = mock_sse_ctx + + mock_mcp = MagicMock() + mock_mcp.ClientSession = mock_client_session_cls + + mock_mcp_client_streamable = MagicMock() + mock_mcp_client_streamable.streamablehttp_client = mock_streamable_fn + + mock_mcp_client_sse = MagicMock() + mock_mcp_client_sse.sse_client = mock_sse_fn + + return { + "mcp": mock_mcp, + "mcp.client": MagicMock(), + "mcp.client.streamable_http": mock_mcp_client_streamable, + "mcp.client.sse": mock_mcp_client_sse, + } + + +class TestToolMCPSessionListToolsAsync: + """测试 list_tools_async 方法""" + + @pytest.mark.asyncio + async def test_list_tools_async_streamable_mode(self): + """测试 Streamable 模式下获取工具列表""" + mock_tool = _make_mock_mcp_tool("tool1", "Test tool 1") + mock_result = MagicMock() + mock_result.tools = [mock_tool] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=mock_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + headers={"Authorization": "Bearer token"}, + ) + tools = await session.list_tools_async() + + assert len(tools) == 1 + assert isinstance(tools[0], ToolInfo) + assert tools[0].name == "tool1" + + @pytest.mark.asyncio + async def test_list_tools_async_sse_mode(self): + """测试 SSE 模式下获取工具列表""" + mock_tool = _make_mock_mcp_tool("tool1", "Test tool 1") + mock_result = MagicMock() + mock_result.tools = [mock_tool] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=mock_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_SSE", + ) + tools = await session.list_tools_async() + + assert len(tools) == 1 + assert isinstance(tools[0], ToolInfo) + + @pytest.mark.asyncio + async def test_list_tools_async_import_error(self): + """测试 mcp 未安装时返回空列表""" + saved_modules = {} + modules_to_remove = [ + k for k in sys.modules if k == "mcp" or k.startswith("mcp.") + ] + for key in modules_to_remove: + saved_modules[key] = sys.modules.pop(key) + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ # type: ignore + + def mock_import(name, *args, **kwargs): + if name == "mcp" or name.startswith("mcp."): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + session = ToolMCPSession(endpoint="http://example.com/mcp") + tools = await session.list_tools_async() + + sys.modules.update(saved_modules) + assert tools == [] + + +class TestToolMCPSessionListTools: + """测试 list_tools 同步方法""" + + def test_list_tools_synchronous(self): + """测试同步获取工具列表""" + expected_tools = [ToolInfo(name="tool1", description="Test")] + + with patch.object( + ToolMCPSession, + "list_tools_async", + new_callable=AsyncMock, + return_value=expected_tools, + ): + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = MagicMock() + mock_loop.run_until_complete.return_value = expected_tools + mock_get_loop.return_value = mock_loop + + session = ToolMCPSession(endpoint="http://example.com/mcp") + tools = session.list_tools() + + assert tools == expected_tools + mock_loop.run_until_complete.assert_called_once() + + +class TestToolMCPSessionCallToolAsync: + """测试 call_tool_async 方法""" + + @pytest.mark.asyncio + async def test_call_tool_async_streamable_mode(self): + """测试 Streamable 模式下调用工具""" + mock_call_result = MagicMock() + mock_call_result.content = [{"type": "text", "text": "result"}] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_call_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + ) + result = await session.call_tool_async( + "test_tool", {"param": "value"} + ) + + assert result == mock_call_result + + @pytest.mark.asyncio + async def test_call_tool_async_sse_mode(self): + """测试 SSE 模式下调用工具""" + mock_call_result = MagicMock() + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_call_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_SSE", + ) + result = await session.call_tool_async("test_tool", {"key": "val"}) + + assert result == mock_call_result + + @pytest.mark.asyncio + async def test_call_tool_async_import_error(self): + """测试 mcp 未安装时抛出 ImportError""" + saved_modules = {} + modules_to_remove = [ + k for k in sys.modules if k == "mcp" or k.startswith("mcp.") + ] + for key in modules_to_remove: + saved_modules[key] = sys.modules.pop(key) + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ # type: ignore + + def mock_import(name, *args, **kwargs): + if name == "mcp" or name.startswith("mcp."): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + session = ToolMCPSession(endpoint="http://example.com/mcp") + with pytest.raises(ImportError): + await session.call_tool_async("test_tool") + + sys.modules.update(saved_modules) + + +class TestToolMCPSessionCallTool: + """测试 call_tool 同步方法""" + + def test_call_tool_synchronous(self): + """测试同步调用工具""" + expected_result = {"result": "success"} + + with patch.object( + ToolMCPSession, + "call_tool_async", + new_callable=AsyncMock, + return_value=expected_result, + ): + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = MagicMock() + mock_loop.run_until_complete.return_value = expected_result + mock_get_loop.return_value = mock_loop + + session = ToolMCPSession(endpoint="http://example.com/mcp") + result = session.call_tool("test_tool", {"param": "value"}) + + assert result == expected_result + mock_loop.run_until_complete.assert_called_once() diff --git a/tests/unittests/tool/test_model.py b/tests/unittests/tool/test_model.py new file mode 100644 index 0000000..b195d0e --- /dev/null +++ b/tests/unittests/tool/test_model.py @@ -0,0 +1,661 @@ +"""Tool 模型单元测试 / Tool Model Unit Tests + +测试 tool 模块中数据模型和工具 schema 的相关功能。 +Tests data models and tool schema functionality in the tool module. +""" + +import pytest + +from agentrun.tool.model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNASConfig, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class TestToolType: + """测试 ToolType 枚举""" + + def test_mcp_type(self): + """测试 MCP 类型""" + assert ToolType.MCP == "MCP" + assert ToolType.MCP.value == "MCP" + + def test_functioncall_type(self): + """测试 FUNCTIONCALL 类型""" + assert ToolType.FUNCTIONCALL == "FUNCTIONCALL" + assert ToolType.FUNCTIONCALL.value == "FUNCTIONCALL" + + +class TestMcpConfig: + """测试 McpConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = McpConfig() + assert config.session_affinity is None + assert config.session_affinity_config is None + assert config.proxy_enabled is None + assert config.bound_configuration is None + assert config.mcp_proxy_configuration is None + + def test_with_values(self): + """测试带值创建""" + config = McpConfig( + session_affinity="MCP_SSE", + session_affinity_config='{"key": "value"}', + proxy_enabled=True, + bound_configuration={"key": "value"}, + mcp_proxy_configuration={"proxy": "config"}, + ) + assert config.session_affinity == "MCP_SSE" + assert config.session_affinity_config == '{"key": "value"}' + assert config.proxy_enabled is True + assert config.bound_configuration == {"key": "value"} + assert config.mcp_proxy_configuration == {"proxy": "config"} + + +class TestToolCodeConfiguration: + """测试 ToolCodeConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolCodeConfiguration() + assert config.code_checksum is None + assert config.code_size is None + assert config.command is None + assert config.language is None + assert config.oss_bucket_name is None + assert config.oss_object_name is None + + def test_with_values(self): + """测试带值创建""" + config = ToolCodeConfiguration( + code_checksum="abc123", + code_size=1024, + command=["python", "app.py"], + language="python3.10", + oss_bucket_name="my-bucket", + oss_object_name="code.zip", + ) + assert config.code_checksum == "abc123" + assert config.code_size == 1024 + assert config.command == ["python", "app.py"] + assert config.language == "python3.10" + assert config.oss_bucket_name == "my-bucket" + assert config.oss_object_name == "code.zip" + + +class TestToolContainerConfiguration: + """测试 ToolContainerConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolContainerConfiguration() + assert config.args is None + assert config.command is None + assert config.image is None + assert config.port is None + + def test_with_values(self): + """测试带值创建""" + config = ToolContainerConfiguration( + args=["--arg1", "value1"], + command=["python", "app.py"], + image="registry.example.com/tool:latest", + port=8080, + ) + assert config.args == ["--arg1", "value1"] + assert config.command == ["python", "app.py"] + assert config.image == "registry.example.com/tool:latest" + assert config.port == 8080 + + +class TestToolLogConfiguration: + """测试 ToolLogConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolLogConfiguration() + assert config.log_store is None + assert config.project is None + + def test_with_values(self): + """测试带值创建""" + config = ToolLogConfiguration( + log_store="my-log-store", + project="my-project", + ) + assert config.log_store == "my-log-store" + assert config.project == "my-project" + + +class TestToolNASConfig: + """测试 ToolNASConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolNASConfig() + assert config.group_id is None + assert config.mount_points is None + assert config.user_id is None + + def test_with_values(self): + """测试带值创建""" + config = ToolNASConfig( + group_id=1001, + mount_points=[{"path": "/mnt/nas", "nas_id": "nas-123"}], + user_id=1000, + ) + assert config.group_id == 1001 + assert config.mount_points == [ + {"path": "/mnt/nas", "nas_id": "nas-123"} + ] + assert config.user_id == 1000 + + +class TestToolNetworkConfiguration: + """测试 ToolNetworkConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolNetworkConfiguration() + assert config.security_group_id is None + assert config.vpc_id is None + assert config.vswitch_ids is None + + def test_with_values(self): + """测试带值创建""" + config = ToolNetworkConfiguration( + security_group_id="sg-123", + vpc_id="vpc-456", + vswitch_ids=["vsw-789", "vsw-012"], + ) + assert config.security_group_id == "sg-123" + assert config.vpc_id == "vpc-456" + assert config.vswitch_ids == ["vsw-789", "vsw-012"] + + +class TestToolOSSMountConfig: + """测试 ToolOSSMountConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolOSSMountConfig() + assert config.mount_points is None + + def test_with_values(self): + """测试带值创建""" + config = ToolOSSMountConfig( + mount_points=[{ + "bucket": "my-bucket", + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + }] + ) + assert config.mount_points == [ + {"bucket": "my-bucket", "endpoint": "oss-cn-hangzhou.aliyuncs.com"} + ] + + +class TestToolSchema: + """测试 ToolSchema 模型""" + + def test_default_values(self): + """测试默认值""" + schema = ToolSchema() + assert schema.type is None + assert schema.description is None + assert schema.properties is None + assert schema.required is None + assert schema.items is None + assert schema.any_of is None + assert schema.one_of is None + assert schema.all_of is None + + def test_from_any_openapi_schema_simple(self): + """测试从简单 OpenAPI Schema 创建""" + openapi_schema = { + "type": "object", + "description": "A simple object", + "properties": { + "name": {"type": "string", "description": "Name field"}, + "age": {"type": "integer", "description": "Age field"}, + }, + "required": ["name"], + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "object" + assert schema.description == "A simple object" + assert schema.properties is not None + assert "name" in schema.properties + assert "age" in schema.properties + assert schema.required == ["name"] + + def test_from_any_openapi_schema_nested(self): + """测试从嵌套 OpenAPI Schema 创建""" + openapi_schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"}, + }, + } + }, + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "object" + assert schema.properties is not None + assert "user" in schema.properties + assert schema.properties["user"].type == "object" + assert schema.properties["user"].properties is not None + assert "name" in schema.properties["user"].properties + + def test_from_any_openapi_schema_empty(self): + """测试从空 Schema 创建""" + schema = ToolSchema.from_any_openapi_schema(None) + assert schema.type == "string" + + # Empty dict creates a ToolSchema with None type (pydash_get returns None for missing keys) + schema = ToolSchema.from_any_openapi_schema({}) + # Actually returns "string" due to the check at the beginning of the method + assert schema.type == "string" + + def test_from_any_openapi_schema_non_dict(self): + """测试从非 dict 输入创建""" + schema = ToolSchema.from_any_openapi_schema("invalid") + assert schema.type == "string" + + schema = ToolSchema.from_any_openapi_schema(123) + assert schema.type == "string" + + def test_from_any_openapi_schema_array(self): + """测试从数组 Schema 创建""" + openapi_schema = { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 10, + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "array" + assert schema.items is not None + assert schema.items.type == "string" + assert schema.min_items == 1 + assert schema.max_items == 10 + + def test_from_any_openapi_schema_anyof(self): + """测试 anyOf 支持""" + openapi_schema = { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.any_of is not None + assert len(schema.any_of) == 2 + assert schema.any_of[0].type == "string" + assert schema.any_of[1].type == "integer" + + def test_from_any_openapi_schema_oneof(self): + """测试 oneOf 支持""" + openapi_schema = { + "oneOf": [ + {"type": "string"}, + {"type": "boolean"}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.one_of is not None + assert len(schema.one_of) == 2 + assert schema.one_of[0].type == "string" + assert schema.one_of[1].type == "boolean" + + def test_from_any_openapi_schema_allof(self): + """测试 allOf 支持""" + openapi_schema = { + "allOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}}, + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.all_of is not None + assert len(schema.all_of) == 2 + assert schema.all_of[0].type == "object" + assert schema.all_of[1].type == "object" + + def test_to_json_schema_simple(self): + """测试转换为 JSON Schema - 简单情况""" + schema = ToolSchema( + type="string", + description="A string field", + min_length=1, + max_length=100, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["description"] == "A string field" + assert json_schema["minLength"] == 1 + assert json_schema["maxLength"] == 100 + + def test_to_json_schema_nested(self): + """测试转换为 JSON Schema - 嵌套情况""" + schema = ToolSchema( + type="object", + properties={ + "user": ToolSchema( + type="object", + properties={ + "name": ToolSchema(type="string"), + }, + ) + }, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert "properties" in json_schema + assert "user" in json_schema["properties"] + assert json_schema["properties"]["user"]["type"] == "object" + assert "properties" in json_schema["properties"]["user"] + + def test_to_json_schema_roundtrip(self): + """测试完整往返转换""" + openapi_schema = { + "type": "object", + "description": "Test schema", + "properties": { + "name": { + "type": "string", + "description": "Name", + "minLength": 1, + }, + "age": { + "type": "integer", + "minimum": 0, + "maximum": 150, + }, + }, + "required": ["name"], + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["description"] == "Test schema" + assert "properties" in json_schema + assert "name" in json_schema["properties"] + assert "age" in json_schema["properties"] + assert json_schema["required"] == ["name"] + + def test_to_json_schema_with_anyof(self): + """测试转换包含 anyOf 的 schema""" + schema = ToolSchema( + any_of=[ + ToolSchema(type="string"), + ToolSchema(type="integer"), + ] + ) + json_schema = schema.to_json_schema() + assert "anyOf" in json_schema + assert len(json_schema["anyOf"]) == 2 + assert json_schema["anyOf"][0]["type"] == "string" + assert json_schema["anyOf"][1]["type"] == "integer" + + def test_recursive_properties(self): + """测试递归嵌套 properties""" + schema = ToolSchema( + type="object", + properties={ + "level1": ToolSchema( + type="object", + properties={ + "level2": ToolSchema( + type="object", + properties={ + "level3": ToolSchema(type="string"), + }, + ), + }, + ), + }, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["properties"]["level1"]["type"] == "object" + assert ( + json_schema["properties"]["level1"]["properties"]["level2"]["type"] + == "object" + ) + assert ( + json_schema["properties"]["level1"]["properties"]["level2"][ + "properties" + ]["level3"]["type"] + == "string" + ) + + def test_to_json_schema_with_string_constraints(self): + """测试 pattern, min_length, max_length, format""" + schema = ToolSchema( + type="string", + pattern="^[a-zA-Z]+$", + min_length=1, + max_length=100, + format="email", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["pattern"] == "^[a-zA-Z]+$" + assert json_schema["minLength"] == 1 + assert json_schema["maxLength"] == 100 + assert json_schema["format"] == "email" + + def test_to_json_schema_with_number_constraints(self): + """测试 minimum, maximum, exclusive_minimum, exclusive_maximum""" + schema = ToolSchema( + type="number", + minimum=0, + maximum=100, + exclusive_minimum=0, + exclusive_maximum=100, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "number" + assert json_schema["minimum"] == 0 + assert json_schema["maximum"] == 100 + assert json_schema["exclusiveMinimum"] == 0 + assert json_schema["exclusiveMaximum"] == 100 + + def test_to_json_schema_with_enum(self): + """测试 enum 字段""" + schema = ToolSchema( + type="string", + enum=["red", "green", "blue"], + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["enum"] == ["red", "green", "blue"] + + def test_to_json_schema_with_additional_properties(self): + """测试 additionalProperties""" + schema = ToolSchema( + type="object", + additional_properties=True, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["additionalProperties"] is True + + def test_to_json_schema_with_default(self): + """测试 default 字段""" + schema = ToolSchema( + type="string", + default="default_value", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["default"] == "default_value" + + def test_to_json_schema_with_title(self): + """测试 title 字段""" + schema = ToolSchema( + type="string", + title="String Field", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["title"] == "String Field" + + def test_to_json_schema_with_one_of(self): + """测试 oneOf 序列化""" + schema = ToolSchema( + one_of=[ + ToolSchema(type="string"), + ToolSchema(type="integer"), + ], + ) + json_schema = schema.to_json_schema() + assert "oneOf" in json_schema + assert len(json_schema["oneOf"]) == 2 + assert json_schema["oneOf"][0]["type"] == "string" + assert json_schema["oneOf"][1]["type"] == "integer" + + def test_to_json_schema_with_all_of(self): + """测试 allOf 序列化""" + schema = ToolSchema( + all_of=[ + ToolSchema( + type="object", + properties={"name": ToolSchema(type="string")}, + ), + ToolSchema( + type="object", + properties={"age": ToolSchema(type="integer")}, + ), + ], + ) + json_schema = schema.to_json_schema() + assert "allOf" in json_schema + assert len(json_schema["allOf"]) == 2 + assert json_schema["allOf"][0]["type"] == "object" + assert json_schema["allOf"][1]["type"] == "object" + + +class TestToolInfo: + """测试 ToolInfo 模型""" + + def test_default_values(self): + """测试默认值""" + info = ToolInfo() + assert info.name is None + assert info.description is None + assert info.parameters is None + + def test_with_values(self): + """测试带值创建""" + info = ToolInfo( + name="test_tool", + description="A test tool", + parameters=ToolSchema( + type="object", + properties={ + "input": ToolSchema(type="string"), + }, + ), + ) + assert info.name == "test_tool" + assert info.description == "A test tool" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_object(self): + """测试从 MCP 工具对象创建""" + mcp_tool = { + "name": "mcp_tool", + "description": "An MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "param": {"type": "string"}, + }, + }, + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "mcp_tool" + assert info.description == "An MCP tool" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_dict(self): + """测试从 MCP 工具字典创建""" + mcp_tool = { + "name": "dict_tool", + "description": "Dict tool", + "inputSchema": { + "type": "string", + }, + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "dict_tool" + assert info.description == "Dict tool" + assert info.parameters is not None + assert info.parameters.type == "string" + + def test_from_mcp_tool_without_name(self): + """测试从没有 name 的 MCP 工具创建""" + mcp_tool = { + "description": "Tool without name", + "inputSchema": {"type": "string"}, + } + with pytest.raises(ValueError, match="name"): + ToolInfo.from_mcp_tool(mcp_tool) + + def test_from_mcp_tool_with_empty_schema(self): + """测试从空 schema 的 MCP 工具创建""" + mcp_tool = { + "name": "empty_schema_tool", + "description": "Tool with empty schema", + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "empty_schema_tool" + assert info.description == "Tool with empty schema" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_model_dump(self): + """测试 from_mcp_tool 当 input_schema 有 model_dump 方法时""" + + class MockInputSchema: + + def model_dump(self): + return { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + }, + "required": ["param1"], + } + + mcp_tool = { + "name": "tool_with_model_dump", + "description": "Tool with model_dump input schema", + "inputSchema": MockInputSchema(), + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "tool_with_model_dump" + assert info.description == "Tool with model_dump input schema" + assert info.parameters is not None + assert info.parameters.type == "object" + assert "param1" in info.parameters.properties + assert "param2" in info.parameters.properties + assert info.parameters.required == ["param1"] diff --git a/tests/unittests/tool/test_openapi.py b/tests/unittests/tool/test_openapi.py new file mode 100644 index 0000000..995c40d --- /dev/null +++ b/tests/unittests/tool/test_openapi.py @@ -0,0 +1,694 @@ +"""Tool OpenAPI 客户端单元测试 / Tool OpenAPI Client Unit Tests + +测试 ToolOpenAPIClient 的 OpenAPI Schema 解析和 HTTP 调用功能。 +Tests OpenAPI Schema parsing and HTTP call functionality of ToolOpenAPIClient. +""" + +import json +from unittest.mock import Mock, patch + +import httpx +import pytest + +from agentrun.tool.api.openapi import ToolOpenAPIClient +from agentrun.tool.model import ToolInfo, ToolSchema + + +class TestToolOpenAPIClient: + """测试 ToolOpenAPIClient""" + + @pytest.fixture + def sample_openapi_spec(self): + """示例 OpenAPI Spec""" + return json.dumps({ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + }, + "servers": [ + {"url": "https://api.example.com/v1"}, + ], + "paths": { + "/users": { + "get": { + "operationId": "listUsers", + "summary": "List all users", + "description": "Get a list of users", + "parameters": [ + { + "name": "limit", + "in": "query", + "required": False, + "schema": {"type": "integer"}, + }, + ], + }, + "post": { + "operationId": "createUser", + "summary": "Create a user", + "description": "Create a new user", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"}, + }, + "required": ["name"], + }, + } + } + }, + }, + }, + "/users/{id}": { + "get": { + "operationId": "getUser", + "summary": "Get user by ID", + }, + "put": { + "operationId": "updateUser", + "summary": "Update user", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + }, + } + } + }, + }, + "delete": { + "operationId": "deleteUser", + "summary": "Delete user", + }, + }, + }, + }) + + @pytest.fixture + def sample_openapi_spec_no_servers(self): + """没有 servers 的 OpenAPI Spec""" + return json.dumps({ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + }, + "paths": {}, + }) + + def test_init_with_valid_json(self, sample_openapi_spec): + """测试使用有效 JSON 初始化""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + assert client._spec is not None + assert client._spec["openapi"] == "3.0.0" + + def test_init_with_invalid_json(self): + """测试使用无效 JSON 初始化""" + client = ToolOpenAPIClient(protocol_spec="invalid json") + assert client._spec is None + + def test_init_with_none(self): + """测试使用 None 初始化""" + client = ToolOpenAPIClient(protocol_spec=None) + assert client._spec is None + + def test_init_with_headers(self, sample_openapi_spec): + """测试带 headers 初始化""" + headers = {"Authorization": "Bearer token"} + client = ToolOpenAPIClient( + protocol_spec=sample_openapi_spec, + headers=headers, + ) + assert client.headers == headers + + def test_server_url(self, sample_openapi_spec): + """测试获取 server URL""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + assert client.server_url == "https://api.example.com/v1" + + def test_server_url_no_servers(self, sample_openapi_spec_no_servers): + """测试没有 servers 时的 server URL""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec_no_servers) + assert client.server_url is None + + def test_server_url_no_spec(self): + """测试没有 spec 时的 server URL""" + client = ToolOpenAPIClient(protocol_spec=None) + assert client.server_url is None + + def test_parse_operations_get_method(self, sample_openapi_spec): + """测试解析 GET 方法""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + get_operation = next( + (op for op in operations if op["operation_id"] == "listUsers"), + None, + ) + assert get_operation is not None + assert get_operation["method"] == "GET" + assert get_operation["path"] == "/users" + assert get_operation["summary"] == "List all users" + assert get_operation["input_schema"] is not None + assert "properties" in get_operation["input_schema"] + + def test_parse_operations_post_method_with_request_body( + self, sample_openapi_spec + ): + """测试解析 POST 方法(带 requestBody)""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + post_operation = next( + (op for op in operations if op["operation_id"] == "createUser"), + None, + ) + assert post_operation is not None + assert post_operation["method"] == "POST" + assert post_operation["path"] == "/users" + assert post_operation["input_schema"] is not None + assert post_operation["input_schema"]["type"] == "object" + + def test_parse_operations_multiple_methods(self, sample_openapi_spec): + """测试解析多个 HTTP 方法""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + operation_ids = [op["operation_id"] for op in operations] + assert "listUsers" in operation_ids + assert "createUser" in operation_ids + assert "getUser" in operation_ids + assert "updateUser" in operation_ids + assert "deleteUser" in operation_ids + + def test_parse_operations_parameters(self, sample_openapi_spec): + """测试解析 parameters""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + list_users_op = next( + (op for op in operations if op["operation_id"] == "listUsers"), + None, + ) + assert list_users_op is not None + assert list_users_op["input_schema"] is not None + assert "properties" in list_users_op["input_schema"] + assert "limit" in list_users_op["input_schema"]["properties"] + + def test_parse_operations_no_spec(self): + """测试没有 spec 时的解析""" + client = ToolOpenAPIClient(protocol_spec=None) + operations = client._parse_operations() + assert operations == [] + + def test_list_tools(self, sample_openapi_spec): + """测试获取工具列表""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + tools = client.list_tools() + + assert len(tools) > 0 + assert all(isinstance(tool, ToolInfo) for tool in tools) + + # 检查特定工具 + list_users_tool = next( + (t for t in tools if t.name == "listUsers"), + None, + ) + assert list_users_tool is not None + assert list_users_tool.description == "List all users" + assert list_users_tool.parameters is not None + + def test_list_tools_empty_spec(self): + """测试空 spec 时的工具列表""" + client = ToolOpenAPIClient(protocol_spec='{"paths": {}}') + tools = client.list_tools() + assert tools == [] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_post_method( + self, mock_client_class, sample_openapi_spec + ): + """测试调用 POST 方法""" + # Mock httpx response + mock_response = Mock() + mock_response.json.return_value = {"id": 123, "name": "Test User"} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool( + "createUser", {"name": "Test User", "email": "test@example.com"} + ) + + assert result == {"id": 123, "name": "Test User"} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "POST" + assert "https://api.example.com/v1/users" in call_args[0][1] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_get_method(self, mock_client_class, sample_openapi_spec): + """测试调用 GET 方法""" + # Mock httpx response + mock_response = Mock() + mock_response.json.return_value = {"id": 123, "name": "Test User"} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("listUsers", {"limit": 10}) + + assert result == {"id": 123, "name": "Test User"} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "GET" + assert "limit" in call_args[1]["params"] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_text_response( + self, mock_client_class, sample_openapi_spec + ): + """测试调用工具返回文本响应""" + # Mock httpx response + mock_response = Mock() + mock_response.text = "Plain text response" + mock_response.headers = {"content-type": "text/plain"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("listUsers", {}) + + assert result == "Plain text response" + + def test_call_tool_operation_not_found(self, sample_openapi_spec): + """测试调用不存在的 operation""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + with pytest.raises( + ValueError, match="Operation 'nonExistent' not found" + ): + client.call_tool("nonExistent", {}) + + def test_call_tool_no_server_url(self): + """测试没有 server URL 时调用工具""" + spec_without_server = json.dumps({ + "openapi": "3.0.0", + "paths": { + "/test": { + "get": { + "operationId": "testOp", + }, + }, + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec_without_server) + with pytest.raises(ValueError, match="No server URL found"): + client.call_tool("testOp", {}) + + @patch("httpx.AsyncClient") + async def test_call_tool_async_post_method( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用 POST 方法""" + # Mock async httpx response + mock_response = Mock() + mock_response.json = Mock(return_value={"id": 123}) + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + # Create a proper async context manager mock + mock_client_instance = AsyncMock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("createUser", {"name": "Test"}) + + assert result == {"id": 123} + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_operation_not_found( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用不存在的 operation""" + mock_client_instance = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + with pytest.raises( + ValueError, match="Operation 'nonExistent' not found" + ): + await client.call_tool_async("nonExistent", {}) + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_no_server_url(self, mock_async_client_class): + """测试异步调用没有 server URL""" + spec_without_server = json.dumps({ + "openapi": "3.0.0", + "paths": { + "/test": { + "get": { + "operationId": "testOp", + }, + }, + }, + }) + mock_client_instance = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=spec_without_server) + with pytest.raises(ValueError, match="No server URL found"): + await client.call_tool_async("testOp", {}) + + async def test_list_tools_async(self, sample_openapi_spec): + """测试异步获取工具列表""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + tools = await client.list_tools_async() + + assert len(tools) > 0 + assert all(isinstance(tool, ToolInfo) for tool in tools) + + def test_resolve_ref(self): + """测试 _resolve_ref 解析 $ref 引用""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "User": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + ref = "#/components/schemas/User" + resolved = client._resolve_ref(ref) + assert resolved is not None + assert resolved["type"] == "object" + assert "name" in resolved["properties"] + + def test_resolve_ref_invalid(self): + """测试 _resolve_ref 无效引用""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": {"schemas": {"User": {"type": "object"}}}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + ref = "#/components/schemas/NonExistent" + resolved = client._resolve_ref(ref) + assert resolved == {} + + def test_resolve_ref_no_spec(self): + """测试 _resolve_ref 没有 spec""" + client = ToolOpenAPIClient(protocol_spec=None) + ref = "#/components/schemas/User" + resolved = client._resolve_ref(ref) + assert resolved == {} + + def test_resolve_schema_with_ref(self): + """测试 _resolve_schema 递归解析 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "User": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = {"$ref": "#/components/schemas/User"} + resolved = client._resolve_schema(schema) + assert resolved is not None + assert resolved["type"] == "object" + + def test_resolve_schema_none(self): + """测试 _resolve_schema 传入 None""" + spec = json.dumps({"openapi": "3.0.0"}) + client = ToolOpenAPIClient(protocol_spec=spec) + resolved = client._resolve_schema(None) + assert resolved is None + + def test_resolve_schema_with_items(self): + """测试 _resolve_schema 解析 items 中的 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": {"schemas": {"Item": {"type": "string"}}}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = { + "type": "array", + "items": {"$ref": "#/components/schemas/Item"}, + } + resolved = client._resolve_schema(schema) + assert resolved is not None + assert resolved["type"] == "array" + assert resolved["items"]["type"] == "string" + + def test_resolve_schema_with_anyof(self): + """测试 _resolve_schema 解析 anyOf 中的 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "StringType": {"type": "string"}, + "NumberType": {"type": "number"}, + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = { + "anyOf": [ + {"$ref": "#/components/schemas/StringType"}, + {"$ref": "#/components/schemas/NumberType"}, + ] + } + resolved = client._resolve_schema(schema) + assert resolved is not None + assert "anyOf" in resolved + assert len(resolved["anyOf"]) == 2 + assert resolved["anyOf"][0]["type"] == "string" + assert resolved["anyOf"][1]["type"] == "number" + + def test_server_url_fallback(self): + """测试 server_url 使用 fallback_server_url""" + spec = json.dumps( + {"openapi": "3.0.0", "info": {"title": "Test API"}, "paths": {}} + ) + client = ToolOpenAPIClient( + protocol_spec=spec, + fallback_server_url="https://fallback.example.com", + ) + assert client.server_url == "https://fallback.example.com" + + def test_server_url_empty_servers_list(self): + """测试 servers 为空列表时使用 fallback""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [], + "paths": {}, + }) + client = ToolOpenAPIClient( + protocol_spec=spec, + fallback_server_url="https://fallback.example.com", + ) + assert client.server_url == "https://fallback.example.com" + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_put_method(self, mock_client_class, sample_openapi_spec): + """测试 PUT 方法调用(走 POST/PUT/PATCH 分支)""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("updateUser", {"name": "Updated Name"}) + + assert result == {"success": True} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "PUT" + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_delete_method( + self, mock_client_class, sample_openapi_spec + ): + """测试 DELETE 方法调用(走 GET/DELETE 分支)""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("deleteUser", {}) + + assert result == {"success": True} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "DELETE" + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_get_method( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步 GET 方法调用""" + mock_response = Mock() + mock_response.json.return_value = {"id": 123} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("getUser", {"id": "123"}) + + assert result == {"id": 123} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "GET" + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_text_response( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用返回 text 响应""" + mock_response = Mock() + mock_response.text = "plain text response" + mock_response.headers = {"content-type": "text/plain"} + mock_response.json.side_effect = ValueError("No JSON") + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("listUsers", {"limit": 10}) + + assert result == "plain text response" + + def test_parse_operations_no_operation_id(self): + """测试没有 operationId 时使用默认值""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/test": {"get": {"summary": "Test without operationId"}} + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert len(operations) == 1 + assert operations[0]["operation_id"] is not None + assert operations[0]["method"] == "GET" + + def test_parse_operations_invalid_path_item(self): + """测试无效的 path_item(非 dict)""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": {"/test": "invalid"}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert operations == [] + + def test_parse_operations_required_parameters(self): + """测试 required 参数的解析""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users/{id}": { + "get": { + "operationId": "getUserById", + "parameters": [{ + "name": "id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert len(operations) == 1 + op = operations[0] + assert op["operation_id"] == "getUserById" + assert op["input_schema"] is not None + assert "id" in op["input_schema"]["properties"] + assert "id" in op["input_schema"]["required"] + + +class AsyncMock(Mock): + """Async mock helper""" + + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py new file mode 100644 index 0000000..e1445ac --- /dev/null +++ b/tests/unittests/tool/test_tool.py @@ -0,0 +1,743 @@ +"""Tool 资源类和客户端单元测试 / Tool Resource Class and Client Unit Tests + +测试 Tool 资源类和 ToolClient 的功能。 +Tests functionality of Tool resource class and ToolClient. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from agentrun.tool.client import ToolClient +from agentrun.tool.model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) +from agentrun.tool.tool import Tool + + +class TestTool: + """测试 Tool 资源类""" + + def test_tool_attributes_default(self): + """测试 Tool 默认属性""" + tool = Tool() + assert tool.tool_id is None + assert tool.name is None + assert tool.tool_name is None + assert tool.description is None + assert tool.tool_type is None + assert tool.status is None + assert tool.code_configuration is None + assert tool.container_configuration is None + assert tool.mcp_config is None + assert tool.log_configuration is None + assert tool.network_config is None + assert tool.oss_mount_config is None + assert tool.data_endpoint is None + assert tool.protocol_spec is None + assert tool.protocol_type is None + assert tool.memory is None + assert tool.gpu is None + assert tool.timeout is None + assert tool.internet_access is None + assert tool.environment_variables is None + assert tool.created_time is None + assert tool.last_modified_time is None + assert tool.version_id is None + + def test_tool_attributes_with_values(self): + """测试 Tool 带值创建""" + tool = Tool( + tool_id="tool-123", + name="my-tool", + tool_name="my-tool", + description="A test tool", + tool_type="MCP", + status="READY", + data_endpoint="https://example.com/data", + memory=1024, + gpu="T4", + timeout=60, + internet_access=True, + environment_variables={"KEY": "value"}, + ) + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_name == "my-tool" + assert tool.description == "A test tool" + assert tool.tool_type == "MCP" + assert tool.status == "READY" + assert tool.data_endpoint == "https://example.com/data" + assert tool.memory == 1024 + assert tool.gpu == "T4" + assert tool.timeout == 60 + assert tool.internet_access is True + assert tool.environment_variables == {"KEY": "value"} + + def test_get_tool_type_mcp(self): + """测试获取 MCP 工具类型""" + tool = Tool(tool_type="MCP") + assert tool._get_tool_type() == ToolType.MCP + + def test_get_tool_type_functioncall(self): + """测试获取 FUNCTIONCALL 工具类型""" + tool = Tool(tool_type="FUNCTIONCALL") + assert tool._get_tool_type() == ToolType.FUNCTIONCALL + + def test_get_tool_type_invalid(self): + """测试获取无效工具类型""" + tool = Tool(tool_type="INVALID") + assert tool._get_tool_type() is None + + def test_get_tool_type_none(self): + """测试获取 None 工具类型""" + tool = Tool() + assert tool._get_tool_type() is None + + def test_get_mcp_endpoint_sse(self): + """测试获取 MCP SSE endpoint""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/sse" + + def test_get_mcp_endpoint_streamable(self): + """测试获取 MCP Streamable endpoint""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_STREAMABLE"), + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/mcp" + + def test_get_mcp_endpoint_default(self): + """测试获取 MCP endpoint(默认 SSE)""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/sse" + + def test_get_mcp_endpoint_no_name(self): + """测试没有 name 时获取 MCP endpoint""" + tool = Tool( + data_endpoint="https://example.com", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint is None + + def test_get_mcp_endpoint_no_data_endpoint(self): + """测试没有 data_endpoint 时获取 MCP endpoint""" + tool = Tool( + tool_name="my-tool", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint is None + + def test_from_inner_object(self): + """测试从内部对象创建 Tool""" + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = "T4" + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = {"KEY": "value"} + inner_tool.created_time = "2024-01-01T00:00:00Z" + inner_tool.last_modified_time = "2024-01-02T00:00:00Z" + inner_tool.version_id = "version-123" + inner_tool.protocol_spec = '{"openapi": "3.0.0"}' + inner_tool.protocol_type = "openapi" + + # Mock configurations + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "gpu": "T4", + "timeout": 60, + "internetAccess": True, + "environmentVariables": {"KEY": "value"}, + "createdTime": "2024-01-01T00:00:00Z", + "lastModifiedTime": "2024-01-02T00:00:00Z", + "versionId": "version-123", + "protocolSpec": '{"openapi": "3.0.0"}', + "protocolType": "openapi", + } + ) + + tool = Tool.from_inner_object(inner_tool) + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.description == "Test tool" + assert tool.tool_type == "MCP" + assert tool.status == "READY" + assert tool.data_endpoint == "https://example.com/data" + assert tool.memory == 1024 + assert tool.gpu == "T4" + assert tool.timeout == 60 + assert tool.internet_access is True + assert tool.environment_variables == {"KEY": "value"} + assert tool.created_time == "2024-01-01T00:00:00Z" + assert tool.last_modified_time == "2024-01-02T00:00:00Z" + assert tool.version_id == "version-123" + assert tool.protocol_spec == '{"openapi": "3.0.0"}' + assert tool.protocol_type == "openapi" + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp(self, mock_config_class, mock_mcp_session_class): + """测试获取 MCP 工具列表""" + mock_session = Mock() + mock_session.list_tools.return_value = [ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + tools = tool.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + def test_list_tools_functioncall(self, mock_openapi_client_class): + """测试获取 FUNCTIONCALL 工具列表""" + mock_client = Mock() + mock_client.list_tools.return_value = [ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + mock_openapi_client_class.return_value = mock_client + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + tools = tool.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + def test_list_tools_no_type(self): + """测试没有工具类型时获取工具列表""" + tool = Tool() + tools = tool.list_tools() + assert tools == [] + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp(self, mock_config_class, mock_mcp_session_class): + """测试调用 MCP 工具""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "success"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试调用 FUNCTIONCALL 工具""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "success"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + def test_call_tool_unsupported_type(self): + """测试调用不支持的类型工具""" + tool = Tool(tool_type="UNSUPPORTED") + with pytest.raises(ValueError, match="Unsupported tool type"): + tool.call_tool("tool1", {}) + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + async def test_list_tools_async_mcp( + self, mock_config_class, mock_mcp_session_class + ): + """测试异步获取 MCP 工具列表""" + mock_session = Mock() + mock_session.list_tools_async = AsyncMock( + return_value=[ + ToolInfo(name="tool1", description="Tool 1"), + ] + ) + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + tools = await tool.list_tools_async() + + assert len(tools) == 1 + assert tools[0].name == "tool1" + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_mcp( + self, mock_config_class, mock_mcp_session_class + ): + """测试异步调用 MCP 工具""" + mock_session = Mock() + mock_session.call_tool_async = AsyncMock( + return_value={"result": "success"} + ) + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + result = await tool.call_tool_async("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + +class TestToolClient: + """测试 ToolClient""" + + def test_client_init(self): + """测试客户端初始化""" + client = ToolClient() + assert client is not None + + @patch("agentrun.tool.client.ToolControlAPI") + def test_get(self, mock_control_api_class): + """测试获取工具""" + # Mock inner tool + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = None + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = None + inner_tool.created_time = None + inner_tool.last_modified_time = None + inner_tool.version_id = None + inner_tool.protocol_spec = None + inner_tool.protocol_type = None + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "timeout": 60, + "internetAccess": True, + } + ) + + mock_api = Mock() + mock_api.get_tool.return_value = inner_tool + mock_control_api_class.return_value = mock_api + + client = ToolClient() + tool = client.get(name="my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_type == "MCP" + mock_api.get_tool.assert_called_once_with(name="my-tool", config=None) + + @patch("agentrun.tool.client.ToolControlAPI") + async def test_get_async(self, mock_control_api_class): + """测试异步获取工具""" + # Mock inner tool + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = None + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = None + inner_tool.created_time = None + inner_tool.last_modified_time = None + inner_tool.version_id = None + inner_tool.protocol_spec = None + inner_tool.protocol_type = None + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "timeout": 60, + "internetAccess": True, + } + ) + + mock_api = Mock() + mock_api.get_tool_async = AsyncMock(return_value=inner_tool) + mock_control_api_class.return_value = mock_api + + client = ToolClient() + tool = await client.get_async(name="my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_type == "MCP" + mock_api.get_tool_async.assert_called_once_with( + name="my-tool", config=None + ) + + @patch("agentrun.tool.client.ToolControlAPI") + def test_get_http_error(self, mock_control_api_class): + """测试 get() 遇到 HTTPError 时的异常转换""" + from agentrun.utils.exception import HTTPError + + mock_resource_error = Exception("Resource not found") + mock_resource_error.message = "Resource not found" # type: ignore + mock_resource_error.error_code = "ResourceNotFound" # type: ignore + + mock_http_error = HTTPError.__new__(HTTPError) + mock_http_error.to_resource_error = Mock(return_value=mock_resource_error) # type: ignore + + mock_api = Mock() + mock_api.get_tool.side_effect = mock_http_error + mock_control_api_class.return_value = mock_api + + client = ToolClient() + + with pytest.raises(Exception) as exc_info: + client.get(name="my-tool") + assert exc_info.value.message == "Resource not found" # type: ignore + + @patch("agentrun.tool.client.ToolControlAPI") + async def test_get_async_http_error(self, mock_control_api_class): + """测试 get_async() 遇到 HTTPError 时的异常转换""" + from agentrun.utils.exception import HTTPError + + mock_resource_error = Exception("Resource not found") + mock_resource_error.message = "Resource not found" # type: ignore + mock_resource_error.error_code = "ResourceNotFound" # type: ignore + + mock_http_error = HTTPError.__new__(HTTPError) + mock_http_error.to_resource_error = Mock(return_value=mock_resource_error) # type: ignore + + mock_api = Mock() + mock_api.get_tool_async = AsyncMock(side_effect=mock_http_error) + mock_control_api_class.return_value = mock_api + + client = ToolClient() + + with pytest.raises(Exception) as exc_info: + await client.get_async(name="my-tool") + assert exc_info.value.message == "Resource not found" # type: ignore + + @patch("agentrun.tool.tool.Tool._Tool__get_client") + def test_get_by_name(self, mock_get_client): + """测试类方法 get_by_name""" + mock_client = Mock() + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_client.get.return_value = mock_tool + mock_get_client.return_value = mock_client + + tool = Tool.get_by_name("my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + mock_client.get.assert_called_once_with(name="my-tool") + + @patch("agentrun.tool.tool.Tool._Tool__get_client") + async def test_get_by_name_async(self, mock_get_client): + """测试类方法 get_by_name_async""" + mock_client = Mock() + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_client.get_async = AsyncMock(return_value=mock_tool) + mock_get_client.return_value = mock_client + + tool = await Tool.get_by_name_async("my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + mock_client.get_async.assert_called_once_with(name="my-tool") + + @patch("agentrun.tool.tool.Tool.get_by_name") + def test_get_sync(self, mock_get_by_name): + """测试实例方法 get()""" + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_get_by_name.return_value = mock_tool + + tool = Tool(tool_name="my-tool") + result = tool.get() + + assert result.tool_id == "tool-123" + mock_get_by_name.assert_called_once_with(name="my-tool", config=None) + + def test_get_sync_no_name(self): + """测试 get() 没有 name 时抛出 ValueError""" + tool = Tool() + + with pytest.raises(ValueError, match="Tool name is required"): + tool.get() + + @patch("agentrun.tool.tool.Tool.get_by_name_async") + async def test_get_async_method(self, mock_get_by_name_async): + """测试实例方法 get_async()""" + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_get_by_name_async.return_value = mock_tool + + tool = Tool(tool_name="my-tool") + result = await tool.get_async() + + assert result.tool_id == "tool-123" + mock_get_by_name_async.assert_called_once_with( + name="my-tool", config=None + ) + + def test_get_async_no_name(self): + """测试 get_async() 没有 name 时抛出 ValueError""" + tool = Tool() + + with pytest.raises(ValueError, match="Tool name is required"): + import asyncio + + asyncio.run(tool.get_async()) + + def test_get_functioncall_server_url(self): + """测试 _get_functioncall_server_url 有 data_endpoint""" + tool = Tool( + tool_name="my-tool", data_endpoint="https://example.com/data" + ) + url = tool._get_functioncall_server_url() + + assert url == "https://example.com/data/tools/my-tool" + + def test_get_functioncall_server_url_no_endpoint(self): + """测试 _get_functioncall_server_url 没有 data_endpoint 和 name 时返回 None""" + tool = Tool() + url = tool._get_functioncall_server_url() + + assert url is None + + @patch("agentrun.utils.config.Config") + async def test_list_tools_async_mcp_no_endpoint(self, mock_config_class): + """测试 MCP 类型但没有 endpoint 时返回空列表""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + tools = await tool.list_tools_async() + + assert tools == [] + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + async def test_list_tools_async_functioncall( + self, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 list_tools_async""" + mock_client = Mock() + mock_client.list_tools_async = AsyncMock( + return_value=[ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + ) + mock_openapi_client_class.return_value = mock_client + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + tools = await tool.list_tools_async() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + async def test_list_tools_async_no_type(self): + """测试没有类型时 list_tools_async 返回空列表""" + tool = Tool() + tools = await tool.list_tools_async() + assert tools == [] + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 call_tool_async""" + mock_client = Mock() + mock_client.call_tool_async = AsyncMock( + return_value={"result": "success"} + ) + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = await tool.call_tool_async("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + async def test_call_tool_async_mcp_no_endpoint(self): + """测试 MCP 类型但没有 endpoint 时 call_tool_async 抛出 ValueError""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + with pytest.raises(ValueError, match="MCP endpoint not available"): + await tool.call_tool_async("tool1", {"param": "value"}) + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 call_tool(同步)""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "success"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + def test_call_tool_mcp_no_endpoint(self): + """测试 MCP 类型但没有 endpoint 时 call_tool 抛出 ValueError""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + with pytest.raises(ValueError, match="MCP endpoint not available"): + tool.call_tool("tool1", {"param": "value"}) + + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp_no_endpoint(self, mock_config_class): + """测试 MCP 类型但没有 endpoint 时 list_tools 返回空列表""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + tools = tool.list_tools() + + assert tools == [] From 5952fff3018a0c1615e722471323a84533829d63 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Sun, 29 Mar 2026 17:00:40 +0800 Subject: [PATCH 02/10] =?UTF-8?q?feat(tool):=20=E6=B7=BB=E5=8A=A0Skill?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E4=B8=8B=E8=BD=BD=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 此提交添加了对SKILL类型工具的支持,包括获取下载URL和异步下载解压的功能。同时更新了相关单元测试。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/tool/__tool_async_template.py | 85 ++++++++ agentrun/tool/model.py | 2 + agentrun/tool/tool.py | 143 +++++++++++++ tests/unittests/tool/test_tool.py | 275 +++++++++++++++++++++++++ 4 files changed, 505 insertions(+) diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py index 41f5bca..74dd049 100644 --- a/agentrun/tool/__tool_async_template.py +++ b/agentrun/tool/__tool_async_template.py @@ -4,8 +4,13 @@ Provides object-oriented wrapper and complete lifecycle management for tool resources. """ +import io +import os +import shutil from typing import Any, Dict, List, Optional +import zipfile +import httpx import pydash from agentrun.utils.config import Config @@ -309,3 +314,83 @@ async def call_tool_async( return result raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def _get_skill_download_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 Skill 工具的下载 URL / Get download URL for Skill tools + + 根据 data_endpoint 和 tool_name 构造下载地址。 + Constructs download URL from data_endpoint and tool_name. + + Returns: + Optional[str]: 下载 URL / Download URL + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}/download" + + async def download_skill_async( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """异步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory asynchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + async with httpx.AsyncClient( + timeout=300, follow_redirects=True + ) as http_client: + response = await http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir diff --git a/agentrun/tool/model.py b/agentrun/tool/model.py index 994bc51..2d8cc81 100644 --- a/agentrun/tool/model.py +++ b/agentrun/tool/model.py @@ -17,6 +17,8 @@ class ToolType(str, Enum): """MCP 协议工具 / MCP Protocol Tool""" FUNCTIONCALL = "FUNCTIONCALL" """函数调用工具 / Function Call Tool""" + SKILL = "SKILL" + """技能工具 / Skill Tool""" class McpConfig(BaseModel): diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py index bdfaa5d..b9eb565 100644 --- a/agentrun/tool/tool.py +++ b/agentrun/tool/tool.py @@ -14,8 +14,13 @@ Provides object-oriented wrapper and complete lifecycle management for tool resources. """ +import io +import os +import shutil from typing import Any, Dict, List, Optional +import zipfile +import httpx import pydash from agentrun.utils.config import Config @@ -438,3 +443,141 @@ def call_tool( return result raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def _get_skill_download_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 Skill 工具的下载 URL / Get download URL for Skill tools + + 根据 data_endpoint 和 tool_name 构造下载地址。 + Constructs download URL from data_endpoint and tool_name. + + Returns: + Optional[str]: 下载 URL / Download URL + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}/download" + + async def download_skill_async( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """异步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory asynchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + async with httpx.AsyncClient( + timeout=300, follow_redirects=True + ) as http_client: + response = await http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir + + def download_skill( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """同步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory synchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + with httpx.Client(timeout=300, follow_redirects=True) as http_client: + response = http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py index e1445ac..e3e2bf2 100644 --- a/tests/unittests/tool/test_tool.py +++ b/tests/unittests/tool/test_tool.py @@ -380,6 +380,281 @@ async def test_call_tool_async_mcp( assert result == {"result": "success"} + # ==================== SKILL 相关测试 ==================== + + def test_get_tool_type_skill(self): + """测试获取 SKILL 工具类型""" + tool = Tool(tool_type="SKILL") + assert tool._get_tool_type() == ToolType.SKILL + + def test_get_skill_download_url_with_data_endpoint(self): + """测试使用 data_endpoint 构造 skill 下载 URL""" + tool = Tool( + tool_name="my-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/my-skill/download" + + def test_get_skill_download_url_uses_name_fallback(self): + """测试 tool_name 为空时使用 name 作为 fallback""" + tool = Tool( + name="fallback-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/fallback-skill/download" + + def test_get_skill_download_url_tool_name_takes_priority(self): + """测试 tool_name 优先于 name""" + tool = Tool( + tool_name="primary-skill", + name="fallback-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/primary-skill/download" + + @patch("agentrun.tool.tool.Config") + def test_get_skill_download_url_config_fallback(self, mock_config_class): + """测试 data_endpoint 为空时从 Config 获取""" + mock_config = Mock() + mock_config._data_endpoint = "https://config-endpoint.com" + mock_config_class.with_configs.return_value = mock_config + + tool = Tool(tool_name="my-skill") + url = tool._get_skill_download_url() + assert url == "https://config-endpoint.com/tools/my-skill/download" + + def test_get_skill_download_url_no_name(self): + """测试没有 name 时返回 None""" + tool = Tool(data_endpoint="https://example.com") + url = tool._get_skill_download_url() + assert url is None + + @patch("agentrun.tool.tool.Config") + def test_get_skill_download_url_no_endpoint(self, mock_config_class): + """测试没有 data_endpoint 且 Config 也没有时返回 None""" + mock_config = Mock() + mock_config._data_endpoint = None + mock_config_class.with_configs.return_value = mock_config + + tool = Tool(tool_name="my-skill") + url = tool._get_skill_download_url() + assert url is None + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_success( + self, mock_config_class, mock_async_client_class + ): + """测试成功下载并解压 skill 包""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {"Authorization": "Bearer token"} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("SKILL.md", "# Test Skill") + zf.writestr("main.py", "print('hello')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = await tool.download_skill_async(target_dir=tmp_dir) + + expected_dir = os.path.join(tmp_dir, "test-skill") + assert result == expected_dir + assert os.path.exists(expected_dir) + assert os.path.isfile(os.path.join(expected_dir, "SKILL.md")) + assert os.path.isfile(os.path.join(expected_dir, "main.py")) + + with open(os.path.join(expected_dir, "SKILL.md")) as f: + assert f.read() == "# Test Skill" + finally: + shutil.rmtree(tmp_dir) + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_overwrites_existing( + self, mock_config_class, mock_async_client_class + ): + """测试下载 skill 时覆盖已存在的目录""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("new_file.txt", "new content") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + existing_dir = os.path.join(tmp_dir, "test-skill") + os.makedirs(existing_dir) + with open(os.path.join(existing_dir, "old_file.txt"), "w") as f: + f.write("old content") + + result = await tool.download_skill_async(target_dir=tmp_dir) + + assert os.path.isfile(os.path.join(result, "new_file.txt")) + assert not os.path.exists(os.path.join(result, "old_file.txt")) + finally: + shutil.rmtree(tmp_dir) + + async def test_download_skill_async_wrong_type(self): + """测试非 SKILL 类型调用 download_skill_async 抛出 ValueError""" + tool = Tool(tool_type="MCP", tool_name="my-tool") + + with pytest.raises(ValueError, match="only available for SKILL"): + await tool.download_skill_async() + + async def test_download_skill_async_no_url(self): + """测试无法构造下载 URL 时抛出 ValueError""" + tool = Tool(tool_type="SKILL") + + with pytest.raises(ValueError, match="Cannot construct download URL"): + await tool.download_skill_async() + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_http_error( + self, mock_config_class, mock_async_client_class + ): + """测试下载失败时抛出 HTTPStatusError""" + import httpx + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", + request=Mock(), + response=Mock(status_code=404), + ) + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + with pytest.raises(httpx.HTTPStatusError): + await tool.download_skill_async() + + @patch("agentrun.tool.tool.httpx.Client") + @patch("agentrun.tool.tool.Config") + def test_download_skill_sync_success( + self, mock_config_class, mock_client_class + ): + """测试同步版本 download_skill 成功""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("skill.py", "print('skill')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="sync-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = tool.download_skill(target_dir=tmp_dir) + + expected_dir = os.path.join(tmp_dir, "sync-skill") + assert result == expected_dir + assert os.path.isfile(os.path.join(expected_dir, "skill.py")) + finally: + shutil.rmtree(tmp_dir) + + def test_download_skill_sync_wrong_type(self): + """测试同步版本非 SKILL 类型抛出 ValueError""" + tool = Tool(tool_type="FUNCTIONCALL", tool_name="my-tool") + + with pytest.raises(ValueError, match="only available for SKILL"): + tool.download_skill() + class TestToolClient: """测试 ToolClient""" From 7d33985c2fa9cbecb5d6d722d2f7dc97eb116601 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Sun, 29 Mar 2026 18:24:35 +0800 Subject: [PATCH 03/10] =?UTF-8?q?feat(integrations):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20tool=5Fresource=20=E5=87=BD=E6=95=B0=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E6=A1=86=E6=9E=B6=E9=9B=86=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在多个集成模块中添加了 `tool_resource` 函数的支持,允许用户将 ToolResource 封装为不同 AI 框架所需的工具格式,包括 CrewAI、LangChain、PydanticAI、LangGraph 和 AgentScope。同时更新了相关初始化文件以导出新功能。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/integration/agentscope/__init__.py | 9 +++- agentrun/integration/agentscope/builtin.py | 20 ++++++++ agentrun/integration/builtin/__init__.py | 2 + agentrun/integration/crewai/__init__.py | 8 ++- agentrun/integration/crewai/builtin.py | 20 ++++++++ agentrun/integration/google_adk/__init__.py | 9 +++- agentrun/integration/google_adk/builtin.py | 20 ++++++++ agentrun/integration/langchain/__init__.py | 9 +++- agentrun/integration/langchain/builtin.py | 20 ++++++++ agentrun/integration/langgraph/__init__.py | 9 +++- agentrun/integration/langgraph/builtin.py | 20 ++++++++ agentrun/integration/pydantic_ai/__init__.py | 9 +++- agentrun/integration/pydantic_ai/builtin.py | 20 ++++++++ agentrun/integration/utils/tool.py | 52 ++++++++++++++++++++ 14 files changed, 221 insertions(+), 6 deletions(-) diff --git a/agentrun/integration/agentscope/__init__.py b/agentrun/integration/agentscope/__init__.py index d9e108f..b59b649 100644 --- a/agentrun/integration/agentscope/__init__.py +++ b/agentrun/integration/agentscope/__init__.py @@ -3,11 +3,18 @@ 提供 AgentRun 模型与沙箱工具的 AgentScope 适配入口。 / 提供 AgentRun 模型with沙箱工具的 AgentScope 适配入口。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/agentscope/builtin.py b/agentrun/integration/agentscope/builtin.py index 1a94e7f..5f271c4 100644 --- a/agentrun/integration/agentscope/builtin.py +++ b/agentrun/integration/agentscope/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 AgentScope 工具列表。 / AgentScope Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_agentscope( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/builtin/__init__.py b/agentrun/integration/builtin/__init__.py index 49f4258..aa914c6 100644 --- a/agentrun/integration/builtin/__init__.py +++ b/agentrun/integration/builtin/__init__.py @@ -7,12 +7,14 @@ from .knowledgebase import knowledgebase_toolset from .model import model, ModelArgs from .sandbox import sandbox_toolset +from .tool_resource import tool_resource from .toolset import toolset __all__ = [ "model", "ModelArgs", "toolset", + "tool_resource", "sandbox_toolset", "knowledgebase_toolset", ] diff --git a/agentrun/integration/crewai/__init__.py b/agentrun/integration/crewai/__init__.py index 46ab61d..bd0743f 100644 --- a/agentrun/integration/crewai/__init__.py +++ b/agentrun/integration/crewai/__init__.py @@ -4,10 +4,16 @@ CrewAI 与 LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 / CrewAI with LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, +) __all__ = [ "model", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/crewai/builtin.py b/agentrun/integration/crewai/builtin.py index beda5a7..5ee5013 100644 --- a/agentrun/integration/crewai/builtin.py +++ b/agentrun/integration/crewai/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 CrewAI 工具列表。 / CrewAI Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_crewai( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/google_adk/__init__.py b/agentrun/integration/google_adk/__init__.py index 372f64d..9028cb1 100644 --- a/agentrun/integration/google_adk/__init__.py +++ b/agentrun/integration/google_adk/__init__.py @@ -3,11 +3,18 @@ 提供与 Google Agent Development Kit 的模型与沙箱工具集成。 / 提供with Google Agent Development Kit 的模型with沙箱工具集成。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/google_adk/builtin.py b/agentrun/integration/google_adk/builtin.py index e655f8f..8642e27 100644 --- a/agentrun/integration/google_adk/builtin.py +++ b/agentrun/integration/google_adk/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 Google ADK 工具列表。 / Google ADK Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_google_adk( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/langchain/__init__.py b/agentrun/integration/langchain/__init__.py index 3e48086..b703d56 100644 --- a/agentrun/integration/langchain/__init__.py +++ b/agentrun/integration/langchain/__init__.py @@ -20,7 +20,13 @@ AgentRunConverter, ) # 向后兼容 -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, + toolset, +) __all__ = [ "AgentRunConverter", @@ -28,4 +34,5 @@ "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/langchain/builtin.py b/agentrun/integration/langchain/builtin.py index c18e479..98fbe69 100644 --- a/agentrun/integration/langchain/builtin.py +++ b/agentrun/integration/langchain/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 LangChain ``StructuredTool`` 列表。 / LangChain Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_langchain( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/langgraph/__init__.py b/agentrun/integration/langgraph/__init__.py index 71fa409..9a3115c 100644 --- a/agentrun/integration/langgraph/__init__.py +++ b/agentrun/integration/langgraph/__init__.py @@ -25,7 +25,13 @@ """ from .agent_converter import AgentRunConverter -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, + toolset, +) __all__ = [ "AgentRunConverter", @@ -33,4 +39,5 @@ "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/langgraph/builtin.py b/agentrun/integration/langgraph/builtin.py index a9efaae..83153c5 100644 --- a/agentrun/integration/langgraph/builtin.py +++ b/agentrun/integration/langgraph/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 LangGraph 工具列表。 / LangGraph Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_langgraph( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/pydantic_ai/__init__.py b/agentrun/integration/pydantic_ai/__init__.py index 5a04376..34fdc38 100644 --- a/agentrun/integration/pydantic_ai/__init__.py +++ b/agentrun/integration/pydantic_ai/__init__.py @@ -3,11 +3,18 @@ 提供 AgentRun 模型与沙箱工具的 PydanticAI 适配入口。 / 提供 AgentRun 模型with沙箱工具的 PydanticAI 适配入口。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", ] diff --git a/agentrun/integration/pydantic_ai/builtin.py b/agentrun/integration/pydantic_ai/builtin.py index a5e5b05..952130d 100644 --- a/agentrun/integration/pydantic_ai/builtin.py +++ b/agentrun/integration/pydantic_ai/builtin.py @@ -14,10 +14,12 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +52,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 PydanticAI 工具列表。 / PydanticAI Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_pydantic_ai( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index bde72a5..c479846 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -47,6 +47,7 @@ ) if TYPE_CHECKING: + from agentrun.tool.tool import Tool as ToolResource from agentrun.toolset import ToolSet from agentrun.utils.log import logger @@ -859,6 +860,57 @@ def from_agentrun_toolset( return CommonToolSet(integration_tools) + @classmethod + def from_agentrun_tool( + cls, + tool_resource: "ToolResource", + config: Optional[Any] = None, + refresh: bool = False, + ) -> "CommonToolSet": + """从 AgentRun ToolResource 创建通用工具集 / Create CommonToolSet from AgentRun ToolResource + + Args: + tool_resource: agentrun.tool.tool.Tool (ToolResource) 实例 / ToolResource instance + config: 额外的请求配置,调用工具时会自动合并 / Extra request config, merged automatically when calling tools + refresh: 是否先刷新最新信息 / Whether to refresh latest info first + + Returns: + 通用 ToolSet 实例,可直接调用 .to_openai_function()、.to_langchain() 等 + CommonToolSet instance, can directly call .to_openai_function(), .to_langchain(), etc. + + Example: + >>> from agentrun import ToolResource, ToolResourceClient + >>> from agentrun.integration.utils.tool import CommonToolSet + >>> + >>> client = ToolResourceClient() + >>> tool = client.get(name="my-tool") + >>> common_toolset = CommonToolSet.from_agentrun_tool(tool) + >>> + >>> openai_tools = common_toolset.to_openai_function() + >>> langchain_tools = common_toolset.to_langchain() + """ + + if refresh: + tool_resource = tool_resource.get(config=config) + + tools_meta = tool_resource.list_tools(config=config) or [] + integration_tools: List[Tool] = [] + seen_names: set = set() + + for meta in tools_meta: + tool = _build_tool_from_meta(tool_resource, meta, config) + if tool: + if tool.name in seen_names: + logger.warning( + f"Duplicate tool name '{tool.name}' detected, " + "second occurrence will be skipped" + ) + continue + seen_names.add(tool.name) + integration_tools.append(tool) + + return CommonToolSet(integration_tools) + def to_openai_function( self, prefix: Optional[str] = None, From 2757f2ff178924770ebb95265280014219fad513 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Sun, 29 Mar 2026 20:10:48 +0800 Subject: [PATCH 04/10] =?UTF-8?q?feat(integrations):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20skill=5Ftools=20=E9=9B=86=E6=88=90=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在多个集成框架中添加了 `skill_tools` 函数,用于将 Skill 封装为不同 AI 框架的工具列表,包括 CrewAI、LangChain、PydanticAI、LangGraph、AgentScope 和 Google ADK。同时更新了相关的初始化文件和工具加载器。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/integration/agentscope/__init__.py | 2 + agentrun/integration/agentscope/builtin.py | 16 + agentrun/integration/builtin/__init__.py | 2 + agentrun/integration/builtin/skill.py | 16 + agentrun/integration/builtin/tool_resource.py | 48 + agentrun/integration/crewai/__init__.py | 2 + agentrun/integration/crewai/builtin.py | 16 + agentrun/integration/google_adk/__init__.py | 2 + agentrun/integration/google_adk/builtin.py | 16 + agentrun/integration/langchain/__init__.py | 2 + agentrun/integration/langchain/builtin.py | 20 + agentrun/integration/langgraph/__init__.py | 6 +- agentrun/integration/langgraph/builtin.py | 20 + agentrun/integration/pydantic_ai/__init__.py | 2 + agentrun/integration/pydantic_ai/builtin.py | 16 + agentrun/integration/utils/skill_loader.py | 459 +++++++++ .../integration/test_skill_loader.py | 911 ++++++++++++++++++ 17 files changed, 1554 insertions(+), 2 deletions(-) create mode 100644 agentrun/integration/builtin/skill.py create mode 100644 agentrun/integration/builtin/tool_resource.py create mode 100644 agentrun/integration/utils/skill_loader.py create mode 100644 tests/unittests/integration/test_skill_loader.py diff --git a/agentrun/integration/agentscope/__init__.py b/agentrun/integration/agentscope/__init__.py index b59b649..91ba476 100644 --- a/agentrun/integration/agentscope/__init__.py +++ b/agentrun/integration/agentscope/__init__.py @@ -7,6 +7,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, toolset, ) @@ -17,4 +18,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/agentscope/builtin.py b/agentrun/integration/agentscope/builtin.py index 5f271c4..17f7aff 100644 --- a/agentrun/integration/agentscope/builtin.py +++ b/agentrun/integration/agentscope/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -106,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 AgentScope 工具列表。 / AgentScope Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_agentscope( + prefix=prefix, + ) diff --git a/agentrun/integration/builtin/__init__.py b/agentrun/integration/builtin/__init__.py index aa914c6..8b7f084 100644 --- a/agentrun/integration/builtin/__init__.py +++ b/agentrun/integration/builtin/__init__.py @@ -7,6 +7,7 @@ from .knowledgebase import knowledgebase_toolset from .model import model, ModelArgs from .sandbox import sandbox_toolset +from .skill import skill_tools from .tool_resource import tool_resource from .toolset import toolset @@ -17,4 +18,5 @@ "tool_resource", "sandbox_toolset", "knowledgebase_toolset", + "skill_tools", ] diff --git a/agentrun/integration/builtin/skill.py b/agentrun/integration/builtin/skill.py new file mode 100644 index 0000000..d0f057c --- /dev/null +++ b/agentrun/integration/builtin/skill.py @@ -0,0 +1,16 @@ +"""内置 Skill 集成函数 / Built-in Skill Integration Functions + +提供快速创建 Skill 工具集对象的便捷函数。 +Provides convenient functions for quickly creating Skill toolset objects. +""" + +from typing import List, Optional, Union + +from agentrun.integration.utils.skill_loader import skill_tools as _skill_tools +from agentrun.integration.utils.tool import CommonToolSet +from agentrun.utils.config import Config + +# Re-export for convenience +skill_tools = _skill_tools + +__all__ = ["skill_tools"] diff --git a/agentrun/integration/builtin/tool_resource.py b/agentrun/integration/builtin/tool_resource.py new file mode 100644 index 0000000..18d01b8 --- /dev/null +++ b/agentrun/integration/builtin/tool_resource.py @@ -0,0 +1,48 @@ +"""内置 ToolResource 集成函数 / Built-in ToolResource Integration Functions + +提供快速创建通用工具集对象的便捷函数(基于新版 Tool 模块)。 +Provides convenient functions for quickly creating common toolset objects (based on new Tool module). +""" + +from typing import Optional, Union + +from agentrun.integration.utils.tool import CommonToolSet +from agentrun.tool.client import ToolClient +from agentrun.tool.tool import Tool as ToolResourceType +from agentrun.utils.config import Config + + +def tool_resource( + input: Union[str, ToolResourceType], config: Optional[Config] = None +) -> CommonToolSet: + """将 ToolResource 封装为通用工具集 / Wrap ToolResource as CommonToolSet + + 支持从工具名称或 ToolResource 实例创建通用工具集。 + Supports creating CommonToolSet from tool name or ToolResource instance. + + Args: + input: 工具名称或 ToolResource 实例 / Tool name or ToolResource instance + config: 配置对象 / Configuration object + + Returns: + CommonToolSet: 通用工具集实例 / CommonToolSet instance + + Examples: + >>> # 从工具名称创建 / Create from tool name + >>> ts = tool_resource("my-tool") + >>> + >>> # 从 ToolResource 实例创建 / Create from ToolResource instance + >>> tool = ToolClient().get(name="my-tool") + >>> ts = tool_resource(tool) + >>> + >>> # 转换为 LangChain 工具 / Convert to LangChain tools + >>> lc_tools = ts.to_langchain() + """ + + resource = ( + input + if isinstance(input, ToolResourceType) + else ToolClient().get(name=input, config=config) + ) + + return CommonToolSet.from_agentrun_tool(resource, config=config) diff --git a/agentrun/integration/crewai/__init__.py b/agentrun/integration/crewai/__init__.py index bd0743f..f2e581e 100644 --- a/agentrun/integration/crewai/__init__.py +++ b/agentrun/integration/crewai/__init__.py @@ -8,6 +8,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, ) @@ -16,4 +17,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/crewai/builtin.py b/agentrun/integration/crewai/builtin.py index 5ee5013..1c8aadb 100644 --- a/agentrun/integration/crewai/builtin.py +++ b/agentrun/integration/crewai/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -106,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 CrewAI 工具列表。 / CrewAI Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_crewai( + prefix=prefix, + ) diff --git a/agentrun/integration/google_adk/__init__.py b/agentrun/integration/google_adk/__init__.py index 9028cb1..fad29ba 100644 --- a/agentrun/integration/google_adk/__init__.py +++ b/agentrun/integration/google_adk/__init__.py @@ -7,6 +7,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, toolset, ) @@ -17,4 +18,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/google_adk/builtin.py b/agentrun/integration/google_adk/builtin.py index 8642e27..9622565 100644 --- a/agentrun/integration/google_adk/builtin.py +++ b/agentrun/integration/google_adk/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -106,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 Google ADK 工具列表。 / Google ADK Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_google_adk( + prefix=prefix, + ) diff --git a/agentrun/integration/langchain/__init__.py b/agentrun/integration/langchain/__init__.py index b703d56..9cad7e6 100644 --- a/agentrun/integration/langchain/__init__.py +++ b/agentrun/integration/langchain/__init__.py @@ -24,6 +24,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, toolset, ) @@ -35,4 +36,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/langchain/builtin.py b/agentrun/integration/langchain/builtin.py index 98fbe69..9c6b9ab 100644 --- a/agentrun/integration/langchain/builtin.py +++ b/agentrun/integration/langchain/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -112,3 +113,22 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 LangChain ``StructuredTool`` 列表。 / LangChain Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_langchain( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) diff --git a/agentrun/integration/langgraph/__init__.py b/agentrun/integration/langgraph/__init__.py index 9a3115c..141e6c6 100644 --- a/agentrun/integration/langgraph/__init__.py +++ b/agentrun/integration/langgraph/__init__.py @@ -15,8 +15,8 @@ >>> from agentrun.integration.langgraph import AgentRunConverter >>> >>> async for event in agent.astream_events(input_data, version="v2"): - ... for item in AgentRunConverter.to_agui_events(event): - ... yield item + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item 支持多种调用方式: - agent.astream_events(input, version="v2") - 支持 token by token @@ -29,6 +29,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, toolset, ) @@ -40,4 +41,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/langgraph/builtin.py b/agentrun/integration/langgraph/builtin.py index 83153c5..5b06979 100644 --- a/agentrun/integration/langgraph/builtin.py +++ b/agentrun/integration/langgraph/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -106,3 +107,22 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 LangGraph 工具列表。 / LangGraph Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_langgraph( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) diff --git a/agentrun/integration/pydantic_ai/__init__.py b/agentrun/integration/pydantic_ai/__init__.py index 34fdc38..0af972e 100644 --- a/agentrun/integration/pydantic_ai/__init__.py +++ b/agentrun/integration/pydantic_ai/__init__.py @@ -7,6 +7,7 @@ knowledgebase_toolset, model, sandbox_toolset, + skill_tools, tool_resource, toolset, ) @@ -17,4 +18,5 @@ "sandbox_toolset", "knowledgebase_toolset", "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/pydantic_ai/builtin.py b/agentrun/integration/pydantic_ai/builtin.py index 952130d..eb235f9 100644 --- a/agentrun/integration/pydantic_ai/builtin.py +++ b/agentrun/integration/pydantic_ai/builtin.py @@ -14,6 +14,7 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool @@ -106,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 PydanticAI 工具列表。 / PydanticAI Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_pydantic_ai( + prefix=prefix, + ) diff --git a/agentrun/integration/utils/skill_loader.py b/agentrun/integration/utils/skill_loader.py new file mode 100644 index 0000000..ca178ef --- /dev/null +++ b/agentrun/integration/utils/skill_loader.py @@ -0,0 +1,459 @@ +"""Skill 加载器模块 / Skill Loader Module + +提供从本地 .skills 目录加载 Skill 包的能力,并构造 load_skills 工具供 Agent 运行时调用。 +Provides the ability to load Skill packages from a local .skills directory +and construct a load_skills tool for Agent runtime invocation. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +import os +import re +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from agentrun.integration.utils.tool import CommonToolSet, Tool, ToolParameter +from agentrun.utils.log import logger + +if TYPE_CHECKING: + from agentrun.tool.tool import Tool as ToolResource + from agentrun.utils.config import Config + + +@dataclass +class SkillInfo: + """Skill 摘要信息 / Skill summary information + + Attributes: + name: skill 名称 / skill name + description: skill 描述 / skill description + version: skill 版本 / skill version + path: 本地目录路径 / local directory path + """ + + name: str + description: str = "" + version: str = "" + path: str = "" + + +@dataclass +class SkillDetail(SkillInfo): + """Skill 详细信息 / Skill detail information + + Attributes: + instruction: SKILL.md 全文内容 / full content of SKILL.md + files: 目录下的文件/文件夹列表 / list of files/folders in the directory + """ + + instruction: str = "" + files: List[str] = field(default_factory=list) + + +def _parse_frontmatter(content: str) -> Dict[str, str]: + """解析 SKILL.md 的 YAML frontmatter / Parse YAML frontmatter from SKILL.md + + 使用简单的正则解析,避免引入 PyYAML 依赖。 + Uses simple regex parsing to avoid introducing PyYAML dependency. + + Args: + content: SKILL.md 文件内容 / SKILL.md file content + + Returns: + 解析出的 key-value 字典 / parsed key-value dictionary + """ + match = re.match(r"^---\s*\n(.*?)\n---", content, re.DOTALL) + if not match: + return {} + result: Dict[str, str] = {} + for line in match.group(1).split("\n"): + line = line.strip() + if not line or ":" not in line: + continue + key, _, value = line.partition(":") + key = key.strip() + value = value.strip().strip('"').strip("'") + if key: + result[key] = value + return result + + +class SkillLoader: + """Skill 加载器 / Skill Loader + + 负责扫描本地 .skills 目录、解析 skill 元信息、读取 skill 指令内容, + 并构造 load_skills 工具供 Agent 运行时调用。 + + Responsible for scanning the local .skills directory, parsing skill metadata, + reading skill instruction content, and constructing the load_skills tool + for Agent runtime invocation. + + Args: + skills_dir: 本地 skill 目录路径 / local skill directory path + remote_skill_names: 需要从远程下载的 skill 名称列表 / list of remote skill names to download + config: 配置对象 / configuration object + """ + + def __init__( + self, + skills_dir: str = ".skills", + remote_skill_names: Optional[List[str]] = None, + config: Optional["Config"] = None, + ): + self._skills_dir = skills_dir + self._remote_skill_names = remote_skill_names or [] + self._config = config + self._skills_cache: Optional[List[SkillInfo]] = None + + def _ensure_skills_available(self) -> None: + """确保远程 skill 已下载到本地 / Ensure remote skills are downloaded locally + + 对每个 remote_skill_name,检查本地是否已存在对应目录, + 不存在则通过 ToolClient 下载。 + + For each remote_skill_name, check if the local directory exists, + download via ToolClient if not. + """ + if not self._remote_skill_names: + return + + from agentrun.tool.client import ToolClient + + for skill_name in self._remote_skill_names: + skill_path = os.path.join(self._skills_dir, skill_name) + if os.path.isdir(skill_path): + logger.debug( + f"Skill '{skill_name}' already exists at {skill_path}, " + "skipping download" + ) + continue + logger.info( + f"Downloading remote skill '{skill_name}' to {self._skills_dir}" + ) + tool_resource = ToolClient().get( + name=skill_name, config=self._config + ) + tool_resource.download_skill( + target_dir=self._skills_dir, config=self._config + ) + + def _parse_skill_metadata(self, skill_dir: str) -> SkillInfo: + """解析 skill 元信息 / Parse skill metadata + + 按以下优先级获取 skill 的 name 和 description: + 1. SKILL.md 的 YAML frontmatter + 2. package.json + 3. 目录名作为 name,description 为空字符串 + + Priority for getting skill name and description: + 1. SKILL.md YAML frontmatter + 2. package.json + 3. Directory name as name, empty string as description + + Args: + skill_dir: skill 目录的完整路径 / full path to skill directory + + Returns: + SkillInfo 实例 / SkillInfo instance + """ + dir_name = os.path.basename(skill_dir) + name = dir_name + description = "" + version = "" + + skill_md_path = os.path.join(skill_dir, "SKILL.md") + if os.path.isfile(skill_md_path): + try: + with open(skill_md_path, "r", encoding="utf-8") as file_handle: + content = file_handle.read() + frontmatter = _parse_frontmatter(content) + if frontmatter.get("name"): + name = frontmatter["name"] + if frontmatter.get("description"): + description = frontmatter["description"] + if frontmatter.get("version"): + version = frontmatter["version"] + if name != dir_name or description or version: + return SkillInfo( + name=name, + description=description, + version=version, + path=skill_dir, + ) + except (OSError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read SKILL.md in {skill_dir}: {error}" + ) + + package_json_path = os.path.join(skill_dir, "package.json") + if os.path.isfile(package_json_path): + try: + with open( + package_json_path, "r", encoding="utf-8" + ) as file_handle: + package_data = json.load(file_handle) + if package_data.get("name"): + name = package_data["name"] + if package_data.get("description"): + description = package_data["description"] + if package_data.get("version"): + version = package_data["version"] + except (OSError, json.JSONDecodeError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read package.json in {skill_dir}: {error}" + ) + + return SkillInfo( + name=name, description=description, version=version, path=skill_dir + ) + + def scan_skills(self) -> List[SkillInfo]: + """扫描 .skills/ 目录,返回所有 skill 的摘要信息 / Scan .skills/ directory and return all skill summaries + + Returns: + SkillInfo 列表 / list of SkillInfo + """ + if self._skills_cache is not None: + return self._skills_cache + + self._ensure_skills_available() + + if not os.path.isdir(self._skills_dir): + self._skills_cache = [] + return self._skills_cache + + skills: List[SkillInfo] = [] + try: + entries = sorted(os.listdir(self._skills_dir)) + except OSError as error: + logger.warning( + f"Failed to list skills directory {self._skills_dir}: {error}" + ) + self._skills_cache = [] + return self._skills_cache + + for entry in entries: + entry_path = os.path.join(self._skills_dir, entry) + if os.path.isdir(entry_path) and not entry.startswith("."): + skill_info = self._parse_skill_metadata(entry_path) + skills.append(skill_info) + + self._skills_cache = skills + return self._skills_cache + + def load_skill(self, name: str) -> Optional[SkillDetail]: + """加载指定 skill 的详细信息 / Load detailed information for a specific skill + + Args: + name: skill 名称 / skill name + + Returns: + SkillDetail 实例,如果 skill 不存在则返回 None / + SkillDetail instance, or None if skill does not exist + """ + skills = self.scan_skills() + target_skill: Optional[SkillInfo] = None + for skill in skills: + if skill.name == name: + target_skill = skill + break + + if target_skill is None: + return None + + instruction = "" + skill_md_path = os.path.join(target_skill.path, "SKILL.md") + if os.path.isfile(skill_md_path): + try: + with open(skill_md_path, "r", encoding="utf-8") as file_handle: + instruction = file_handle.read() + except (OSError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read SKILL.md for skill '{name}': {error}" + ) + + files: List[str] = [] + try: + for entry in sorted(os.listdir(target_skill.path)): + if not entry.startswith("."): + entry_path = os.path.join(target_skill.path, entry) + if os.path.isdir(entry_path): + files.append(entry + "/") + else: + files.append(entry) + except OSError as error: + logger.warning(f"Failed to list files for skill '{name}': {error}") + + return SkillDetail( + name=target_skill.name, + description=target_skill.description, + version=target_skill.version, + path=target_skill.path, + instruction=instruction, + files=files, + ) + + def _build_tool_description(self, skills: List[SkillInfo]) -> str: + """构建 load_skills 工具的 description / Build the description for the load_skills tool + + 将所有可用 skill 的名称和描述写入工具描述中。 + Writes all available skill names and descriptions into the tool description. + + Args: + skills: skill 摘要列表 / list of skill summaries + + Returns: + 工具描述字符串 / tool description string + """ + if not skills: + return ( + "Load skill instructions for the agent. " + "No skills available in the configured directory." + ) + + skill_lines = [] + for skill in skills: + desc_part = f": {skill.description}" if skill.description else "" + skill_lines.append(f"- {skill.name}{desc_part}") + + skills_list = "\n".join(skill_lines) + return ( + "Load skill instructions for the agent. " + "Call without arguments to list all skills, " + "or with a skill name to get detailed instructions.\n\n" + f"Available skills:\n{skills_list}" + ) + + def _load_skills_func(self, name: Optional[str] = None) -> str: + """load_skills 工具的执行函数 / Execution function for the load_skills tool + + Args: + name: skill 名称(可选)/ skill name (optional) + + Returns: + JSON 字符串 / JSON string + """ + if name is None or name == "": + skills = self.scan_skills() + result: Dict[str, Any] = { + "skills": [ + {"name": skill.name, "description": skill.description} + for skill in skills + ] + } + return json.dumps(result, ensure_ascii=False) + + detail = self.load_skill(name) + if detail is None: + available = [skill.name for skill in self.scan_skills()] + available_str = ", ".join(available) if available else "none" + error_result: Dict[str, str] = { + "error": ( + f"Skill '{name}' not found. " + f"Available skills: {available_str}" + ) + } + return json.dumps(error_result, ensure_ascii=False) + + detail_result: Dict[str, Any] = { + "name": detail.name, + "description": detail.description, + "instruction": detail.instruction, + "files": detail.files, + } + return json.dumps(detail_result, ensure_ascii=False) + + def to_common_toolset(self) -> CommonToolSet: + """构造包含 load_skills 工具的 CommonToolSet / Construct CommonToolSet with load_skills tool + + Returns: + CommonToolSet 实例 / CommonToolSet instance + """ + skills = self.scan_skills() + description = self._build_tool_description(skills) + + load_skills_tool = Tool( + name="load_skills", + description=description, + parameters=[ + ToolParameter( + name="name", + param_type="string", + description=( + "The name of the skill to load. " + "If omitted, returns a list of all available skills." + ), + required=False, + ), + ], + func=self._load_skills_func, + ) + + return CommonToolSet(tools_list=[load_skills_tool]) + + +def skill_tools( + name: Optional[Union[str, List[str], "ToolResource"]] = None, + *, + skills_dir: str = ".skills", + config: Optional["Config"] = None, +) -> CommonToolSet: + """将 Skill 封装为通用工具集 / Wrap Skills as CommonToolSet + + 支持从工具名称、名称列表或 ToolResource 实例创建通用工具集。 + Supports creating CommonToolSet from tool name, name list, or ToolResource instance. + + Args: + name: 远程 skill 名称、名称列表或 ToolResource 实例(可选)/ + Remote skill name, name list, or ToolResource instance (optional). + 如果提供,会先下载到 skills_dir 再加载 / + If provided, downloads to skills_dir before loading. + 如果不提供,仅从 skills_dir 加载本地已有的 skill / + If not provided, only loads local skills from skills_dir. + skills_dir: 本地 skill 目录,默认 ".skills" / Local skill directory, default ".skills" + config: 配置对象 / Configuration object + + Returns: + CommonToolSet: 包含 load_skills 工具的通用工具集 / + CommonToolSet containing the load_skills tool + + Examples: + >>> # 仅加载本地 skill / Load local skills only + >>> ts = skill_tools(skills_dir=".skills") + >>> + >>> # 下载远程 skill 后加载 / Download remote skill then load + >>> ts = skill_tools("my-remote-skill") + >>> + >>> # 下载多个远程 skill / Download multiple remote skills + >>> ts = skill_tools(["skill-a", "skill-b"]) + >>> + >>> # 转换为 LangChain 工具 / Convert to LangChain tools + >>> lc_tools = ts.to_langchain() + """ + remote_names: List[str] = [] + + if name is not None: + if isinstance(name, str): + remote_names = [name] + elif isinstance(name, list): + remote_names = name + else: + # ToolResource instance — extract its name and download + tool_resource_instance = name + resource_name = getattr( + tool_resource_instance, "name", None + ) or getattr(tool_resource_instance, "tool_name", None) + if resource_name: + skill_path = os.path.join(skills_dir, resource_name) + if not os.path.isdir(skill_path): + tool_resource_instance.download_skill( + target_dir=skills_dir, config=config + ) + + loader = SkillLoader( + skills_dir=skills_dir, + remote_skill_names=remote_names, + config=config, + ) + return loader.to_common_toolset() diff --git a/tests/unittests/integration/test_skill_loader.py b/tests/unittests/integration/test_skill_loader.py new file mode 100644 index 0000000..830bba1 --- /dev/null +++ b/tests/unittests/integration/test_skill_loader.py @@ -0,0 +1,911 @@ +"""SkillLoader 单元测试 / SkillLoader Unit Tests + +测试 Skill 加载器的核心功能: +- _parse_frontmatter() 函数 +- SkillLoader 类(scan_skills / load_skill / to_common_toolset) +- skill_tools() 入口函数 +- builtin/skill.py 导出 +- 各框架 builtin 中的 skill_tools() 函数 +""" + +import json +import os +import sys +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import pytest + +import agentrun.integration.builtin.skill as _builtin_skill_mod +from agentrun.integration.utils.skill_loader import ( + _parse_frontmatter, + skill_tools, + SkillDetail, + SkillInfo, + SkillLoader, +) +from agentrun.integration.utils.tool import CommonToolSet + +# ============================================================================= +# Helper: 创建临时 skill 目录结构 +# ============================================================================= + + +def _create_skill_dir( + base_dir: str, + skill_name: str, + *, + skill_md_content: Optional[str] = None, + package_json: Optional[Dict[str, Any]] = None, + extra_files: Optional[Dict[str, str]] = None, +) -> str: + """在 base_dir 下创建一个 skill 子目录,写入可选的 SKILL.md / package.json / 其他文件。""" + skill_path = os.path.join(base_dir, skill_name) + os.makedirs(skill_path, exist_ok=True) + + if skill_md_content is not None: + with open( + os.path.join(skill_path, "SKILL.md"), "w", encoding="utf-8" + ) as fh: + fh.write(skill_md_content) + + if package_json is not None: + with open( + os.path.join(skill_path, "package.json"), "w", encoding="utf-8" + ) as fh: + json.dump(package_json, fh) + + if extra_files: + for filename, content in extra_files.items(): + file_path = os.path.join(skill_path, filename) + sub_dir = os.path.dirname(file_path) + if sub_dir and not os.path.isdir(sub_dir): + os.makedirs(sub_dir, exist_ok=True) + with open(file_path, "w", encoding="utf-8") as fh: + fh.write(content) + + return skill_path + + +# ============================================================================= +# 1. _parse_frontmatter 测试 +# ============================================================================= + + +class TestParseFrontmatter: + """测试 YAML frontmatter 解析函数""" + + def test_valid_frontmatter(self) -> None: + content = ( + "---\nname: my-skill\ndescription: A test skill\nversion:" + " 1.0.0\n---\n# Body" + ) + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "A test skill" + assert result["version"] == "1.0.0" + + def test_no_frontmatter(self) -> None: + content = "# Just a markdown file\nNo frontmatter here." + result = _parse_frontmatter(content) + assert result == {} + + def test_empty_string(self) -> None: + result = _parse_frontmatter("") + assert result == {} + + def test_quoted_values(self) -> None: + content = ( + "---\nname: \"quoted-name\"\ndescription: 'single-quoted'\n---\n" + ) + result = _parse_frontmatter(content) + assert result["name"] == "quoted-name" + assert result["description"] == "single-quoted" + + def test_empty_value(self) -> None: + content = "---\nname: my-skill\ndescription:\n---\n" + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "" + + def test_colon_in_value(self) -> None: + content = ( + "---\nname: my-skill\ndescription: A skill: does things\n---\n" + ) + result = _parse_frontmatter(content) + assert result["description"] == "A skill: does things" + + def test_blank_lines_in_frontmatter(self) -> None: + content = "---\nname: my-skill\n\ndescription: test\n---\n" + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "test" + + def test_no_closing_delimiter(self) -> None: + content = "---\nname: my-skill\ndescription: test\n" + result = _parse_frontmatter(content) + assert result == {} + + +# ============================================================================= +# 2. SkillLoader.scan_skills 测试 +# ============================================================================= + + +class TestScanSkills: + """测试 SkillLoader.scan_skills()""" + + def test_empty_directory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert result == [] + + def test_nonexistent_directory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "nonexistent") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert result == [] + + def test_skill_with_frontmatter(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content=( + "---\nname: custom-name\ndescription: A great skill\nversion:" + " 2.0\n---\n# Skill" + ), + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "custom-name" + assert result[0].description == "A great skill" + assert result[0].version == "2.0" + + def test_skill_with_package_json_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "pkg-skill", + package_json={ + "name": "pkg-name", + "description": "From package.json", + "version": "3.0", + }, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "pkg-name" + assert result[0].description == "From package.json" + assert result[0].version == "3.0" + + def test_skill_with_no_metadata(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "bare-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "bare-skill" + assert result[0].description == "" + + def test_frontmatter_takes_priority_over_package_json( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "priority-skill", + skill_md_content=( + "---\nname: from-frontmatter\ndescription: FM desc\n---\n" + ), + package_json={"name": "from-pkg", "description": "PKG desc"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "from-frontmatter" + assert result[0].description == "FM desc" + + def test_multiple_skills_sorted(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "beta-skill") + _create_skill_dir(skills_dir, "alpha-skill") + _create_skill_dir(skills_dir, "gamma-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 3 + assert [s.name for s in result] == [ + "alpha-skill", + "beta-skill", + "gamma-skill", + ] + + def test_hidden_directories_skipped(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, ".hidden-skill") + _create_skill_dir(skills_dir, "visible-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "visible-skill" + + def test_files_in_root_are_skipped(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with open(os.path.join(skills_dir, "not-a-skill.txt"), "w") as fh: + fh.write("just a file") + _create_skill_dir(skills_dir, "real-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "real-skill" + + def test_cache_is_used(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "cached-skill") + loader = SkillLoader(skills_dir=skills_dir) + first_result = loader.scan_skills() + # Add another skill after first scan + _create_skill_dir(skills_dir, "new-skill") + second_result = loader.scan_skills() + # Should return cached result + assert first_result is second_result + assert len(second_result) == 1 + + def test_malformed_package_json(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + skill_path = os.path.join(skills_dir, "bad-pkg") + os.makedirs(skill_path) + with open(os.path.join(skill_path, "package.json"), "w") as fh: + fh.write("{invalid json") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "bad-pkg" + + def test_skill_md_without_frontmatter_falls_to_package_json( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "fallback-skill", + skill_md_content="# No frontmatter here\nJust content.", + package_json={"name": "pkg-fallback", "description": "From pkg"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "pkg-fallback" + assert result[0].description == "From pkg" + + +# ============================================================================= +# 3. SkillLoader.load_skill 测试 +# ============================================================================= + + +class TestLoadSkill: + """测试 SkillLoader.load_skill()""" + + def test_load_existing_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + md_content = ( + "---\nname: test-skill\ndescription: Test\n---\n# Instructions\nDo" + " stuff." + ) + _create_skill_dir( + skills_dir, + "test-skill", + skill_md_content=md_content, + extra_files={"helper.py": "print('hello')"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("test-skill") + assert detail is not None + assert detail.name == "test-skill" + assert detail.description == "Test" + assert detail.instruction == md_content + assert "SKILL.md" in detail.files + assert "helper.py" in detail.files + + def test_load_nonexistent_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("nonexistent") + assert detail is None + + def test_load_skill_with_subdirectory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + skill_path = _create_skill_dir( + skills_dir, + "dir-skill", + skill_md_content="---\nname: dir-skill\n---\n", + ) + sub_dir = os.path.join(skill_path, "scripts") + os.makedirs(sub_dir) + with open(os.path.join(sub_dir, "run.sh"), "w") as fh: + fh.write("#!/bin/bash") + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("dir-skill") + assert detail is not None + assert "scripts/" in detail.files + + def test_load_skill_hidden_files_excluded(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "hidden-files-skill", + skill_md_content="---\nname: hidden-files-skill\n---\n", + extra_files={".hidden": "secret", "visible.txt": "public"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("hidden-files-skill") + assert detail is not None + assert ".hidden" not in detail.files + assert "visible.txt" in detail.files + + def test_load_skill_without_skill_md(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "no-md-skill", + extra_files={"readme.txt": "hello"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("no-md-skill") + assert detail is not None + assert detail.instruction == "" + assert "readme.txt" in detail.files + + +# ============================================================================= +# 4. SkillLoader._build_tool_description 测试 +# ============================================================================= + + +class TestBuildToolDescription: + """测试 load_skills 工具描述的构建""" + + def test_no_skills(self) -> None: + loader = SkillLoader(skills_dir="/nonexistent") + desc = loader._build_tool_description([]) + assert "No skills available" in desc + + def test_with_skills(self) -> None: + loader = SkillLoader(skills_dir="/nonexistent") + skills = [ + SkillInfo(name="alpha", description="Alpha skill"), + SkillInfo(name="beta", description=""), + ] + desc = loader._build_tool_description(skills) + assert "alpha: Alpha skill" in desc + assert "- beta" in desc + assert "Available skills:" in desc + + +# ============================================================================= +# 5. SkillLoader._load_skills_func 测试 +# ============================================================================= + + +class TestLoadSkillsFunc: + """测试 load_skills 工具的执行函数""" + + def test_list_all_skills(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "skill-a", + skill_md_content="---\nname: skill-a\ndescription: Skill A\n---\n", + ) + _create_skill_dir( + skills_dir, + "skill-b", + skill_md_content="---\nname: skill-b\ndescription: Skill B\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name=None) + result = json.loads(result_json) + assert "skills" in result + assert len(result["skills"]) == 2 + names = [s["name"] for s in result["skills"]] + assert "skill-a" in names + assert "skill-b" in names + + def test_list_with_empty_string(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "only-skill") + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="") + result = json.loads(result_json) + assert "skills" in result + + def test_load_specific_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + md_content = ( + "---\nname: target\ndescription: Target skill\n---\n# Instructions" + ) + _create_skill_dir( + skills_dir, + "target", + skill_md_content=md_content, + ) + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="target") + result = json.loads(result_json) + assert result["name"] == "target" + assert result["description"] == "Target skill" + assert "instruction" in result + assert "files" in result + + def test_load_nonexistent_skill_returns_error(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "existing") + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="missing") + result = json.loads(result_json) + assert "error" in result + assert "missing" in result["error"] + assert "existing" in result["error"] + + +# ============================================================================= +# 6. SkillLoader.to_common_toolset 测试 +# ============================================================================= + + +class TestToCommonToolset: + """测试 to_common_toolset() 返回的 CommonToolSet""" + + def test_returns_common_toolset(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "test-skill", + skill_md_content="---\nname: test-skill\ndescription: Test\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + assert isinstance(toolset, CommonToolSet) + + def test_toolset_has_load_skills_tool(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "test-skill") + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tools_list = toolset.tools() + assert len(tools_list) == 1 + assert tools_list[0].name == "load_skills" + + def test_tool_description_contains_skill_names(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "alpha", + skill_md_content="---\nname: alpha\ndescription: Alpha desc\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + assert "alpha" in tool.description + assert "Alpha desc" in tool.description + + def test_tool_has_name_parameter(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "test-skill") + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + # CanonicalTool.parameters is a JSON schema dict + assert "properties" in tool.parameters + assert "name" in tool.parameters["properties"] + name_prop = tool.parameters["properties"]["name"] + assert name_prop["type"] == "string" + + def test_tool_func_is_callable(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "callable-skill", + skill_md_content="---\nname: callable-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + result_json = tool.func() + result = json.loads(result_json) + assert "skills" in result + + def test_empty_skills_dir_still_returns_toolset( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + assert isinstance(toolset, CommonToolSet) + tools_list = toolset.tools() + assert len(tools_list) == 1 + assert "No skills available" in tools_list[0].description + + +# ============================================================================= +# 7. skill_tools() 入口函数测试 +# ============================================================================= + + +class TestSkillToolsFunction: + """测试 skill_tools() 入口函数""" + + def test_local_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "local-skill", + skill_md_content=( + "---\nname: local-skill\ndescription: Local\n---\n" + ), + ) + toolset = skill_tools(skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + assert len(toolset.tools()) == 1 + + def test_with_string_name_triggers_remote_download( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_tool_client = MagicMock() + mock_tool_resource = MagicMock() + mock_tool_client.return_value.get.return_value = mock_tool_resource + + with patch( + "agentrun.integration.utils.skill_loader.SkillLoader._ensure_skills_available" + ) as mock_ensure: + toolset = skill_tools(name="remote-skill", skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + + def test_with_list_of_names(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + with patch( + "agentrun.integration.utils.skill_loader.SkillLoader._ensure_skills_available" + ): + toolset = skill_tools( + name=["skill-a", "skill-b"], skills_dir=skills_dir + ) + assert isinstance(toolset, CommonToolSet) + + def test_with_tool_resource_instance(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_resource = MagicMock() + mock_resource.name = "resource-skill" + + toolset = skill_tools(name=mock_resource, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + mock_resource.download_skill.assert_called_once() + + def test_with_tool_resource_already_downloaded(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + # Pre-create the skill directory so download is skipped + _create_skill_dir(skills_dir, "existing-resource") + + mock_resource = MagicMock() + mock_resource.name = "existing-resource" + + toolset = skill_tools(name=mock_resource, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + mock_resource.download_skill.assert_not_called() + + def test_none_name_loads_local_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "local-only") + toolset = skill_tools(name=None, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + + +# ============================================================================= +# 8. _ensure_skills_available 测试 +# ============================================================================= + + +class TestEnsureSkillsAvailable: + """测试远程 skill 下载逻辑""" + + def test_no_remote_names_does_nothing(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir, remote_skill_names=[]) + # Should not raise + loader._ensure_skills_available() + + def test_existing_skill_skips_download(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "already-here") + + loader = SkillLoader( + skills_dir=skills_dir, remote_skill_names=["already-here"] + ) + with patch("agentrun.tool.client.ToolClient") as mock_client: + loader._ensure_skills_available() + mock_client.assert_not_called() + + def test_missing_skill_triggers_download(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_tool_resource = MagicMock() + mock_client_instance = MagicMock() + mock_client_instance.get.return_value = mock_tool_resource + + loader = SkillLoader( + skills_dir=skills_dir, remote_skill_names=["new-skill"] + ) + with patch( + "agentrun.tool.client.ToolClient", + return_value=mock_client_instance, + ): + loader._ensure_skills_available() + mock_client_instance.get.assert_called_once_with( + name="new-skill", config=None + ) + mock_tool_resource.download_skill.assert_called_once_with( + target_dir=skills_dir, config=None + ) + + +# ============================================================================= +# 9. builtin/skill.py 导出测试 +# ============================================================================= + + +class TestBuiltinSkillExport: + """测试 builtin/skill.py 的导出""" + + def test_skill_tools_is_exported(self) -> None: + assert hasattr(_builtin_skill_mod, "skill_tools") + assert callable(_builtin_skill_mod.skill_tools) + + def test_skill_tools_in_all(self) -> None: + assert "skill_tools" in _builtin_skill_mod.__all__ + + def test_import_from_builtin_init(self) -> None: + from agentrun.integration.builtin import skill_tools as imported_fn + + assert callable(imported_fn) + + +# ============================================================================= +# 10. 各框架 builtin skill_tools 测试 +# ============================================================================= + + +class TestFrameworkBuiltinSkillTools: + """测试各框架 builtin 中的 skill_tools() 函数""" + + def _run_framework_test(self, framework_module_path: str) -> None: + """通用框架测试:mock builtin skill_tools 返回 CommonToolSet, + 验证框架 skill_tools 调用了正确的转换方法。""" + mock_toolset = MagicMock(spec=CommonToolSet) + mock_toolset.to_langchain.return_value = [MagicMock()] + mock_toolset.to_google_adk.return_value = [MagicMock()] + mock_toolset.to_crewai.return_value = [MagicMock()] + mock_toolset.to_langgraph.return_value = [MagicMock()] + mock_toolset.to_pydantic_ai.return_value = [MagicMock()] + mock_toolset.to_agentscope.return_value = [MagicMock()] + + with patch( + f"{framework_module_path}._skill_tools", + return_value=mock_toolset, + ): + module = sys.modules.get(framework_module_path) + if module is None: + import importlib + + module = importlib.import_module(framework_module_path) + result = module.skill_tools(skills_dir=".test-skills") + assert isinstance(result, list) + assert len(result) == 1 + + def test_langchain_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.langchain.builtin") + + def test_google_adk_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.google_adk.builtin") + + def test_crewai_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.crewai.builtin") + + def test_langgraph_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.langgraph.builtin") + + def test_pydantic_ai_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.pydantic_ai.builtin") + + def test_agentscope_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.agentscope.builtin") + + def test_framework_import_from_init(self) -> None: + """验证各框架 __init__.py 正确导出 skill_tools""" + from agentrun.integration.agentscope import skill_tools as as_fn + from agentrun.integration.crewai import skill_tools as crew_fn + from agentrun.integration.google_adk import skill_tools as adk_fn + from agentrun.integration.langchain import skill_tools as lc_fn + from agentrun.integration.langgraph import skill_tools as lg_fn + from agentrun.integration.pydantic_ai import skill_tools as pai_fn + + assert callable(lc_fn) + assert callable(adk_fn) + assert callable(crew_fn) + assert callable(lg_fn) + assert callable(pai_fn) + assert callable(as_fn) + + +# ============================================================================= +# 11. SkillInfo / SkillDetail 数据类测试 +# ============================================================================= + + +class TestDataClasses: + """测试 SkillInfo 和 SkillDetail 数据类""" + + def test_skill_info_defaults(self) -> None: + info = SkillInfo(name="test") + assert info.name == "test" + assert info.description == "" + assert info.version == "" + assert info.path == "" + + def test_skill_info_with_all_fields(self) -> None: + info = SkillInfo( + name="full", description="desc", version="1.0", path="/path" + ) + assert info.name == "full" + assert info.description == "desc" + assert info.version == "1.0" + assert info.path == "/path" + + def test_skill_detail_defaults(self) -> None: + detail = SkillDetail(name="test") + assert detail.name == "test" + assert detail.instruction == "" + assert detail.files == [] + + def test_skill_detail_inherits_skill_info(self) -> None: + detail = SkillDetail( + name="full", + description="desc", + version="1.0", + path="/path", + instruction="# Do stuff", + files=["a.py", "b.py"], + ) + assert isinstance(detail, SkillInfo) + assert detail.instruction == "# Do stuff" + assert detail.files == ["a.py", "b.py"] + + +# ============================================================================= +# 12. 端到端集成测试 +# ============================================================================= + + +class TestEndToEnd: + """端到端测试:从创建 skill 目录到调用 load_skills 工具""" + + def test_full_workflow(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + md_content = ( + "---\n" + "name: e2e-skill\n" + "description: End-to-end test skill\n" + "version: 1.0.0\n" + "---\n" + "\n" + "# E2E Skill\n" + "\n" + "Follow these instructions to use the skill.\n" + ) + _create_skill_dir( + skills_dir, + "e2e-skill", + skill_md_content=md_content, + extra_files={"scripts/run.sh": "#!/bin/bash\necho hello"}, + ) + + toolset = skill_tools(skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + tools_list = toolset.tools() + assert len(tools_list) == 1 + + tool = tools_list[0] + assert tool.name == "load_skills" + assert "e2e-skill" in tool.description + + # List all skills + list_result = json.loads(tool.func()) + assert len(list_result["skills"]) == 1 + assert list_result["skills"][0]["name"] == "e2e-skill" + + # Load specific skill + detail_result = json.loads(tool.func(name="e2e-skill")) + assert detail_result["name"] == "e2e-skill" + assert detail_result["description"] == "End-to-end test skill" + assert "Follow these instructions" in detail_result["instruction"] + assert "SKILL.md" in detail_result["files"] + assert "scripts/" in detail_result["files"] + + # Load nonexistent skill + error_result = json.loads(tool.func(name="nonexistent")) + assert "error" in error_result + assert "e2e-skill" in error_result["error"] + + def test_multiple_skills_workflow(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + _create_skill_dir( + skills_dir, + "skill-alpha", + skill_md_content=( + "---\nname: skill-alpha\ndescription: Alpha\n---\n# Alpha" + ), + ) + _create_skill_dir( + skills_dir, + "skill-beta", + package_json={"name": "skill-beta", "description": "Beta"}, + ) + + toolset = skill_tools(skills_dir=skills_dir) + tool = toolset.tools()[0] + + list_result = json.loads(tool.func()) + assert len(list_result["skills"]) == 2 + + alpha = json.loads(tool.func(name="skill-alpha")) + assert alpha["name"] == "skill-alpha" + assert "# Alpha" in alpha["instruction"] + + beta = json.loads(tool.func(name="skill-beta")) + assert beta["name"] == "skill-beta" + assert beta["instruction"] == "" From b9d9c6fea22d6511322ae799113e3faf571a5a6f Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Tue, 31 Mar 2026 11:19:55 +0800 Subject: [PATCH 05/10] =?UTF-8?q?refactor(tool):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=96=B0=E7=89=88=E5=B7=A5=E5=85=B7ram=E9=93=BE=E8=B7=AF?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加了ToolCreateMethod枚举类,定义了多种工具创建和部署方式,并相应地更新了相关的导入和导出语句。同时增加了对异步事件循环处理的单元测试。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/tool/__init__.py | 2 + agentrun/tool/__tool_async_template.py | 123 +++++- agentrun/tool/api/mcp.py | 158 ++++++- agentrun/tool/api/openapi.py | 128 +++++- agentrun/tool/model.py | 19 + agentrun/tool/tool.py | 156 ++++++- .../test_tool_resource_integration.py | 386 ++++++++++++++++++ tests/unittests/tool/test_mcp.py | 58 ++- tests/unittests/tool/test_tool.py | 345 ++++++++++++++++ 9 files changed, 1361 insertions(+), 14 deletions(-) create mode 100644 tests/unittests/integration/test_tool_resource_integration.py diff --git a/agentrun/tool/__init__.py b/agentrun/tool/__init__.py index cce0e04..f22f2b6 100644 --- a/agentrun/tool/__init__.py +++ b/agentrun/tool/__init__.py @@ -12,6 +12,7 @@ McpConfig, ToolCodeConfiguration, ToolContainerConfiguration, + ToolCreateMethod, ToolInfo, ToolLogConfiguration, ToolNASConfig, @@ -29,6 +30,7 @@ "ToolClient", "Tool", "ToolType", + "ToolCreateMethod", "McpConfig", "ToolCodeConfiguration", "ToolContainerConfiguration", diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py index 74dd049..7f8b8a8 100644 --- a/agentrun/tool/__tool_async_template.py +++ b/agentrun/tool/__tool_async_template.py @@ -8,6 +8,7 @@ import os import shutil from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import zipfile import httpx @@ -16,6 +17,7 @@ from agentrun.utils.config import Config from agentrun.utils.log import logger from agentrun.utils.model import BaseModel +from agentrun.utils.ram_signature import get_agentrun_signed_headers from .model import ( McpConfig, @@ -128,9 +130,20 @@ class Tool(BaseModel): tool_type: Optional[str] = None """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + create_method: Optional[str] = None + """工具创建方式 / Tool create method + MCP_REMOTE: 远程 MCP 服务器 / Remote MCP server + MCP_LOCAL: 本地 MCP 标准输入输出 / Local MCP stdio + MCP_BUNDLE: MCP 打包部署 / MCP bundle deployment + CODE_PACKAGE: 代码包部署 / Code package deployment + OPENAPI_IMPORT: OpenAPI 导入 / OpenAPI import + """ + version_id: Optional[str] = None """版本 ID / Version ID""" + _RAM_DATA_DOMAINS = ("agentrun-data", "funagent-data-pre") + @classmethod def __get_client(cls, config: Optional[Config] = None): from .client import ToolClient @@ -238,20 +251,36 @@ async def list_tools_async( self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) return await session.list_tools_async() elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) return await openapi_client.list_tools_async() @@ -290,11 +319,20 @@ async def call_tool_async( self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) result = await session.call_tool_async(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -303,11 +341,17 @@ async def call_tool_async( elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, headers=cfg.get_headers(), fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) result = await openapi_client.call_tool_async(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -315,6 +359,83 @@ async def call_tool_async( raise ValueError(f"Unsupported tool type: {self.tool_type}") + def _use_ram_auth(self, config: Optional[Config] = None) -> bool: + """是否使用 RAM 签名鉴权(配置了 AK/SK 时使用)。 + Whether to use RAM signature authentication (when AK/SK is configured). + """ + cfg = Config.with_configs(config) + return bool(cfg.get_access_key_id() and cfg.get_access_key_secret()) + + def _get_ram_data_endpoint( + self, url: str, config: Optional[Config] = None + ) -> str: + """返回 RAM 鉴权用的 data endpoint(仅当 agentrun-data / funagent-data-pre 域名时在 host 前加 -ram)。 + Return RAM-authenticated endpoint (add -ram prefix for agentrun-data / funagent-data-pre domains). + """ + parsed = urlparse(url) + if not parsed.netloc or not any( + f".{domain}." in parsed.netloc for domain in self._RAM_DATA_DOMAINS + ): + return url + parts = parsed.netloc.split(".", 1) + if len(parts) != 2: + return url + ram_netloc = parts[0] + "-ram." + parts[1] + + from urllib.parse import urlunparse + + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + )) + + def _get_auth_headers( + self, url: str, config: Optional[Config] = None + ) -> Dict[str, str]: + """获取认证请求头,支持 RAM 签名。 + Get authentication headers with RAM signature support. + + Args: + url: 请求 URL / Request URL + config: 配置对象 / Configuration object + + Returns: + Dict[str, str]: 包含认证信息的请求头 / Headers with authentication + """ + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + if self._use_ram_auth(cfg): + # 使用 RAM 端点 + ram_url = self._get_ram_data_endpoint(url, cfg) + try: + signed = get_agentrun_signed_headers( + url=ram_url, + method="GET", + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token() or None, + region=cfg.get_region_id(), + product="agentrun", + body=None, + ) + headers = { + **signed, + **headers, + } + logger.debug( + "using RAM signature for skill download to %s", + ram_url[:80] + "..." if len(ram_url) > 80 else ram_url, + ) + except ValueError as e: + logger.warning("RAM signing skipped (missing AK/SK): %s", e) + + return headers + def _get_skill_download_url( self, config: Optional[Config] = None ) -> Optional[str]: @@ -376,7 +497,7 @@ async def download_skill_async( logger.debug("downloading skill from %s to %s", download_url, skill_dir) cfg = Config.with_configs(config) - headers = cfg.get_headers() + headers = self._get_auth_headers(download_url, cfg) async with httpx.AsyncClient( timeout=300, follow_redirects=True diff --git a/agentrun/tool/api/mcp.py b/agentrun/tool/api/mcp.py index a0cef61..9038d17 100644 --- a/agentrun/tool/api/mcp.py +++ b/agentrun/tool/api/mcp.py @@ -5,12 +5,108 @@ """ import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional +from urllib.parse import urlparse, urlunparse + +import httpx from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.config import Config from agentrun.utils.log import logger +def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: + """获取当前线程的事件循环,如果不存在则创建一个新的。 + Get the event loop for the current thread, creating a new one if none exists. + + Python 3.10+ 在非主线程中调用 asyncio.get_event_loop() 时, + 如果该线程没有事件循环会抛出 RuntimeError。此函数安全地处理该情况。 + """ + try: + return asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +from agentrun.utils.ram_signature import get_agentrun_signed_headers + + +class _AgentrunRamAuth(httpx.Auth): + """httpx Auth handler:为每次请求动态生成 RAM 签名。 + + SSE 场景下同一个 httpx.AsyncClient 会发出 GET(SSE 连接)和 + POST(消息发送)请求,URL / method / body 各不相同,因此必须 + per-request 计算签名,不能在 client 初始化时一次性设置 headers。 + """ + + def __init__( + self, + access_key_id: str, + access_key_secret: str, + region: str, + security_token: Optional[str] = None, + ): + self._ak = access_key_id + self._sk = access_key_secret + self._region = region + self._security_token = security_token + + def auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + url = str(request.url) + method = request.method + + body: Optional[bytes] = None + if request.content: + body = request.content + + content_type: Optional[str] = request.headers.get("content-type") + + try: + signed = get_agentrun_signed_headers( + url=url, + method=method, + access_key_id=self._ak, + access_key_secret=self._sk, + security_token=self._security_token, + region=self._region, + product="agentrun", + body=body, + content_type=content_type, + ) + for k, v in signed.items(): + request.headers[k] = v + logger.debug( + "applied RAM signature for MCP %s request to %s", + method, + url[:80] + ("..." if len(url) > 80 else ""), + ) + except ValueError as e: + logger.warning("RAM signing skipped for MCP request: %s", e) + + yield request + + +def _rewrite_to_ram_url(url: str) -> str: + """将 agentrun-data 域名改写为 -ram 端点。""" + parsed = urlparse(url) + parts = parsed.netloc.split(".", 1) + if len(parts) == 2: + ram_netloc = parts[0] + "-ram." + parts[1] + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path or "", + parsed.params, + parsed.query, + parsed.fragment, + )) + return url + + class ToolMCPSession: """Tool MCP 会话管理 / Tool MCP Session Manager @@ -23,6 +119,8 @@ def __init__( endpoint: str, session_affinity: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + use_ram_auth: bool = True, ): """初始化 MCP 会话 / Initialize MCP session @@ -30,16 +128,56 @@ def __init__( endpoint: MCP 数据链路 URL / MCP data endpoint URL session_affinity: 会话亲和性策略 / Session affinity strategy headers: 请求头 / Request headers + config: 配置对象 / Configuration object + use_ram_auth: 是否启用 RAM 签名鉴权 / Whether to enable RAM signature auth. + MCP_REMOTE + proxy_enabled=false 时设为 False(直连外部服务)。 + Set to False for MCP_REMOTE with proxy disabled (direct external connection). """ self.endpoint = endpoint self.session_affinity = session_affinity self.headers = headers or {} + self.config = Config.with_configs(config) + self.use_ram_auth = use_ram_auth @property def is_streamable(self) -> bool: """是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport""" return self.session_affinity == "MCP_STREAMABLE" + def _build_ram_auth(self, url: str) -> tuple: + """当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。 + + Returns: + (rewritten_url, auth_or_none) + """ + # MCP_REMOTE + proxy_enabled=false 时不走 RAM 鉴权 + # Skip RAM auth for MCP_REMOTE with proxy disabled + if not self.use_ram_auth: + return url, None + + parsed = urlparse(url) + # 只对 agentrun-data 和 funagent-data-pre 域名应用 RAM 签名 + if ".agentrun-data." not in ( + parsed.netloc or "" + ) and ".funagent-data-pre." not in (parsed.netloc or ""): + return url, None + + cfg = self.config + ak = cfg.get_access_key_id() + sk = cfg.get_access_key_secret() + if not ak or not sk: + return url, None + + url = _rewrite_to_ram_url(url) + + auth = _AgentrunRamAuth( + access_key_id=ak, + access_key_secret=sk, + region=cfg.get_region_id(), + security_token=cfg.get_security_token() or None, + ) + return url, auth + async def list_tools_async(self) -> List[ToolInfo]: """异步获取工具列表 / Get tool list asynchronously @@ -49,11 +187,14 @@ async def list_tools_async(self) -> List[ToolInfo]: try: from mcp import ClientSession + # 应用 RAM 签名 + url, auth = self._build_ram_auth(self.endpoint) + if self.is_streamable: from mcp.client.streamable_http import streamablehttp_client async with streamablehttp_client( - self.endpoint, headers=self.headers + url, headers=self.headers, auth=auth ) as (read_stream, write_stream, _): async with ClientSession( read_stream, write_stream @@ -67,7 +208,7 @@ async def list_tools_async(self) -> List[ToolInfo]: else: from mcp.client.sse import sse_client - async with sse_client(self.endpoint, headers=self.headers) as ( + async with sse_client(url, headers=self.headers, auth=auth) as ( read_stream, write_stream, ): @@ -92,7 +233,7 @@ def list_tools(self) -> List[ToolInfo]: Returns: List[ToolInfo]: 工具信息列表 / List of tool information """ - return asyncio.get_event_loop().run_until_complete( + return _get_or_create_event_loop().run_until_complete( self.list_tools_async() ) @@ -113,11 +254,14 @@ async def call_tool_async( try: from mcp import ClientSession + # 应用 RAM 签名 + url, auth = self._build_ram_auth(self.endpoint) + if self.is_streamable: from mcp.client.streamable_http import streamablehttp_client async with streamablehttp_client( - self.endpoint, headers=self.headers + url, headers=self.headers, auth=auth ) as (read_stream, write_stream, _): async with ClientSession( read_stream, write_stream @@ -130,7 +274,7 @@ async def call_tool_async( else: from mcp.client.sse import sse_client - async with sse_client(self.endpoint, headers=self.headers) as ( + async with sse_client(url, headers=self.headers, auth=auth) as ( read_stream, write_stream, ): @@ -162,6 +306,6 @@ def call_tool( Returns: Any: 工具执行结果 / Tool execution result """ - return asyncio.get_event_loop().run_until_complete( + return _get_or_create_event_loop().run_until_complete( self.call_tool_async(name, arguments) ) diff --git a/agentrun/tool/api/openapi.py b/agentrun/tool/api/openapi.py index 5873c7f..5202dc9 100644 --- a/agentrun/tool/api/openapi.py +++ b/agentrun/tool/api/openapi.py @@ -7,12 +7,84 @@ """ import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional +from urllib.parse import urlparse, urlunparse import httpx from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.config import Config from agentrun.utils.log import logger +from agentrun.utils.ram_signature import get_agentrun_signed_headers + + +class _AgentrunRamAuth(httpx.Auth): + """httpx Auth handler:为每次请求动态生成 RAM 签名。""" + + def __init__( + self, + access_key_id: str, + access_key_secret: str, + region: str, + security_token: Optional[str] = None, + ): + self._ak = access_key_id + self._sk = access_key_secret + self._region = region + self._security_token = security_token + + def auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + url = str(request.url) + method = request.method + + body: Optional[bytes] = None + if request.content: + body = request.content + + content_type: Optional[str] = request.headers.get("content-type") + + try: + signed = get_agentrun_signed_headers( + url=url, + method=method, + access_key_id=self._ak, + access_key_secret=self._sk, + security_token=self._security_token, + region=self._region, + product="agentrun", + body=body, + content_type=content_type, + ) + for k, v in signed.items(): + request.headers[k] = v + logger.debug( + "applied RAM signature for OpenAPI %s request to %s", + method, + url[:80] + ("..." if len(url) > 80 else ""), + ) + except ValueError as e: + logger.warning("RAM signing skipped for OpenAPI request: %s", e) + + yield request + + +def _rewrite_to_ram_url(url: str) -> str: + """将 agentrun-data 域名改写为 -ram 端点。""" + parsed = urlparse(url) + parts = parsed.netloc.split(".", 1) + if len(parts) == 2: + ram_netloc = parts[0] + "-ram." + parts[1] + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path or "", + parsed.params, + parsed.query, + parsed.fragment, + )) + return url class ToolOpenAPIClient: @@ -27,6 +99,8 @@ def __init__( protocol_spec: Optional[str] = None, headers: Optional[Dict[str, str]] = None, fallback_server_url: Optional[str] = None, + config: Optional[Config] = None, + use_ram_auth: bool = True, ): """初始化 OpenAPI 客户端 / Initialize OpenAPI client @@ -35,11 +109,17 @@ def __init__( headers: 请求头 / Request headers fallback_server_url: 当 OpenAPI spec 中没有 servers 时的备用 URL / Fallback URL when servers is not present in OpenAPI spec + config: 配置对象 / Configuration object + use_ram_auth: 是否启用 RAM 签名鉴权 / Whether to enable RAM signature auth. + OPENAPI_IMPORT 时设为 False(server 是外部服务)。 + Set to False for OPENAPI_IMPORT (server is an external service). """ self.headers = headers or {} self._fallback_server_url = fallback_server_url self._spec: Optional[Dict[str, Any]] = None self._operations: Optional[List[Dict[str, Any]]] = None + self.config = Config.with_configs(config) + self.use_ram_auth = use_ram_auth if protocol_spec: try: @@ -258,11 +338,16 @@ def call_tool( url = f"{base_url.rstrip('/')}{target_operation['path']}" method = target_operation["method"] + # 应用 RAM 签名 + url, auth = self._build_ram_auth(url) + logger.debug( f"Calling FunctionCall tool: {method} {url} with args={arguments}" ) - with httpx.Client(headers=self.headers, timeout=30.0) as client: + with httpx.Client( + headers=self.headers, timeout=30.0, auth=auth + ) as client: if method in ("POST", "PUT", "PATCH"): response = client.request(method, url, json=arguments or {}) else: @@ -312,13 +397,16 @@ async def call_tool_async( url = f"{base_url.rstrip('/')}{target_operation['path']}" method = target_operation["method"] + # 应用 RAM 签名 + url, auth = self._build_ram_auth(url) + logger.debug( f"Calling FunctionCall tool async: {method} {url} with" f" args={arguments}" ) async with httpx.AsyncClient( - headers=self.headers, timeout=30.0 + headers=self.headers, timeout=30.0, auth=auth ) as client: if method in ("POST", "PUT", "PATCH"): response = await client.request( @@ -335,3 +423,37 @@ async def call_tool_async( if "application/json" in content_type: return response.json() return response.text + + def _build_ram_auth(self, url: str) -> tuple: + """当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。 + + Returns: + (rewritten_url, auth_or_none) + """ + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + if not self.use_ram_auth: + return url, None + + parsed = urlparse(url) + # 只对 agentrun-data 和 funagent-data-pre 域名应用 RAM 签名 + if ".agentrun-data." not in ( + parsed.netloc or "" + ) and ".funagent-data-pre." not in (parsed.netloc or ""): + return url, None + + cfg = self.config + ak = cfg.get_access_key_id() + sk = cfg.get_access_key_secret() + if not ak or not sk: + return url, None + + url = _rewrite_to_ram_url(url) + + auth = _AgentrunRamAuth( + access_key_id=ak, + access_key_secret=sk, + region=cfg.get_region_id(), + security_token=cfg.get_security_token() or None, + ) + return url, auth diff --git a/agentrun/tool/model.py b/agentrun/tool/model.py index 2d8cc81..0bdb001 100644 --- a/agentrun/tool/model.py +++ b/agentrun/tool/model.py @@ -21,6 +21,25 @@ class ToolType(str, Enum): """技能工具 / Skill Tool""" +class ToolCreateMethod(str, Enum): + """工具创建方式 / Tool Create Method + + 描述工具的创建和部署方式,用于数据链路鉴权策略判断。 + Describes how a tool is created and deployed, used for data-plane auth policy decisions. + """ + + MCP_REMOTE = "MCP_REMOTE" + """远程 MCP 服务器 / Remote MCP server""" + MCP_LOCAL = "MCP_LOCAL" + """本地 MCP 标准输入输出 / Local MCP stdio""" + MCP_BUNDLE = "MCP_BUNDLE" + """MCP 打包部署 / MCP bundle deployment""" + CODE_PACKAGE = "CODE_PACKAGE" + """代码包部署 / Code package deployment""" + OPENAPI_IMPORT = "OPENAPI_IMPORT" + """OpenAPI 导入 / OpenAPI import""" + + class McpConfig(BaseModel): """MCP 工具配置 / MCP Tool Configuration diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py index b9eb565..42439f0 100644 --- a/agentrun/tool/tool.py +++ b/agentrun/tool/tool.py @@ -18,6 +18,7 @@ import os import shutil from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import zipfile import httpx @@ -26,6 +27,7 @@ from agentrun.utils.config import Config from agentrun.utils.log import logger from agentrun.utils.model import BaseModel +from agentrun.utils.ram_signature import get_agentrun_signed_headers from .model import ( McpConfig, @@ -138,9 +140,20 @@ class Tool(BaseModel): tool_type: Optional[str] = None """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + create_method: Optional[str] = None + """工具创建方式 / Tool create method + MCP_REMOTE: 远程 MCP 服务器 / Remote MCP server + MCP_LOCAL: 本地 MCP 标准输入输出 / Local MCP stdio + MCP_BUNDLE: MCP 打包部署 / MCP bundle deployment + CODE_PACKAGE: 代码包部署 / Code package deployment + OPENAPI_IMPORT: OpenAPI 导入 / OpenAPI import + """ + version_id: Optional[str] = None """版本 ID / Version ID""" + _RAM_DATA_DOMAINS = ("agentrun-data", "funagent-data-pre") + @classmethod def __get_client(cls, config: Optional[Config] = None): from .client import ToolClient @@ -263,20 +276,36 @@ async def list_tools_async( self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) return await session.list_tools_async() elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) return await openapi_client.list_tools_async() @@ -309,20 +338,36 @@ def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) return session.list_tools() elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) return openapi_client.list_tools() @@ -361,11 +406,20 @@ async def call_tool_async( self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) result = await session.call_tool_async(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -374,11 +428,17 @@ async def call_tool_async( elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, headers=cfg.get_headers(), fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) result = await openapi_client.call_tool_async(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -419,11 +479,20 @@ def call_tool( self, "mcp_config.session_affinity", "MCP_SSE" ) + # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 + # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + cfg = Config.with_configs(config) session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, headers=cfg.get_headers(), + config=cfg, + use_ram_auth=not is_mcp_remote_without_proxy, ) result = session.call_tool(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -432,11 +501,17 @@ def call_tool( elif tool_type == ToolType.FUNCTIONCALL: from .api.openapi import ToolOpenAPIClient + # OPENAPI_IMPORT 时 server 是外部服务,不走 RAM 鉴权 + # Skip RAM auth for OPENAPI_IMPORT since the server is an external service + is_openapi_import = self.create_method == "OPENAPI_IMPORT" + cfg = Config.with_configs(config) openapi_client = ToolOpenAPIClient( protocol_spec=self.protocol_spec, headers=cfg.get_headers(), fallback_server_url=self._get_functioncall_server_url(config), + config=cfg, + use_ram_auth=not is_openapi_import, ) result = openapi_client.call_tool(name, arguments) logger.debug("invoke tool %s got result %s", name, result) @@ -444,6 +519,83 @@ def call_tool( raise ValueError(f"Unsupported tool type: {self.tool_type}") + def _use_ram_auth(self, config: Optional[Config] = None) -> bool: + """是否使用 RAM 签名鉴权(配置了 AK/SK 时使用)。 + Whether to use RAM signature authentication (when AK/SK is configured). + """ + cfg = Config.with_configs(config) + return bool(cfg.get_access_key_id() and cfg.get_access_key_secret()) + + def _get_ram_data_endpoint( + self, url: str, config: Optional[Config] = None + ) -> str: + """返回 RAM 鉴权用的 data endpoint(仅当 agentrun-data / funagent-data-pre 域名时在 host 前加 -ram)。 + Return RAM-authenticated endpoint (add -ram prefix for agentrun-data / funagent-data-pre domains). + """ + parsed = urlparse(url) + if not parsed.netloc or not any( + f".{domain}." in parsed.netloc for domain in self._RAM_DATA_DOMAINS + ): + return url + parts = parsed.netloc.split(".", 1) + if len(parts) != 2: + return url + ram_netloc = parts[0] + "-ram." + parts[1] + + from urllib.parse import urlunparse + + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + )) + + def _get_auth_headers( + self, url: str, config: Optional[Config] = None + ) -> Dict[str, str]: + """获取认证请求头,支持 RAM 签名。 + Get authentication headers with RAM signature support. + + Args: + url: 请求 URL / Request URL + config: 配置对象 / Configuration object + + Returns: + Dict[str, str]: 包含认证信息的请求头 / Headers with authentication + """ + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + if self._use_ram_auth(cfg): + # 使用 RAM 端点 + ram_url = self._get_ram_data_endpoint(url, cfg) + try: + signed = get_agentrun_signed_headers( + url=ram_url, + method="GET", + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token() or None, + region=cfg.get_region_id(), + product="agentrun", + body=None, + ) + headers = { + **signed, + **headers, + } + logger.debug( + "using RAM signature for skill download to %s", + ram_url[:80] + "..." if len(ram_url) > 80 else ram_url, + ) + except ValueError as e: + logger.warning("RAM signing skipped (missing AK/SK): %s", e) + + return headers + def _get_skill_download_url( self, config: Optional[Config] = None ) -> Optional[str]: @@ -505,7 +657,7 @@ async def download_skill_async( logger.debug("downloading skill from %s to %s", download_url, skill_dir) cfg = Config.with_configs(config) - headers = cfg.get_headers() + headers = self._get_auth_headers(download_url, cfg) async with httpx.AsyncClient( timeout=300, follow_redirects=True @@ -565,7 +717,7 @@ def download_skill( logger.debug("downloading skill from %s to %s", download_url, skill_dir) cfg = Config.with_configs(config) - headers = cfg.get_headers() + headers = self._get_auth_headers(download_url, cfg) with httpx.Client(timeout=300, follow_redirects=True) as http_client: response = http_client.get(download_url, headers=headers) diff --git a/tests/unittests/integration/test_tool_resource_integration.py b/tests/unittests/integration/test_tool_resource_integration.py new file mode 100644 index 0000000..3b33da5 --- /dev/null +++ b/tests/unittests/integration/test_tool_resource_integration.py @@ -0,0 +1,386 @@ +"""ToolResource Integration 单元测试 / ToolResource Integration Unit Tests + +测试新版 ToolResource 到 integration 层的桥接功能: +- CommonToolSet.from_agentrun_tool() 类方法 +- builtin/tool_resource.py 入口函数 +- 各框架 builtin 中的 tool_resource() 函数 +""" + +import sys +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import pytest + +# 获取 builtin.tool_resource 模块的真正模块对象 +# __init__.py 中 from .tool_resource import tool_resource 会让 +# "agentrun.integration.builtin.tool_resource" 在 patch 字符串路径中 +# 解析为函数而非模块,所以必须用 sys.modules 获取真正的模块对象 +# 再配合 patch.object 使用 +import agentrun.integration.builtin.tool_resource # noqa: F401 +from agentrun.integration.utils.tool import CommonToolSet + +_tool_resource_mod = sys.modules["agentrun.integration.builtin.tool_resource"] + + +# ============================================================================= +# Helper: 构建 mock ToolResource 和 ToolInfo +# ============================================================================= + + +class FakeToolInfo: + """模拟 ToolInfo 对象,支持 model_dump() 返回真实字典。 + + _to_dict() 内部会调用 obj.model_dump(exclude_none=True), + 所以 mock 必须返回真实的 dict 而不是 MagicMock。 + """ + + def __init__( + self, + name: str, + description: str = "", + input_schema: Optional[Dict[str, Any]] = None, + ): + self.name = name + self.description = description + self.input_schema = input_schema + + def model_dump(self, **kwargs) -> Dict[str, Any]: + result: Dict[str, Any] = { + "name": self.name, + "description": self.description, + } + if self.input_schema is not None: + result["input_schema"] = self.input_schema + return result + + +def _make_tool_info( + name: str, + description: str = "", + input_schema: Optional[Dict[str, Any]] = None, +) -> FakeToolInfo: + """创建一个 FakeToolInfo 对象""" + return FakeToolInfo( + name=name, description=description, input_schema=input_schema + ) + + +def _make_mock_tool_resource( + tool_infos: Optional[List[FakeToolInfo]] = None, + tool_name: str = "test-tool", +) -> MagicMock: + """创建一个 mock ToolResource 实例 + + 模拟 agentrun.tool.tool.Tool 的接口: + - list_tools(config) -> List[ToolInfo] + - call_tool(name, arguments, config) -> Any + - get(config) -> ToolResource + """ + resource = MagicMock() + resource.tool_name = tool_name + resource.list_tools.return_value = ( + tool_infos if tool_infos is not None else [] + ) + resource.call_tool.return_value = {"result": "ok"} + resource.get.return_value = resource + return resource + + +# ============================================================================= +# Tests: CommonToolSet.from_agentrun_tool() +# ============================================================================= + + +class TestFromAgentrunTool: + """测试 CommonToolSet.from_agentrun_tool() 类方法""" + + def test_empty_tool_list(self): + """空工具列表返回空 CommonToolSet""" + resource = _make_mock_tool_resource(tool_infos=[]) + result = CommonToolSet.from_agentrun_tool(resource) + assert isinstance(result, CommonToolSet) + assert len(result.tools()) == 0 + resource.list_tools.assert_called_once_with(config=None) + + def test_single_tool(self): + """单个工具正确桥接""" + info = _make_tool_info("search", "Search the web") + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert len(tools) == 1 + assert tools[0].name == "search" + + def test_multiple_tools(self): + """多个工具正确桥接""" + infos = [ + _make_tool_info("tool_a", "Tool A"), + _make_tool_info("tool_b", "Tool B"), + _make_tool_info("tool_c", "Tool C"), + ] + resource = _make_mock_tool_resource(tool_infos=infos) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert len(tools) == 3 + names = {t.name for t in tools} + assert names == {"tool_a", "tool_b", "tool_c"} + + def test_duplicate_tool_names_skipped(self): + """重复工具名被跳过""" + infos = [ + _make_tool_info("dup_tool", "First"), + _make_tool_info("dup_tool", "Second"), + ] + resource = _make_mock_tool_resource(tool_infos=infos) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert len(tools) == 1 + assert tools[0].name == "dup_tool" + + def test_with_config(self): + """config 参数正确传递""" + config = MagicMock() + resource = _make_mock_tool_resource(tool_infos=[]) + CommonToolSet.from_agentrun_tool(resource, config=config) + resource.list_tools.assert_called_once_with(config=config) + + def test_with_refresh(self): + """refresh=True 时先调用 get()""" + config = MagicMock() + resource = _make_mock_tool_resource(tool_infos=[]) + CommonToolSet.from_agentrun_tool(resource, config=config, refresh=True) + resource.get.assert_called_once_with(config=config) + + def test_without_refresh(self): + """refresh=False 时不调用 get()""" + resource = _make_mock_tool_resource(tool_infos=[]) + CommonToolSet.from_agentrun_tool(resource, refresh=False) + resource.get.assert_not_called() + + def test_none_tool_list(self): + """list_tools 返回 None 时返回空 CommonToolSet""" + resource = _make_mock_tool_resource() + resource.list_tools.return_value = None + result = CommonToolSet.from_agentrun_tool(resource) + assert len(result.tools()) == 0 + + def test_to_openai_function_conversion(self): + """桥接后的工具可以转换为 OpenAI 格式(无 function 包装)""" + info = _make_tool_info("weather", "Get weather info") + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + openai_tools = result.to_openai_function() + assert len(openai_tools) == 1 + assert openai_tools[0]["name"] == "weather" + + def test_to_anthropic_tool_conversion(self): + """桥接后的工具可以转换为 Anthropic 格式""" + info = _make_tool_info("calculator", "Calculate things") + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + anthropic_tools = result.to_anthropic_tool() + assert len(anthropic_tools) == 1 + assert anthropic_tools[0]["name"] == "calculator" + + def test_tool_description_preserved(self): + """桥接后的工具描述被保留""" + info = _make_tool_info("echo", "Echo back the input") + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert tools[0].description == "Echo back the input" + + def test_prefix_filter(self): + """filter 参数正常工作""" + infos = [ + _make_tool_info("get_weather", "Get weather"), + _make_tool_info("set_alarm", "Set alarm"), + ] + resource = _make_mock_tool_resource(tool_infos=infos) + result = CommonToolSet.from_agentrun_tool(resource) + + filtered = result.to_openai_function( + filter_tools_by_name=lambda name: name.startswith("get_") + ) + assert len(filtered) == 1 + assert filtered[0]["name"] == "get_weather" + + def test_tool_with_input_schema(self): + """带 input_schema 的工具正确解析参数""" + schema = { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + } + info = _make_tool_info("weather", "Get weather", input_schema=schema) + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert len(tools) == 1 + assert tools[0].name == "weather" + + def test_tool_call_forwards_to_resource(self): + """桥接后的工具调用会转发到 ToolResource.call_tool()""" + info = _make_tool_info("echo", "Echo tool") + resource = _make_mock_tool_resource(tool_infos=[info]) + result = CommonToolSet.from_agentrun_tool(resource) + tools = result.tools() + assert len(tools) == 1 + # Tool 对象有 func 属性,调用 func 而非直接调用 Tool + if hasattr(tools[0], "func") and tools[0].func is not None: + tools[0].func(message="hello") + resource.call_tool.assert_called_once() + + +# ============================================================================= +# Tests: builtin/tool_resource.py 入口函数 +# ============================================================================= + + +class TestBuiltinToolResource: + """测试 agentrun.integration.builtin.tool_resource 入口函数""" + + def test_from_string_name(self): + """通过字符串名称创建""" + mock_client = MagicMock() + mock_resource = _make_mock_tool_resource(tool_infos=[]) + mock_client.return_value.get.return_value = mock_resource + + with patch.object(_tool_resource_mod, "ToolClient", mock_client): + from agentrun.integration.builtin.tool_resource import ( + tool_resource as builtin_tool_resource, + ) + + result = builtin_tool_resource("my-tool") + assert isinstance(result, CommonToolSet) + mock_client.return_value.get.assert_called_once_with( + name="my-tool", config=None + ) + + def test_from_string_name_with_config(self): + """通过字符串名称 + config 创建""" + config = MagicMock() + mock_client = MagicMock() + mock_resource = _make_mock_tool_resource(tool_infos=[]) + mock_client.return_value.get.return_value = mock_resource + + with patch.object(_tool_resource_mod, "ToolClient", mock_client): + from agentrun.integration.builtin.tool_resource import ( + tool_resource as builtin_tool_resource, + ) + + result = builtin_tool_resource("my-tool", config=config) + assert isinstance(result, CommonToolSet) + mock_client.return_value.get.assert_called_once_with( + name="my-tool", config=config + ) + + def test_from_tool_resource_instance(self): + """通过 ToolResource 实例创建""" + from agentrun.integration.builtin.tool_resource import ( + tool_resource as builtin_tool_resource, + ) + from agentrun.tool.tool import Tool as ToolResourceType + + mock_resource = MagicMock(spec=ToolResourceType) + mock_resource.list_tools.return_value = [] + + result = builtin_tool_resource(mock_resource) + assert isinstance(result, CommonToolSet) + mock_resource.list_tools.assert_called_once() + + +# ============================================================================= +# Tests: 各框架 builtin tool_resource() 函数 +# ============================================================================= + + +class TestFrameworkBuiltinToolResource: + """测试各框架 builtin 中的 tool_resource() 函数""" + + def _run_framework_test(self, framework_module_path: str): + """通用框架测试辅助方法""" + import importlib + + module = importlib.import_module(framework_module_path) + framework_tool_resource = getattr(module, "tool_resource") + + mock_client = MagicMock() + info = _make_tool_info("test_tool", "A test tool") + mock_resource = _make_mock_tool_resource(tool_infos=[info]) + mock_client.return_value.get.return_value = mock_resource + + with patch.object(_tool_resource_mod, "ToolClient", mock_client): + result = framework_tool_resource("my-tool") + assert isinstance(result, list) + assert len(result) >= 1 + + def test_langchain_tool_resource(self): + """LangChain tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.langchain.builtin") + + def test_google_adk_tool_resource(self): + """Google ADK tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.google_adk.builtin") + + def test_langgraph_tool_resource(self): + """LangGraph tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.langgraph.builtin") + + def test_agentscope_tool_resource(self): + """AgentScope tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.agentscope.builtin") + + def test_crewai_tool_resource(self): + """CrewAI tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.crewai.builtin") + + def test_pydantic_ai_tool_resource(self): + """PydanticAI tool_resource() 返回列表""" + self._run_framework_test("agentrun.integration.pydantic_ai.builtin") + + def test_framework_with_filter(self): + """框架 tool_resource() 支持 filter_tools_by_name 参数""" + from agentrun.integration.langchain.builtin import ( + tool_resource as lc_tool_resource, + ) + + mock_client = MagicMock() + infos = [ + _make_tool_info("get_data", "Get data"), + _make_tool_info("set_data", "Set data"), + ] + mock_resource = _make_mock_tool_resource(tool_infos=infos) + mock_client.return_value.get.return_value = mock_resource + + with patch.object(_tool_resource_mod, "ToolClient", mock_client): + result = lc_tool_resource( + "my-tool", + filter_tools_by_name=lambda name: name.startswith("get_"), + ) + assert isinstance(result, list) + assert len(result) == 1 + + +# ============================================================================= +# Tests: builtin __init__ 导出 +# ============================================================================= + + +class TestBuiltinInit: + """测试 builtin __init__.py 正确导出 tool_resource""" + + def test_import_from_builtin(self): + """可以从 builtin 导入 tool_resource""" + from agentrun.integration.builtin import tool_resource as imported_func + + assert callable(imported_func) + + def test_in_all(self): + """tool_resource 在 __all__ 中""" + import agentrun.integration.builtin as builtin_module + + assert "tool_resource" in builtin_module.__all__ diff --git a/tests/unittests/tool/test_mcp.py b/tests/unittests/tool/test_mcp.py index 83dc3fa..907c007 100644 --- a/tests/unittests/tool/test_mcp.py +++ b/tests/unittests/tool/test_mcp.py @@ -296,7 +296,9 @@ def test_call_tool_synchronous(self): new_callable=AsyncMock, return_value=expected_result, ): - with patch("asyncio.get_event_loop") as mock_get_loop: + with patch( + "agentrun.tool.api.mcp._get_or_create_event_loop" + ) as mock_get_loop: mock_loop = MagicMock() mock_loop.run_until_complete.return_value = expected_result mock_get_loop.return_value = mock_loop @@ -306,3 +308,57 @@ def test_call_tool_synchronous(self): assert result == expected_result mock_loop.run_until_complete.assert_called_once() + + +class TestGetOrCreateEventLoop: + """测试 _get_or_create_event_loop 辅助函数""" + + def test_returns_existing_event_loop(self): + """测试在有事件循环的线程中返回现有循环""" + from agentrun.tool.api.mcp import _get_or_create_event_loop + + with patch("agentrun.tool.api.mcp.asyncio.get_event_loop") as mock_get: + mock_loop = MagicMock() + mock_get.return_value = mock_loop + + result = _get_or_create_event_loop() + + assert result is mock_loop + mock_get.assert_called_once() + + def test_creates_new_event_loop_when_none_exists(self): + """测试在无事件循环的线程中创建新循环(模拟 Python 3.10+ 非主线程行为)""" + from agentrun.tool.api.mcp import _get_or_create_event_loop + + with patch( + "agentrun.tool.api.mcp.asyncio.get_event_loop", + side_effect=RuntimeError("no event loop"), + ): + with patch( + "agentrun.tool.api.mcp.asyncio.new_event_loop" + ) as mock_new: + with patch( + "agentrun.tool.api.mcp.asyncio.set_event_loop" + ) as mock_set: + mock_loop = MagicMock() + mock_new.return_value = mock_loop + + result = _get_or_create_event_loop() + + assert result is mock_loop + mock_new.assert_called_once() + mock_set.assert_called_once_with(mock_loop) + + def test_works_in_thread_pool_executor(self): + """测试在 ThreadPoolExecutor 线程中能正常工作""" + import concurrent.futures + + from agentrun.tool.api.mcp import _get_or_create_event_loop + + def get_loop_in_thread(): + loop = _get_or_create_event_loop() + return loop is not None + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(get_loop_in_thread) + assert future.result() is True diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py index e3e2bf2..da129b0 100644 --- a/tests/unittests/tool/test_tool.py +++ b/tests/unittests/tool/test_tool.py @@ -615,6 +615,11 @@ def test_download_skill_sync_success( mock_config = Mock() mock_config.get_headers.return_value = {} + # 添加 AK/SK 的 Mock 返回值,避免 RAM 签名 + mock_config.get_access_key_id.return_value = None + mock_config.get_access_key_secret.return_value = None + mock_config.get_security_token.return_value = None + mock_config.get_region_id.return_value = "cn-hangzhou" mock_config_class.with_configs.return_value = mock_config zip_buffer = io.BytesIO() @@ -655,6 +660,346 @@ def test_download_skill_sync_wrong_type(self): with pytest.raises(ValueError, match="only available for SKILL"): tool.download_skill() + @patch("agentrun.tool.tool.get_agentrun_signed_headers") + @patch("agentrun.tool.tool.httpx.Client") + @patch("agentrun.tool.tool.Config") + def test_download_skill_with_ram_auth( + self, mock_config_class, mock_client_class, mock_signed_headers + ): + """测试预发环境使用 RAM 签名认证""" + import io + import os + import shutil + import tempfile + import zipfile + + # 模拟配置了 AK/SK 的情况 + mock_config = Mock() + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + mock_config.get_security_token.return_value = None + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + # 模拟 RAM 签名 + mock_signed_headers.return_value = { + "Agentrun-Authorization": ( + "AGENTRUN4-HMAC-SHA256 Credential=test-ak" + ), + "x-acs-date": "20260330T000000Z", + "x-acs-content-sha256": "UNSIGNED-PAYLOAD", + } + + # 创建测试用的 zip 文件 + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("skill.py", "print('skill')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + # 测试预发环境 URL + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://1760720386195983.funagent-data-pre.cn-hangzhou.aliyuncs.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = tool.download_skill(target_dir=tmp_dir) + + # 验证 RAM 签名被调用 + assert mock_signed_headers.called + # 验证使用的是 RAM 端点 + call_args = mock_signed_headers.call_args + assert "-ram.funagent-data-pre" in call_args[1]["url"] + + expected_dir = os.path.join(tmp_dir, "test-skill") + assert result == expected_dir + finally: + shutil.rmtree(tmp_dir) + + @patch("agentrun.tool.tool.get_agentrun_signed_headers") + @patch("agentrun.tool.tool.httpx.Client") + @patch("agentrun.tool.tool.Config") + def test_download_skill_without_ram_auth( + self, mock_config_class, mock_client_class, mock_signed_headers + ): + """测试没有 AK/SK 时不使用 RAM 签名""" + import io + import os + import shutil + import tempfile + import zipfile + + # 模拟没有配置 AK/SK 的情况 + mock_config = Mock() + mock_config.get_access_key_id.return_value = None + mock_config.get_access_key_secret.return_value = None + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("skill.py", "print('skill')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = tool.download_skill(target_dir=tmp_dir) + + # 验证 RAM 签名没有被调用 + assert not mock_signed_headers.called + + expected_dir = os.path.join(tmp_dir, "test-skill") + assert result == expected_dir + finally: + shutil.rmtree(tmp_dir) + + # ==================== create_method 鉴权策略测试 ==================== + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp_remote_without_proxy_skips_ram( + self, mock_config_class, mock_mcp_session_class + ): + """测试 MCP_REMOTE + proxy_enabled=false 时不走 RAM 鉴权""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "ok"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + data_endpoint="https://example.agentrun-data.aliyuncs.com", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + ) + + tool.call_tool("tool1", {}) + + # 验证 ToolMCPSession 被调用时 use_ram_auth=False + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["use_ram_auth"] is False + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp_remote_with_proxy_uses_ram( + self, mock_config_class, mock_mcp_session_class + ): + """测试 MCP_REMOTE + proxy_enabled=true 时走 RAM 鉴权""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "ok"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + data_endpoint="https://example.agentrun-data.aliyuncs.com", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=True + ), + ) + + tool.call_tool("tool1", {}) + + # 验证 ToolMCPSession 被调用时 use_ram_auth=True + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["use_ram_auth"] is True + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp_bundle_always_uses_ram( + self, mock_config_class, mock_mcp_session_class + ): + """测试 MCP_BUNDLE 类型始终走 RAM 鉴权""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "ok"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_BUNDLE", + data_endpoint="https://example.agentrun-data.aliyuncs.com", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + ) + + tool.call_tool("tool1", {}) + + # MCP_BUNDLE 即使 proxy_enabled=False 也要走 RAM + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["use_ram_auth"] is True + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall_openapi_import_skips_ram( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL + OPENAPI_IMPORT 时不走 RAM 鉴权""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "ok"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="FUNCTIONCALL", + create_method="OPENAPI_IMPORT", + protocol_spec=( + '{"openapi": "3.0.0", "servers": [{"url":' + ' "https://external.example.com"}]}' + ), + ) + + tool.call_tool("tool1", {}) + + # 验证 ToolOpenAPIClient 被调用时 use_ram_auth=False + call_kwargs = mock_openapi_client_class.call_args[1] + assert call_kwargs["use_ram_auth"] is False + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall_code_package_uses_ram( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL + CODE_PACKAGE 时走 RAM 鉴权""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "ok"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="FUNCTIONCALL", + create_method="CODE_PACKAGE", + data_endpoint="https://example.agentrun-data.aliyuncs.com", + ) + + tool.call_tool("tool1", {}) + + # 验证 ToolOpenAPIClient 被调用时 use_ram_auth=True + call_kwargs = mock_openapi_client_class.call_args[1] + assert call_kwargs["use_ram_auth"] is True + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_mcp_remote_without_proxy_skips_ram( + self, mock_config_class, mock_mcp_session_class + ): + """测试异步调用:MCP_REMOTE + proxy_enabled=false 时不走 RAM 鉴权""" + mock_session = Mock() + mock_session.call_tool_async = AsyncMock(return_value={"result": "ok"}) + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + data_endpoint="https://example.agentrun-data.aliyuncs.com", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + ) + + await tool.call_tool_async("tool1", {}) + + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["use_ram_auth"] is False + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_functioncall_openapi_import_skips_ram( + self, mock_config_class, mock_openapi_client_class + ): + """测试异步调用:FUNCTIONCALL + OPENAPI_IMPORT 时不走 RAM 鉴权""" + mock_client = Mock() + mock_client.call_tool_async = AsyncMock(return_value={"result": "ok"}) + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="FUNCTIONCALL", + create_method="OPENAPI_IMPORT", + protocol_spec=( + '{"openapi": "3.0.0", "servers": [{"url":' + ' "https://external.example.com"}]}' + ), + ) + + await tool.call_tool_async("tool1", {}) + + call_kwargs = mock_openapi_client_class.call_args[1] + assert call_kwargs["use_ram_auth"] is False + + def test_tool_create_method_field(self): + """测试 Tool 的 create_method 字段""" + tool = Tool(create_method="MCP_REMOTE") + assert tool.create_method == "MCP_REMOTE" + + tool2 = Tool(create_method="OPENAPI_IMPORT") + assert tool2.create_method == "OPENAPI_IMPORT" + + tool3 = Tool() + assert tool3.create_method is None + class TestToolClient: """测试 ToolClient""" From c177925a4d4c6ce56993f4bbc49469c1455a67b9 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Tue, 31 Mar 2026 14:13:21 +0800 Subject: [PATCH 06/10] =?UTF-8?q?refactor(tool):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=AE=BF=E9=97=AE=E6=96=B9=E5=BC=8F=E5=B9=B6?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=B5=8B=E8=AF=95=E8=A6=86=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本提交更新了 `Tool` 类中的配置访问方法,将直接属性访问改为调用统一的方法接口,并相应地增强了单元测试以更好地模拟各种边界情况下的行为。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/tool/__tool_async_template.py | 6 +- agentrun/tool/tool.py | 6 +- tests/unittests/tool/test_tool.py | 126 ++++++++++++++++++++----- 3 files changed, 110 insertions(+), 28 deletions(-) diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py index 7f8b8a8..a070ad8 100644 --- a/agentrun/tool/__tool_async_template.py +++ b/agentrun/tool/__tool_async_template.py @@ -181,7 +181,7 @@ def _get_functioncall_server_url( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None return f"{data_endpoint}/tools/{effective_name}" @@ -210,7 +210,7 @@ def _get_mcp_endpoint( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None @@ -451,7 +451,7 @@ def _get_skill_download_url( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None return f"{data_endpoint}/tools/{effective_name}/download" diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py index 42439f0..ae6c530 100644 --- a/agentrun/tool/tool.py +++ b/agentrun/tool/tool.py @@ -206,7 +206,7 @@ def _get_functioncall_server_url( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None return f"{data_endpoint}/tools/{effective_name}" @@ -235,7 +235,7 @@ def _get_mcp_endpoint( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None @@ -611,7 +611,7 @@ def _get_skill_download_url( data_endpoint = self.data_endpoint if not data_endpoint: cfg = Config.with_configs(config) - data_endpoint = cfg._data_endpoint + data_endpoint = cfg.get_data_endpoint() if not data_endpoint or not effective_name: return None return f"{data_endpoint}/tools/{effective_name}/download" diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py index da129b0..865570c 100644 --- a/tests/unittests/tool/test_tool.py +++ b/tests/unittests/tool/test_tool.py @@ -139,13 +139,20 @@ def test_get_mcp_endpoint_no_name(self): endpoint = tool._get_mcp_endpoint() assert endpoint is None - def test_get_mcp_endpoint_no_data_endpoint(self): - """测试没有 data_endpoint 时获取 MCP endpoint""" + @patch("agentrun.tool.tool.Config") + def test_get_mcp_endpoint_no_data_endpoint(self, mock_config_class): + """测试没有 data_endpoint 时从 Config.get_data_endpoint() 兜底""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config_class.with_configs.return_value = mock_config + tool = Tool( tool_name="my-tool", ) endpoint = tool._get_mcp_endpoint() - assert endpoint is None + assert endpoint == "https://fallback.example.com/tools/my-tool/sse" def test_from_inner_object(self): """测试从内部对象创建 Tool""" @@ -244,8 +251,11 @@ def test_list_tools_mcp(self, mock_config_class, mock_mcp_session_class): assert tools[0].name == "tool1" assert tools[1].name == "tool2" + @patch("agentrun.tool.tool.Config") @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - def test_list_tools_functioncall(self, mock_openapi_client_class): + def test_list_tools_functioncall( + self, mock_openapi_client_class, mock_config_class + ): """测试获取 FUNCTIONCALL 工具列表""" mock_client = Mock() mock_client.list_tools.return_value = [ @@ -254,6 +264,13 @@ def test_list_tools_functioncall(self, mock_openapi_client_class): ] mock_openapi_client_class.return_value = mock_client + mock_config = Mock() + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + tool = Tool( tool_type="FUNCTIONCALL", protocol_spec='{"openapi": "3.0.0"}', @@ -295,7 +312,7 @@ def test_call_tool_mcp(self, mock_config_class, mock_mcp_session_class): assert result == {"result": "success"} @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") def test_call_tool_functioncall( self, mock_config_class, mock_openapi_client_class ): @@ -306,6 +323,11 @@ def test_call_tool_functioncall( mock_config = Mock() mock_config.get_headers.return_value = {} + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_access_key_id.return_value = "" + mock_config.get_access_key_secret.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool( @@ -417,9 +439,11 @@ def test_get_skill_download_url_tool_name_takes_priority(self): @patch("agentrun.tool.tool.Config") def test_get_skill_download_url_config_fallback(self, mock_config_class): - """测试 data_endpoint 为空时从 Config 获取""" + """测试 data_endpoint 为空时从 Config.get_data_endpoint() 获取""" mock_config = Mock() - mock_config._data_endpoint = "https://config-endpoint.com" + mock_config.get_data_endpoint.return_value = ( + "https://config-endpoint.com" + ) mock_config_class.with_configs.return_value = mock_config tool = Tool(tool_name="my-skill") @@ -434,9 +458,9 @@ def test_get_skill_download_url_no_name(self): @patch("agentrun.tool.tool.Config") def test_get_skill_download_url_no_endpoint(self, mock_config_class): - """测试没有 data_endpoint 且 Config 也没有时返回 None""" + """测试没有 data_endpoint 且 Config.get_data_endpoint() 返回空时返回 None""" mock_config = Mock() - mock_config._data_endpoint = None + mock_config.get_data_endpoint.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool(tool_name="my-skill") @@ -558,8 +582,13 @@ async def test_download_skill_async_wrong_type(self): with pytest.raises(ValueError, match="only available for SKILL"): await tool.download_skill_async() - async def test_download_skill_async_no_url(self): - """测试无法构造下载 URL 时抛出 ValueError""" + @patch("agentrun.tool.tool.Config") + async def test_download_skill_async_no_url(self, mock_config_class): + """测试无法构造下载 URL 时抛出 ValueError(无 name 且 get_data_endpoint 返回空)""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = "" + mock_config_class.with_configs.return_value = mock_config + tool = Tool(tool_type="SKILL") with pytest.raises(ValueError, match="Cannot construct download URL"): @@ -875,7 +904,7 @@ def test_call_tool_mcp_bundle_always_uses_ram( assert call_kwargs["use_ram_auth"] is True @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") def test_call_tool_functioncall_openapi_import_skips_ram( self, mock_config_class, mock_openapi_client_class ): @@ -886,6 +915,11 @@ def test_call_tool_functioncall_openapi_import_skips_ram( mock_config = Mock() mock_config.get_headers.return_value = {} + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_access_key_id.return_value = "" + mock_config.get_access_key_secret.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool( @@ -961,7 +995,7 @@ async def test_call_tool_async_mcp_remote_without_proxy_skips_ram( assert call_kwargs["use_ram_auth"] is False @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") async def test_call_tool_async_functioncall_openapi_import_skips_ram( self, mock_config_class, mock_openapi_client_class ): @@ -972,6 +1006,11 @@ async def test_call_tool_async_functioncall_openapi_import_skips_ram( mock_config = Mock() mock_config.get_headers.return_value = {} + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_access_key_id.return_value = "" + mock_config.get_access_key_secret.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool( @@ -1244,25 +1283,37 @@ def test_get_functioncall_server_url(self): assert url == "https://example.com/data/tools/my-tool" - def test_get_functioncall_server_url_no_endpoint(self): + @patch("agentrun.tool.tool.Config") + def test_get_functioncall_server_url_no_endpoint(self, mock_config_class): """测试 _get_functioncall_server_url 没有 data_endpoint 和 name 时返回 None""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config_class.with_configs.return_value = mock_config + tool = Tool() url = tool._get_functioncall_server_url() assert url is None - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") async def test_list_tools_async_mcp_no_endpoint(self, mock_config_class): - """测试 MCP 类型但没有 endpoint 时返回空列表""" + """测试 MCP 类型但没有 endpoint 时,使用 Config.get_data_endpoint() 兜底""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = "" + mock_config_class.with_configs.return_value = mock_config + tool = Tool(tool_name="my-tool", tool_type="MCP") tools = await tool.list_tools_async() assert tools == [] + @patch("agentrun.tool.tool.Config") @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") async def test_list_tools_async_functioncall( - self, mock_openapi_client_class + self, mock_openapi_client_class, mock_config_class ): """测试 FUNCTIONCALL 类型的 list_tools_async""" mock_client = Mock() @@ -1274,6 +1325,13 @@ async def test_list_tools_async_functioncall( ) mock_openapi_client_class.return_value = mock_client + mock_config = Mock() + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + tool = Tool( tool_type="FUNCTIONCALL", protocol_spec='{"openapi": "3.0.0"}', @@ -1292,7 +1350,7 @@ async def test_list_tools_async_no_type(self): assert tools == [] @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") async def test_call_tool_async_functioncall( self, mock_config_class, mock_openapi_client_class ): @@ -1305,6 +1363,11 @@ async def test_call_tool_async_functioncall( mock_config = Mock() mock_config.get_headers.return_value = {} + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_access_key_id.return_value = "" + mock_config.get_access_key_secret.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool( @@ -1316,15 +1379,20 @@ async def test_call_tool_async_functioncall( assert result == {"result": "success"} - async def test_call_tool_async_mcp_no_endpoint(self): + @patch("agentrun.tool.tool.Config") + async def test_call_tool_async_mcp_no_endpoint(self, mock_config_class): """测试 MCP 类型但没有 endpoint 时 call_tool_async 抛出 ValueError""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = "" + mock_config_class.with_configs.return_value = mock_config + tool = Tool(tool_name="my-tool", tool_type="MCP") with pytest.raises(ValueError, match="MCP endpoint not available"): await tool.call_tool_async("tool1", {"param": "value"}) @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") def test_call_tool_functioncall( self, mock_config_class, mock_openapi_client_class ): @@ -1335,6 +1403,11 @@ def test_call_tool_functioncall( mock_config = Mock() mock_config.get_headers.return_value = {} + mock_config.get_data_endpoint.return_value = ( + "https://fallback.example.com" + ) + mock_config.get_access_key_id.return_value = "" + mock_config.get_access_key_secret.return_value = "" mock_config_class.with_configs.return_value = mock_config tool = Tool( @@ -1346,16 +1419,25 @@ def test_call_tool_functioncall( assert result == {"result": "success"} - def test_call_tool_mcp_no_endpoint(self): + @patch("agentrun.tool.tool.Config") + def test_call_tool_mcp_no_endpoint(self, mock_config_class): """测试 MCP 类型但没有 endpoint 时 call_tool 抛出 ValueError""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = "" + mock_config_class.with_configs.return_value = mock_config + tool = Tool(tool_name="my-tool", tool_type="MCP") with pytest.raises(ValueError, match="MCP endpoint not available"): tool.call_tool("tool1", {"param": "value"}) - @patch("agentrun.utils.config.Config") + @patch("agentrun.tool.tool.Config") def test_list_tools_mcp_no_endpoint(self, mock_config_class): """测试 MCP 类型但没有 endpoint 时 list_tools 返回空列表""" + mock_config = Mock() + mock_config.get_data_endpoint.return_value = "" + mock_config_class.with_configs.return_value = mock_config + tool = Tool(tool_name="my-tool", tool_type="MCP") tools = tool.list_tools() From 7124dd305e68df906e20661b65ce33929d4b76fd Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Wed, 1 Apr 2026 20:57:36 +0800 Subject: [PATCH 07/10] =?UTF-8?q?refactor(skill=5Floader):=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=8A=80=E8=83=BD=E5=8A=A0=E8=BD=BD=E5=99=A8=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E5=B9=B6=E5=A2=9E=E5=8A=A0=E5=8D=95=E5=85=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E8=A6=86=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次更新改进了skill_loader模块的功能,并添加了相应的单元测试以提高代码质量和稳定性。 Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/integration/utils/skill_loader.py | 344 ++++++++++++- .../integration/test_skill_loader.py | 450 +++++++++++++++++- 2 files changed, 777 insertions(+), 17 deletions(-) diff --git a/agentrun/integration/utils/skill_loader.py b/agentrun/integration/utils/skill_loader.py index ca178ef..7eed1f5 100644 --- a/agentrun/integration/utils/skill_loader.py +++ b/agentrun/integration/utils/skill_loader.py @@ -11,7 +11,8 @@ import json import os import re -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +import subprocess +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union from agentrun.integration.utils.tool import CommonToolSet, Tool, ToolParameter from agentrun.utils.log import logger @@ -20,6 +21,10 @@ from agentrun.tool.tool import Tool as ToolResource from agentrun.utils.config import Config +# Maximum output size for execute_command (bytes) +# execute_command 输出大小限制(字节) +MAX_OUTPUT_SIZE = 102400 # 100KB + @dataclass class SkillInfo: @@ -100,11 +105,15 @@ def __init__( skills_dir: str = ".skills", remote_skill_names: Optional[List[str]] = None, config: Optional["Config"] = None, + command_approval: Optional[Callable[[str, str], bool]] = None, + command_timeout: int = 300, ): self._skills_dir = skills_dir self._remote_skill_names = remote_skill_names or [] self._config = config self._skills_cache: Optional[List[SkillInfo]] = None + self._command_approval = command_approval + self._command_timeout = command_timeout def _ensure_skills_available(self) -> None: """确保远程 skill 已下载到本地 / Ensure remote skills are downloaded locally @@ -364,8 +373,237 @@ def _load_skills_func(self, name: Optional[str] = None) -> str: } return json.dumps(detail_result, ensure_ascii=False) + def _read_skill_file_func(self, name: str, relative_path: str) -> str: + """read_skill_file 工具的执行函数 / Execution function for the read_skill_file tool + + 读取指定 skill 目录内的文件内容,带路径穿越保护。 + Reads file content within a specific skill directory with path traversal protection. + + Args: + name: skill 名称 / skill name + relative_path: skill 目录内的相对路径 / relative path within skill directory + + Returns: + JSON 字符串 / JSON string + """ + skills = self.scan_skills() + target_skill: Optional[SkillInfo] = None + for skill in skills: + if skill.name == name: + target_skill = skill + break + + if target_skill is None: + available = [s.name for s in skills] + available_str = ", ".join(available) if available else "none" + return json.dumps( + { + "error": ( + f"Skill '{name}' not found. " + f"Available skills: {available_str}" + ) + }, + ensure_ascii=False, + ) + + # Path traversal protection / 路径穿越保护 + skill_real_dir = os.path.realpath(target_skill.path) + target_path = os.path.realpath( + os.path.join(target_skill.path, relative_path) + ) + if ( + not target_path.startswith(skill_real_dir + os.sep) + and target_path != skill_real_dir + ): + return json.dumps( + { + "error": ( + f"Path '{relative_path}' is outside the skill" + " directory. Access denied." + ) + }, + ensure_ascii=False, + ) + + if not os.path.exists(target_path): + return json.dumps( + { + "error": ( + f"File '{relative_path}' not found in skill '{name}'." + ) + }, + ensure_ascii=False, + ) + + # Directory listing / 目录列表 + if os.path.isdir(target_path): + try: + entries: List[str] = [] + for entry in sorted(os.listdir(target_path)): + if not entry.startswith("."): + entry_full = os.path.join(target_path, entry) + if os.path.isdir(entry_full): + entries.append(entry + "/") + else: + entries.append(entry) + return json.dumps({"files": entries}, ensure_ascii=False) + except OSError as error: + return json.dumps( + {"error": f"Failed to list directory: {error}"}, + ensure_ascii=False, + ) + + # File reading / 文件读取 + try: + with open(target_path, "r", encoding="utf-8") as file_handle: + content = file_handle.read() + return json.dumps({"content": content}, ensure_ascii=False) + except UnicodeDecodeError: + return json.dumps( + { + "error": ( + f"File '{relative_path}' cannot be read as text. " + "It may be a binary file." + ) + }, + ensure_ascii=False, + ) + except OSError as error: + return json.dumps( + {"error": f"Failed to read file: {error}"}, + ensure_ascii=False, + ) + + def _truncate_output(self, output: str) -> str: + """截断过大的输出 / Truncate oversized output + + Args: + output: 原始输出 / original output + + Returns: + 截断后的输出 / truncated output + """ + if len(output.encode("utf-8", errors="replace")) <= MAX_OUTPUT_SIZE: + return output + # Truncate by bytes then decode safely + truncated = output.encode("utf-8", errors="replace")[:MAX_OUTPUT_SIZE] + return truncated.decode("utf-8", errors="replace") + ( + f"\n... [output truncated, exceeded {MAX_OUTPUT_SIZE} bytes]" + ) + + def _execute_command_func( + self, + command: str, + cwd: Optional[str] = None, + timeout: Optional[int] = None, + ) -> str: + """execute_command 工具的执行函数 / Execution function for the execute_command tool + + 在本地机器上执行 shell 命令。 + Executes a shell command on the local machine. + + Args: + command: 要执行的命令 / command to execute + cwd: 工作目录(可选,默认 skills_dir)/ working directory (optional, defaults to skills_dir) + timeout: 超时秒数(可选,默认使用 command_timeout)/ timeout in seconds (optional) + + Returns: + JSON 字符串 / JSON string + """ + resolved_cwd = cwd if cwd else self._skills_dir + resolved_timeout = ( + timeout if timeout is not None else self._command_timeout + ) + + # Validate cwd exists / 验证工作目录存在 + if not os.path.isdir(resolved_cwd): + return json.dumps( + { + "error": ( + f"Working directory '{resolved_cwd}' does not exist." + ) + }, + ensure_ascii=False, + ) + + # Command approval callback / 命令确认回调 + if self._command_approval is not None: + try: + approved = self._command_approval(command, resolved_cwd) + except Exception as approval_error: + logger.warning( + "Command approval callback raised an error:" + f" {approval_error}" + ) + return json.dumps( + { + "error": ( + "Command approval callback failed: " + f"{approval_error}" + ) + }, + ensure_ascii=False, + ) + if not approved: + return json.dumps( + {"error": "Command execution rejected by user."}, + ensure_ascii=False, + ) + + logger.info( + f"Executing command: {command!r} in cwd={resolved_cwd!r} " + f"timeout={resolved_timeout}s" + ) + + try: + completed = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + cwd=resolved_cwd, + timeout=resolved_timeout, + ) + stdout = self._truncate_output(completed.stdout) + stderr = self._truncate_output(completed.stderr) + + logger.info(f"Command finished: exit_code={completed.returncode}") + + return json.dumps( + { + "stdout": stdout, + "stderr": stderr, + "exit_code": completed.returncode, + "timed_out": False, + }, + ensure_ascii=False, + ) + except subprocess.TimeoutExpired: + logger.warning( + f"Command timed out after {resolved_timeout}s: {command!r}" + ) + return json.dumps( + { + "stdout": "", + "stderr": ( + f"Command timed out after {resolved_timeout} seconds." + ), + "exit_code": -1, + "timed_out": True, + }, + ensure_ascii=False, + ) + except OSError as error: + logger.error(f"Failed to execute command: {error}") + return json.dumps( + {"error": f"Failed to execute command: {error}"}, + ensure_ascii=False, + ) + def to_common_toolset(self) -> CommonToolSet: - """构造包含 load_skills 工具的 CommonToolSet / Construct CommonToolSet with load_skills tool + """构造包含 load_skills、read_skill_file、execute_command 工具的 CommonToolSet + + Construct CommonToolSet with load_skills, read_skill_file, and execute_command tools. Returns: CommonToolSet 实例 / CommonToolSet instance @@ -390,7 +628,82 @@ def to_common_toolset(self) -> CommonToolSet: func=self._load_skills_func, ) - return CommonToolSet(tools_list=[load_skills_tool]) + read_skill_file_tool = Tool( + name="read_skill_file", + description=( + "Read a file from a skill's directory. " + "Returns the file content as text, or lists directory contents " + "if the path points to a directory. " + "Only files within the skill directory can be accessed." + ), + parameters=[ + ToolParameter( + name="name", + param_type="string", + description="The name of the skill containing the file.", + required=True, + ), + ToolParameter( + name="relative_path", + param_type="string", + description=( + "Relative path to the file within the skill directory " + "(e.g., 'scripts/run.sh', 'requirements.txt')." + ), + required=True, + ), + ], + func=self._read_skill_file_func, + ) + + execute_command_tool = Tool( + name="execute_command", + description=( + "Execute a shell command on the local machine. " + "Use this to run scripts, install dependencies, or perform " + "file operations as instructed by skill documentation. " + "Returns stdout, stderr, exit_code, and timeout status.\n\n" + "⚠️ IMPORTANT: Before calling this tool, you MUST first " + "display the exact command to the user and ask for explicit " + "confirmation. Only proceed if the user approves. " + "Never execute commands without user approval." + ), + parameters=[ + ToolParameter( + name="command", + param_type="string", + description="The shell command to execute.", + required=True, + ), + ToolParameter( + name="cwd", + param_type="string", + description=( + "Working directory for the command. " + "Defaults to the skills directory if not specified." + ), + required=False, + ), + ToolParameter( + name="timeout", + param_type="integer", + description=( + "Maximum execution time in seconds. " + f"Defaults to {self._command_timeout}." + ), + required=False, + ), + ], + func=self._execute_command_func, + ) + + return CommonToolSet( + tools_list=[ + load_skills_tool, + read_skill_file_tool, + execute_command_tool, + ] + ) def skill_tools( @@ -398,6 +711,8 @@ def skill_tools( *, skills_dir: str = ".skills", config: Optional["Config"] = None, + command_approval: Optional[Callable[[str, str], bool]] = None, + command_timeout: int = 300, ) -> CommonToolSet: """将 Skill 封装为通用工具集 / Wrap Skills as CommonToolSet @@ -413,10 +728,16 @@ def skill_tools( If not provided, only loads local skills from skills_dir. skills_dir: 本地 skill 目录,默认 ".skills" / Local skill directory, default ".skills" config: 配置对象 / Configuration object + command_approval: 命令执行前的确认回调函数(可选)/ + Optional approval callback invoked before executing commands. + 接收 (command, cwd) 参数,返回 True 允许执行,False 拒绝 / + Receives (command, cwd), returns True to allow, False to reject. + command_timeout: execute_command 的默认超时秒数,默认 30 / + Default timeout in seconds for execute_command, default 30. Returns: - CommonToolSet: 包含 load_skills 工具的通用工具集 / - CommonToolSet containing the load_skills tool + CommonToolSet: 包含 load_skills、read_skill_file、execute_command 工具的通用工具集 / + CommonToolSet containing load_skills, read_skill_file, and execute_command tools Examples: >>> # 仅加载本地 skill / Load local skills only @@ -425,11 +746,14 @@ def skill_tools( >>> # 下载远程 skill 后加载 / Download remote skill then load >>> ts = skill_tools("my-remote-skill") >>> - >>> # 下载多个远程 skill / Download multiple remote skills - >>> ts = skill_tools(["skill-a", "skill-b"]) + >>> # 带命令确认回调 / With command approval callback + >>> ts = skill_tools( + ... skills_dir=".skills", + ... command_approval=lambda cmd, cwd: input(f"Execute '{cmd}'? [y/N]: ").lower() == "y", + ... ) >>> - >>> # 转换为 LangChain 工具 / Convert to LangChain tools - >>> lc_tools = ts.to_langchain() + >>> # 自定义超时 / Custom timeout + >>> ts = skill_tools(skills_dir=".skills", command_timeout=120) """ remote_names: List[str] = [] @@ -455,5 +779,7 @@ def skill_tools( skills_dir=skills_dir, remote_skill_names=remote_names, config=config, + command_approval=command_approval, + command_timeout=command_timeout, ) return loader.to_common_toolset() diff --git a/tests/unittests/integration/test_skill_loader.py b/tests/unittests/integration/test_skill_loader.py index 830bba1..eb8bd10 100644 --- a/tests/unittests/integration/test_skill_loader.py +++ b/tests/unittests/integration/test_skill_loader.py @@ -499,8 +499,11 @@ def test_toolset_has_load_skills_tool(self, tmp_path: Any) -> None: loader = SkillLoader(skills_dir=skills_dir) toolset = loader.to_common_toolset() tools_list = toolset.tools() - assert len(tools_list) == 1 - assert tools_list[0].name == "load_skills" + assert len(tools_list) == 3 + tool_names = [t.name for t in tools_list] + assert "load_skills" in tool_names + assert "read_skill_file" in tool_names + assert "execute_command" in tool_names def test_tool_description_contains_skill_names(self, tmp_path: Any) -> None: skills_dir = str(tmp_path / "skills") @@ -553,8 +556,9 @@ def test_empty_skills_dir_still_returns_toolset( toolset = loader.to_common_toolset() assert isinstance(toolset, CommonToolSet) tools_list = toolset.tools() - assert len(tools_list) == 1 - assert "No skills available" in tools_list[0].description + assert len(tools_list) == 3 + load_skills_tool = [t for t in tools_list if t.name == "load_skills"][0] + assert "No skills available" in load_skills_tool.description # ============================================================================= @@ -577,7 +581,7 @@ def test_local_only(self, tmp_path: Any) -> None: ) toolset = skill_tools(skills_dir=skills_dir) assert isinstance(toolset, CommonToolSet) - assert len(toolset.tools()) == 1 + assert len(toolset.tools()) == 3 def test_with_string_name_triggers_remote_download( self, tmp_path: Any @@ -855,10 +859,10 @@ def test_full_workflow(self, tmp_path: Any) -> None: toolset = skill_tools(skills_dir=skills_dir) assert isinstance(toolset, CommonToolSet) tools_list = toolset.tools() - assert len(tools_list) == 1 + assert len(tools_list) == 3 - tool = tools_list[0] - assert tool.name == "load_skills" + tool_map = {t.name: t for t in tools_list} + tool = tool_map["load_skills"] assert "e2e-skill" in tool.description # List all skills @@ -909,3 +913,433 @@ def test_multiple_skills_workflow(self, tmp_path: Any) -> None: beta = json.loads(tool.func(name="skill-beta")) assert beta["name"] == "skill-beta" assert beta["instruction"] == "" + + +# ============================================================================= +# 13. read_skill_file 工具测试 +# ============================================================================= + + +class TestReadSkillFile: + """测试 _read_skill_file_func 方法""" + + def _get_read_skill_file_tool(self, skills_dir: str, **kwargs: Any) -> Any: + loader = SkillLoader(skills_dir=skills_dir, **kwargs) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + return tool_map["read_skill_file"] + + def test_read_existing_file(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n# Hello", + extra_files={"config.json": '{"key": "value"}'}, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("my-skill", "config.json") + ) + assert "content" in result + assert '"key": "value"' in result["content"] + + def test_file_not_found(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("my-skill", "nonexistent.txt") + ) + assert "error" in result + assert "not found" in result["error"] + + def test_skill_not_found(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "existing-skill", + skill_md_content="---\nname: existing-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("no-such-skill", "file.txt") + ) + assert "error" in result + assert "not found" in result["error"] + assert "existing-skill" in result["error"] + + def test_path_traversal_with_dotdot(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n", + ) + # Create a file outside the skill dir + with open(tmp_path / "secret.txt", "w") as fh: + fh.write("secret data") + + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("my-skill", "../../secret.txt") + ) + assert "error" in result + assert ( + "outside" in result["error"].lower() + or "denied" in result["error"].lower() + ) + + def test_path_traversal_with_absolute_path(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("my-skill", "/etc/passwd") + ) + assert "error" in result + assert ( + "outside" in result["error"].lower() + or "denied" in result["error"].lower() + ) + + def test_binary_file(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + skill_path = _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n", + ) + # Write a binary file + with open(os.path.join(skill_path, "data.bin"), "wb") as fh: + fh.write(bytes(range(256))) + + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._read_skill_file_func("my-skill", "data.bin") + ) + assert "error" in result + assert "binary" in result["error"].lower() + + def test_directory_listing(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content="---\nname: my-skill\n---\n", + extra_files={ + "scripts/run.sh": "#!/bin/bash\necho hi", + "scripts/setup.py": "print('setup')", + }, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._read_skill_file_func("my-skill", "scripts")) + assert "files" in result + assert "run.sh" in result["files"] + assert "setup.py" in result["files"] + + def test_read_skill_file_via_tool(self, tmp_path: Any) -> None: + """Test read_skill_file via the Tool object's func""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "tool-skill", + skill_md_content="---\nname: tool-skill\n---\n", + extra_files={"readme.txt": "Hello from tool"}, + ) + tool = self._get_read_skill_file_tool(skills_dir) + result = json.loads( + tool.func(name="tool-skill", relative_path="readme.txt") + ) + assert "content" in result + assert "Hello from tool" in result["content"] + + +# ============================================================================= +# 14. execute_command 工具测试 +# ============================================================================= + + +class TestExecuteCommand: + """测试 _execute_command_func 方法""" + + def test_normal_execution(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._execute_command_func("echo hello")) + assert result["exit_code"] == 0 + assert "hello" in result["stdout"] + assert result["timed_out"] is False + + def test_nonzero_exit_code(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._execute_command_func("exit 42")) + assert result["exit_code"] == 42 + assert result["timed_out"] is False + + def test_stderr_output(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._execute_command_func("echo error_msg >&2")) + assert "error_msg" in result["stderr"] + + def test_custom_cwd(self, tmp_path: Any) -> None: + custom_dir = str(tmp_path / "custom") + os.makedirs(custom_dir) + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._execute_command_func("pwd", cwd=custom_dir)) + assert result["exit_code"] == 0 + assert custom_dir in result["stdout"] + + def test_nonexistent_cwd(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads( + loader._execute_command_func( + "echo hi", cwd="/nonexistent/path/12345" + ) + ) + assert "error" in result + assert "does not exist" in result["error"] + + def test_timeout(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir, command_timeout=1) + result = json.loads(loader._execute_command_func("sleep 10", timeout=1)) + assert result["timed_out"] is True + assert result["exit_code"] == -1 + + def test_output_truncation(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + # Generate output larger than 100KB + large_output_cmd = "python3 -c \"print('A' * 200000)\"" + result = json.loads(loader._execute_command_func(large_output_cmd)) + assert result["exit_code"] == 0 + assert "truncated" in result["stdout"] + + def test_command_approval_approved(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + approval_calls: List[tuple[str, str]] = [] + + def approve(command: str, cwd: str) -> bool: + approval_calls.append((command, cwd)) + return True + + loader = SkillLoader(skills_dir=skills_dir, command_approval=approve) + result = json.loads(loader._execute_command_func("echo approved")) + assert result["exit_code"] == 0 + assert "approved" in result["stdout"] + assert len(approval_calls) == 1 + assert approval_calls[0][0] == "echo approved" + + def test_command_approval_rejected(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + def reject(command: str, cwd: str) -> bool: + return False + + loader = SkillLoader(skills_dir=skills_dir, command_approval=reject) + result = json.loads(loader._execute_command_func("echo should_not_run")) + assert "error" in result + assert "rejected" in result["error"].lower() + + def test_no_command_approval(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir, command_approval=None) + result = json.loads(loader._execute_command_func("echo no_approval")) + assert result["exit_code"] == 0 + assert "no_approval" in result["stdout"] + + def test_command_approval_raises_exception(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + def broken_approval(command: str, cwd: str) -> bool: + raise RuntimeError("approval callback broken") + + loader = SkillLoader( + skills_dir=skills_dir, command_approval=broken_approval + ) + result = json.loads(loader._execute_command_func("echo should_not_run")) + assert "error" in result + assert ( + "approval callback" in result["error"].lower() + or "broken" in result["error"].lower() + ) + + def test_default_cwd_is_skills_dir(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = json.loads(loader._execute_command_func("pwd")) + assert result["exit_code"] == 0 + # The resolved real path should match + assert os.path.realpath(skills_dir) in os.path.realpath( + result["stdout"].strip() + ) + + def test_execute_command_via_tool(self, tmp_path: Any) -> None: + """Test execute_command via the Tool object's func""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + tool = tool_map["execute_command"] + result = json.loads(tool.func(command="echo via_tool")) + assert result["exit_code"] == 0 + assert "via_tool" in result["stdout"] + + +# ============================================================================= +# 15. skill_tools() 新参数测试 +# ============================================================================= + + +class TestSkillToolsNewParams: + """测试 skill_tools() 的 command_approval 和 command_timeout 参数""" + + def test_command_approval_passed_through(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + approval_called = False + + def approval(command: str, cwd: str) -> bool: + nonlocal approval_called + approval_called = True + return True + + toolset = skill_tools(skills_dir=skills_dir, command_approval=approval) + tool_map = {t.name: t for t in toolset.tools()} + exec_tool = tool_map["execute_command"] + result = json.loads(exec_tool.func(command="echo test")) + assert approval_called + assert result["exit_code"] == 0 + + def test_command_timeout_passed_through(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + toolset = skill_tools(skills_dir=skills_dir, command_timeout=1) + tool_map = {t.name: t for t in toolset.tools()} + exec_tool = tool_map["execute_command"] + result = json.loads(exec_tool.func(command="sleep 10")) + assert result["timed_out"] is True + + def test_default_values(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + toolset = skill_tools(skills_dir=skills_dir) + tool_map = {t.name: t for t in toolset.tools()} + exec_tool = tool_map["execute_command"] + # Default timeout is 30, command should succeed quickly + result = json.loads(exec_tool.func(command="echo default")) + assert result["exit_code"] == 0 + assert "default" in result["stdout"] + + +# ============================================================================= +# 16. to_common_toolset() 返回 3 个工具测试 +# ============================================================================= + + +class TestToCommonToolsetThreeTools: + """测试 to_common_toolset() 返回包含 3 个工具的 CommonToolSet""" + + def test_returns_three_tools(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "test-skill", + skill_md_content="---\nname: test-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tools_list = toolset.tools() + assert len(tools_list) == 3 + tool_names = {t.name for t in tools_list} + assert tool_names == { + "load_skills", + "read_skill_file", + "execute_command", + } + + def test_load_skills_tool_has_correct_params(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + load_tool = tool_map["load_skills"] + assert "name" in load_tool.parameters["properties"] + + def test_read_skill_file_tool_has_correct_params( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + read_tool = tool_map["read_skill_file"] + assert "name" in read_tool.parameters["properties"] + assert "relative_path" in read_tool.parameters["properties"] + required = read_tool.parameters.get("required", []) + assert "name" in required + assert "relative_path" in required + + def test_execute_command_tool_has_correct_params( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + exec_tool = tool_map["execute_command"] + assert "command" in exec_tool.parameters["properties"] + assert "cwd" in exec_tool.parameters["properties"] + assert "timeout" in exec_tool.parameters["properties"] + required = exec_tool.parameters.get("required", []) + assert "command" in required + + def test_execute_command_description_has_safety_warning( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_map = {t.name: t for t in toolset.tools()} + exec_tool = tool_map["execute_command"] + assert "IMPORTANT" in exec_tool.description + assert "approval" in exec_tool.description.lower() From 39b32f8d067639cc4bc261e88c40aebc09cec157 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Mon, 6 Apr 2026 23:50:24 +0800 Subject: [PATCH 08/10] refactor(tool): update MCP endpoint handling and add protocol spec parsing This change updates the `_get_mcp_endpoint` method to return both the endpoint URL and session affinity as a tuple, and introduces a new method `_parse_protocol_spec_mcp_url` to handle parsing of MCP URLs from protocol specifications when using `MCP_REMOTE` without proxy enabled. This improves consistency across synchronous and asynchronous implementations of the tool class. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/tool/__tool_async_template.py | 107 ++++++++++--- agentrun/tool/tool.py | 123 +++++++++++---- tests/unittests/tool/test_tool.py | 205 ++++++++++++++++++++++++- 3 files changed, 381 insertions(+), 54 deletions(-) diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py index a070ad8..70ccf72 100644 --- a/agentrun/tool/__tool_async_template.py +++ b/agentrun/tool/__tool_async_template.py @@ -5,9 +5,10 @@ """ import io +import json import os import shutil -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse import zipfile @@ -196,16 +197,85 @@ def _get_tool_type(self) -> Optional[ToolType]: return None return None + def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: + """从 protocol_spec 解析 MCP 服务器 URL 和 session_affinity / Parse MCP server URL and session_affinity from protocol_spec + + 用于 MCP_REMOTE + proxy_enabled=false 场景,从 protocol_spec JSON 中提取 + 第一个 mcpServers entry 的 url 和 transportType。 + Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url and + transportType from the first mcpServers entry in protocol_spec JSON. + + Returns: + Tuple[str, str]: (mcp_url, session_affinity) + + Raises: + ValueError: protocol_spec 为空、格式不合法或缺少必要字段时抛出 + """ + if not self.protocol_spec: + raise ValueError( + "protocol_spec is required for MCP_REMOTE tool with proxy" + " disabled, but it is empty for tool" + f" '{self.tool_name or self.name}'" + ) + + try: + spec = json.loads(self.protocol_spec) + except (json.JSONDecodeError, TypeError) as exc: + raise ValueError( + "Failed to parse protocol_spec for tool" + f" '{self.tool_name or self.name}': {exc}" + ) from exc + + mcp_servers = spec.get("mcpServers") + if not mcp_servers or not isinstance(mcp_servers, dict): + raise ValueError( + "mcpServers not found or invalid in protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + first_server = next(iter(mcp_servers.values()), None) + if not first_server or not isinstance(first_server, dict): + raise ValueError( + "No MCP server entry found in protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + url = first_server.get("url") + if not url: + raise ValueError( + "url not found in MCP server entry of protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + transport_type = first_server.get("transportType", "sse") + if transport_type == "streamable-http": + session_affinity = "MCP_STREAMABLE" + else: + session_affinity = "MCP_SSE" + + return url, session_affinity + def _get_mcp_endpoint( self, config: Optional[Config] = None - ) -> Optional[str]: - """获取 MCP 数据链路 URL / Get MCP data endpoint URL + ) -> Optional[Tuple[str, str]]: + """获取 MCP 数据链路 URL 和 session_affinity / Get MCP data endpoint URL and session_affinity + + MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL 和 session_affinity。 + 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取。 + For MCP_REMOTE with proxy disabled, parses URL and session_affinity from protocol_spec. + Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config. - 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 - 如果 self.data_endpoint 为空,则从 Config 中获取。 - Determines /mcp or /sse path based on session_affinity. - Falls back to Config if self.data_endpoint is not set. + Returns: + Optional[Tuple[str, str]]: (endpoint_url, session_affinity) 或 None """ + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + + if is_mcp_remote_without_proxy: + return self._parse_protocol_spec_mcp_url() + effective_name = self.tool_name or self.name data_endpoint = self.data_endpoint if not data_endpoint: @@ -219,8 +289,11 @@ def _get_mcp_endpoint( ) if session_affinity == "MCP_STREAMABLE": - return f"{data_endpoint}/tools/{effective_name}/mcp" - return f"{data_endpoint}/tools/{effective_name}/sse" + return ( + f"{data_endpoint}/tools/{effective_name}/mcp", + session_affinity, + ) + return f"{data_endpoint}/tools/{effective_name}/sse", session_affinity async def list_tools_async( self, config: Optional[Config] = None @@ -240,16 +313,14 @@ async def list_tools_async( if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: logger.warning( "MCP endpoint not available for tool %s", self.name ) return [] - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -309,15 +380,13 @@ async def call_tool_async( if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: raise ValueError( f"MCP endpoint not available for tool {self.name}" ) - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py index ae6c530..397f4de 100644 --- a/agentrun/tool/tool.py +++ b/agentrun/tool/tool.py @@ -15,9 +15,10 @@ """ import io +import json import os import shutil -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse import zipfile @@ -221,16 +222,85 @@ def _get_tool_type(self) -> Optional[ToolType]: return None return None + def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: + """从 protocol_spec 解析 MCP 服务器 URL 和 session_affinity / Parse MCP server URL and session_affinity from protocol_spec + + 用于 MCP_REMOTE + proxy_enabled=false 场景,从 protocol_spec JSON 中提取 + 第一个 mcpServers entry 的 url 和 transportType。 + Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url and + transportType from the first mcpServers entry in protocol_spec JSON. + + Returns: + Tuple[str, str]: (mcp_url, session_affinity) + + Raises: + ValueError: protocol_spec 为空、格式不合法或缺少必要字段时抛出 + """ + if not self.protocol_spec: + raise ValueError( + "protocol_spec is required for MCP_REMOTE tool with proxy" + " disabled, but it is empty for tool" + f" '{self.tool_name or self.name}'" + ) + + try: + spec = json.loads(self.protocol_spec) + except (json.JSONDecodeError, TypeError) as exc: + raise ValueError( + "Failed to parse protocol_spec for tool" + f" '{self.tool_name or self.name}': {exc}" + ) from exc + + mcp_servers = spec.get("mcpServers") + if not mcp_servers or not isinstance(mcp_servers, dict): + raise ValueError( + "mcpServers not found or invalid in protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + first_server = next(iter(mcp_servers.values()), None) + if not first_server or not isinstance(first_server, dict): + raise ValueError( + "No MCP server entry found in protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + url = first_server.get("url") + if not url: + raise ValueError( + "url not found in MCP server entry of protocol_spec for tool" + f" '{self.tool_name or self.name}'" + ) + + transport_type = first_server.get("transportType", "sse") + if transport_type == "streamable-http": + session_affinity = "MCP_STREAMABLE" + else: + session_affinity = "MCP_SSE" + + return url, session_affinity + def _get_mcp_endpoint( self, config: Optional[Config] = None - ) -> Optional[str]: - """获取 MCP 数据链路 URL / Get MCP data endpoint URL + ) -> Optional[Tuple[str, str]]: + """获取 MCP 数据链路 URL 和 session_affinity / Get MCP data endpoint URL and session_affinity - 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 - 如果 self.data_endpoint 为空,则从 Config 中获取。 - Determines /mcp or /sse path based on session_affinity. - Falls back to Config if self.data_endpoint is not set. + MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL 和 session_affinity。 + 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取。 + For MCP_REMOTE with proxy disabled, parses URL and session_affinity from protocol_spec. + Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config. + + Returns: + Optional[Tuple[str, str]]: (endpoint_url, session_affinity) 或 None """ + is_mcp_remote_without_proxy = ( + self.create_method == "MCP_REMOTE" + and not pydash.get(self, "mcp_config.proxy_enabled", False) + ) + + if is_mcp_remote_without_proxy: + return self._parse_protocol_spec_mcp_url() + effective_name = self.tool_name or self.name data_endpoint = self.data_endpoint if not data_endpoint: @@ -244,8 +314,11 @@ def _get_mcp_endpoint( ) if session_affinity == "MCP_STREAMABLE": - return f"{data_endpoint}/tools/{effective_name}/mcp" - return f"{data_endpoint}/tools/{effective_name}/sse" + return ( + f"{data_endpoint}/tools/{effective_name}/mcp", + session_affinity, + ) + return f"{data_endpoint}/tools/{effective_name}/sse", session_affinity async def list_tools_async( self, config: Optional[Config] = None @@ -265,16 +338,14 @@ async def list_tools_async( if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: logger.warning( "MCP endpoint not available for tool %s", self.name ) return [] - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -327,16 +398,14 @@ def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: logger.warning( "MCP endpoint not available for tool %s", self.name ) return [] - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -396,15 +465,13 @@ async def call_tool_async( if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: raise ValueError( f"MCP endpoint not available for tool {self.name}" ) - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -469,15 +536,13 @@ def call_tool( if tool_type == ToolType.MCP: from .api.mcp import ToolMCPSession - mcp_endpoint = self._get_mcp_endpoint(config) - if not mcp_endpoint: + endpoint_result = self._get_mcp_endpoint(config) + if not endpoint_result: raise ValueError( f"MCP endpoint not available for tool {self.name}" ) - session_affinity = pydash.get( - self, "mcp_config.session_affinity", "MCP_SSE" - ) + mcp_endpoint, session_affinity = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py index 865570c..2404d67 100644 --- a/tests/unittests/tool/test_tool.py +++ b/tests/unittests/tool/test_tool.py @@ -110,7 +110,7 @@ def test_get_mcp_endpoint_sse(self): mcp_config=McpConfig(session_affinity="MCP_SSE"), ) endpoint = tool._get_mcp_endpoint() - assert endpoint == "https://example.com/tools/my-tool/sse" + assert endpoint == ("https://example.com/tools/my-tool/sse", "MCP_SSE") def test_get_mcp_endpoint_streamable(self): """测试获取 MCP Streamable endpoint""" @@ -120,7 +120,10 @@ def test_get_mcp_endpoint_streamable(self): mcp_config=McpConfig(session_affinity="MCP_STREAMABLE"), ) endpoint = tool._get_mcp_endpoint() - assert endpoint == "https://example.com/tools/my-tool/mcp" + assert endpoint == ( + "https://example.com/tools/my-tool/mcp", + "MCP_STREAMABLE", + ) def test_get_mcp_endpoint_default(self): """测试获取 MCP endpoint(默认 SSE)""" @@ -129,7 +132,7 @@ def test_get_mcp_endpoint_default(self): data_endpoint="https://example.com", ) endpoint = tool._get_mcp_endpoint() - assert endpoint == "https://example.com/tools/my-tool/sse" + assert endpoint == ("https://example.com/tools/my-tool/sse", "MCP_SSE") def test_get_mcp_endpoint_no_name(self): """测试没有 name 时获取 MCP endpoint""" @@ -152,7 +155,10 @@ def test_get_mcp_endpoint_no_data_endpoint(self, mock_config_class): tool_name="my-tool", ) endpoint = tool._get_mcp_endpoint() - assert endpoint == "https://fallback.example.com/tools/my-tool/sse" + assert endpoint == ( + "https://fallback.example.com/tools/my-tool/sse", + "MCP_SSE", + ) def test_from_inner_object(self): """测试从内部对象创建 Tool""" @@ -831,10 +837,10 @@ def test_call_tool_mcp_remote_without_proxy_skips_ram( tool_name="my-tool", tool_type="MCP", create_method="MCP_REMOTE", - data_endpoint="https://example.agentrun-data.aliyuncs.com", mcp_config=McpConfig( session_affinity="MCP_SSE", proxy_enabled=False ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"sse","url":"https://my-mcp-server.com/sse"}}}', ) tool.call_tool("tool1", {}) @@ -983,10 +989,10 @@ async def test_call_tool_async_mcp_remote_without_proxy_skips_ram( tool_name="my-tool", tool_type="MCP", create_method="MCP_REMOTE", - data_endpoint="https://example.agentrun-data.aliyuncs.com", mcp_config=McpConfig( session_affinity="MCP_SSE", proxy_enabled=False ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"sse","url":"https://my-mcp-server.com/sse"}}}', ) await tool.call_tool_async("tool1", {}) @@ -1028,6 +1034,193 @@ async def test_call_tool_async_functioncall_openapi_import_skips_ram( call_kwargs = mock_openapi_client_class.call_args[1] assert call_kwargs["use_ram_auth"] is False + # ==================== _parse_protocol_spec_mcp_url 测试 ==================== + + def test_parse_protocol_spec_mcp_url_sse(self): + """测试从 protocol_spec 解析 SSE 类型的 MCP URL""" + tool = Tool( + tool_name="my-tool", + protocol_spec='{"mcpServers":{"server1":{"transportType":"sse","url":"https://my-server.com/sse"}}}', + ) + url, session_affinity = tool._parse_protocol_spec_mcp_url() + assert url == "https://my-server.com/sse" + assert session_affinity == "MCP_SSE" + + def test_parse_protocol_spec_mcp_url_streamable_http(self): + """测试从 protocol_spec 解析 Streamable HTTP 类型的 MCP URL""" + tool = Tool( + tool_name="my-tool", + protocol_spec='{"mcpServers":{"server1":{"transportType":"streamable-http","url":"https://my-server.com/mcp"}}}', + ) + url, session_affinity = tool._parse_protocol_spec_mcp_url() + assert url == "https://my-server.com/mcp" + assert session_affinity == "MCP_STREAMABLE" + + def test_parse_protocol_spec_mcp_url_unknown_transport_defaults_sse(self): + """测试 transportType 未知时默认使用 SSE""" + tool = Tool( + tool_name="my-tool", + protocol_spec='{"mcpServers":{"server1":{"transportType":"unknown","url":"https://my-server.com/path"}}}', + ) + url, session_affinity = tool._parse_protocol_spec_mcp_url() + assert url == "https://my-server.com/path" + assert session_affinity == "MCP_SSE" + + def test_parse_protocol_spec_mcp_url_empty_protocol_spec(self): + """测试 protocol_spec 为空时抛出 ValueError""" + tool = Tool(tool_name="my-tool", protocol_spec=None) + with pytest.raises(ValueError, match="protocol_spec is required"): + tool._parse_protocol_spec_mcp_url() + + def test_parse_protocol_spec_mcp_url_invalid_json(self): + """测试 protocol_spec JSON 格式不合法时抛出 ValueError""" + tool = Tool(tool_name="my-tool", protocol_spec="invalid json") + with pytest.raises(ValueError, match="Failed to parse protocol_spec"): + tool._parse_protocol_spec_mcp_url() + + def test_parse_protocol_spec_mcp_url_missing_mcp_servers(self): + """测试 protocol_spec 缺少 mcpServers 字段时抛出 ValueError""" + tool = Tool(tool_name="my-tool", protocol_spec='{"other":"data"}') + with pytest.raises(ValueError, match="mcpServers"): + tool._parse_protocol_spec_mcp_url() + + def test_parse_protocol_spec_mcp_url_empty_mcp_servers(self): + """测试 mcpServers 为空时抛出 ValueError""" + tool = Tool(tool_name="my-tool", protocol_spec='{"mcpServers":{}}') + with pytest.raises(ValueError, match="mcpServers not found or invalid"): + tool._parse_protocol_spec_mcp_url() + + def test_parse_protocol_spec_mcp_url_missing_url(self): + """测试 server entry 缺少 url 字段时抛出 ValueError""" + tool = Tool( + tool_name="my-tool", + protocol_spec='{"mcpServers":{"s1":{"transportType":"sse"}}}', + ) + with pytest.raises(ValueError, match="url"): + tool._parse_protocol_spec_mcp_url() + + # ==================== _get_mcp_endpoint 直连模式测试 ==================== + + def test_get_mcp_endpoint_mcp_remote_without_proxy(self): + """测试 MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL""" + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"sse","url":"https://external-mcp.com/sse"}}}', + ) + result = tool._get_mcp_endpoint() + assert result == ("https://external-mcp.com/sse", "MCP_SSE") + + def test_get_mcp_endpoint_mcp_remote_without_proxy_streamable(self): + """测试 MCP_REMOTE + proxy_enabled=false + streamable-http 时从 protocol_spec 解析""" + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"streamable-http","url":"https://external-mcp.com/mcp"}}}', + ) + result = tool._get_mcp_endpoint() + assert result == ("https://external-mcp.com/mcp", "MCP_STREAMABLE") + + def test_get_mcp_endpoint_mcp_remote_with_proxy_uses_data_endpoint(self): + """测试 MCP_REMOTE + proxy_enabled=true 时使用 data_endpoint 拼接""" + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + data_endpoint="https://example.com", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=True + ), + ) + result = tool._get_mcp_endpoint() + assert result == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + + def test_get_mcp_endpoint_mcp_bundle_uses_data_endpoint(self): + """测试 MCP_BUNDLE 类型使用 data_endpoint 拼接""" + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_BUNDLE", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + result = tool._get_mcp_endpoint() + assert result == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + + # ==================== list_tools / call_tool 直连模式 session_affinity 测试 ==================== + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp_remote_direct_connect_session_affinity( + self, mock_config_class, mock_mcp_session_class + ): + """测试 list_tools 在 MCP_REMOTE 直连模式下使用 protocol_spec 中的 session_affinity""" + mock_session = Mock() + mock_session.list_tools.return_value = [ + ToolInfo(name="tool1", description="Tool 1"), + ] + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"streamable-http","url":"https://external.com/mcp"}}}', + ) + + tool.list_tools() + + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["endpoint"] == "https://external.com/mcp" + assert call_kwargs["session_affinity"] == "MCP_STREAMABLE" + assert call_kwargs["use_ram_auth"] is False + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp_remote_direct_connect_session_affinity( + self, mock_config_class, mock_mcp_session_class + ): + """测试 call_tool 在 MCP_REMOTE 直连模式下使用 protocol_spec 中的 session_affinity""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "ok"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig( + session_affinity="MCP_SSE", proxy_enabled=False + ), + protocol_spec='{"mcpServers":{"s1":{"transportType":"streamable-http","url":"https://external.com/mcp"}}}', + ) + + tool.call_tool("tool1", {}) + + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["endpoint"] == "https://external.com/mcp" + assert call_kwargs["session_affinity"] == "MCP_STREAMABLE" + assert call_kwargs["use_ram_auth"] is False + def test_tool_create_method_field(self): """测试 Tool 的 create_method 字段""" tool = Tool(create_method="MCP_REMOTE") From ca41adba0190f4b9acb9c8df2966e87db91d4598 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Tue, 7 Apr 2026 13:55:22 +0800 Subject: [PATCH 09/10] refactor(tool): enhance MCP endpoint parsing and header handling This change extends the `_parse_protocol_spec_mcp_url` method to also extract headers from the protocol specification, updates related methods to handle this new information, and merges headers appropriately with higher precedence given to those specified in the protocol spec. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/tool/__tool_async_template.py | 57 +++++--- agentrun/tool/tool.py | 73 +++++++--- tests/unittests/tool/test_tool.py | 193 +++++++++++++++++++++++-- 3 files changed, 271 insertions(+), 52 deletions(-) diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py index 70ccf72..a05cee5 100644 --- a/agentrun/tool/__tool_async_template.py +++ b/agentrun/tool/__tool_async_template.py @@ -197,16 +197,16 @@ def _get_tool_type(self) -> Optional[ToolType]: return None return None - def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: - """从 protocol_spec 解析 MCP 服务器 URL 和 session_affinity / Parse MCP server URL and session_affinity from protocol_spec + def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str, Dict[str, str]]: + """从 protocol_spec 解析 MCP 服务器 URL、session_affinity 和 headers / Parse MCP server URL, session_affinity and headers from protocol_spec 用于 MCP_REMOTE + proxy_enabled=false 场景,从 protocol_spec JSON 中提取 - 第一个 mcpServers entry 的 url 和 transportType。 - Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url and - transportType from the first mcpServers entry in protocol_spec JSON. + 第一个 mcpServers entry 的 url、transportType 和 headers。 + Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url, + transportType and headers from the first mcpServers entry in protocol_spec JSON. Returns: - Tuple[str, str]: (mcp_url, session_affinity) + Tuple[str, str, Dict[str, str]]: (mcp_url, session_affinity, headers) Raises: ValueError: protocol_spec 为空、格式不合法或缺少必要字段时抛出 @@ -253,20 +253,26 @@ def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: else: session_affinity = "MCP_SSE" - return url, session_affinity + # 解析 headers(可选字段)/ Parse headers (optional field) + raw_headers = first_server.get("headers") + spec_headers: Dict[str, str] = {} + if raw_headers and isinstance(raw_headers, dict): + spec_headers = {str(k): str(v) for k, v in raw_headers.items()} + + return url, session_affinity, spec_headers def _get_mcp_endpoint( self, config: Optional[Config] = None - ) -> Optional[Tuple[str, str]]: - """获取 MCP 数据链路 URL 和 session_affinity / Get MCP data endpoint URL and session_affinity + ) -> Optional[Tuple[str, str, Dict[str, str]]]: + """获取 MCP 数据链路 URL、session_affinity 和 spec headers / Get MCP data endpoint URL, session_affinity and spec headers - MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL 和 session_affinity。 - 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取。 - For MCP_REMOTE with proxy disabled, parses URL and session_affinity from protocol_spec. - Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config. + MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL、session_affinity 和 headers。 + 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取,headers 为空。 + For MCP_REMOTE with proxy disabled, parses URL, session_affinity and headers from protocol_spec. + Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config, headers empty. Returns: - Optional[Tuple[str, str]]: (endpoint_url, session_affinity) 或 None + Optional[Tuple[str, str, Dict[str, str]]]: (endpoint_url, session_affinity, spec_headers) 或 None """ is_mcp_remote_without_proxy = ( self.create_method == "MCP_REMOTE" @@ -292,8 +298,13 @@ def _get_mcp_endpoint( return ( f"{data_endpoint}/tools/{effective_name}/mcp", session_affinity, + {}, ) - return f"{data_endpoint}/tools/{effective_name}/sse", session_affinity + return ( + f"{data_endpoint}/tools/{effective_name}/sse", + session_affinity, + {}, + ) async def list_tools_async( self, config: Optional[Config] = None @@ -320,7 +331,7 @@ async def list_tools_async( ) return [] - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -329,11 +340,15 @@ async def list_tools_async( and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) @@ -386,7 +401,7 @@ async def call_tool_async( f"MCP endpoint not available for tool {self.name}" ) - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -395,11 +410,15 @@ async def call_tool_async( and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py index 397f4de..f044368 100644 --- a/agentrun/tool/tool.py +++ b/agentrun/tool/tool.py @@ -222,16 +222,16 @@ def _get_tool_type(self) -> Optional[ToolType]: return None return None - def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: - """从 protocol_spec 解析 MCP 服务器 URL 和 session_affinity / Parse MCP server URL and session_affinity from protocol_spec + def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str, Dict[str, str]]: + """从 protocol_spec 解析 MCP 服务器 URL、session_affinity 和 headers / Parse MCP server URL, session_affinity and headers from protocol_spec 用于 MCP_REMOTE + proxy_enabled=false 场景,从 protocol_spec JSON 中提取 - 第一个 mcpServers entry 的 url 和 transportType。 - Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url and - transportType from the first mcpServers entry in protocol_spec JSON. + 第一个 mcpServers entry 的 url、transportType 和 headers。 + Used for MCP_REMOTE + proxy_enabled=false scenario, extracts url, + transportType and headers from the first mcpServers entry in protocol_spec JSON. Returns: - Tuple[str, str]: (mcp_url, session_affinity) + Tuple[str, str, Dict[str, str]]: (mcp_url, session_affinity, headers) Raises: ValueError: protocol_spec 为空、格式不合法或缺少必要字段时抛出 @@ -278,20 +278,26 @@ def _parse_protocol_spec_mcp_url(self) -> Tuple[str, str]: else: session_affinity = "MCP_SSE" - return url, session_affinity + # 解析 headers(可选字段)/ Parse headers (optional field) + raw_headers = first_server.get("headers") + spec_headers: Dict[str, str] = {} + if raw_headers and isinstance(raw_headers, dict): + spec_headers = {str(k): str(v) for k, v in raw_headers.items()} + + return url, session_affinity, spec_headers def _get_mcp_endpoint( self, config: Optional[Config] = None - ) -> Optional[Tuple[str, str]]: - """获取 MCP 数据链路 URL 和 session_affinity / Get MCP data endpoint URL and session_affinity + ) -> Optional[Tuple[str, str, Dict[str, str]]]: + """获取 MCP 数据链路 URL、session_affinity 和 spec headers / Get MCP data endpoint URL, session_affinity and spec headers - MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL 和 session_affinity。 - 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取。 - For MCP_REMOTE with proxy disabled, parses URL and session_affinity from protocol_spec. - Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config. + MCP_REMOTE + proxy_enabled=false 时从 protocol_spec 解析 URL、session_affinity 和 headers。 + 其他场景使用 data_endpoint 拼接,session_affinity 从 mcp_config 获取,headers 为空。 + For MCP_REMOTE with proxy disabled, parses URL, session_affinity and headers from protocol_spec. + Otherwise constructs URL from data_endpoint and gets session_affinity from mcp_config, headers empty. Returns: - Optional[Tuple[str, str]]: (endpoint_url, session_affinity) 或 None + Optional[Tuple[str, str, Dict[str, str]]]: (endpoint_url, session_affinity, spec_headers) 或 None """ is_mcp_remote_without_proxy = ( self.create_method == "MCP_REMOTE" @@ -317,8 +323,13 @@ def _get_mcp_endpoint( return ( f"{data_endpoint}/tools/{effective_name}/mcp", session_affinity, + {}, ) - return f"{data_endpoint}/tools/{effective_name}/sse", session_affinity + return ( + f"{data_endpoint}/tools/{effective_name}/sse", + session_affinity, + {}, + ) async def list_tools_async( self, config: Optional[Config] = None @@ -345,7 +356,7 @@ async def list_tools_async( ) return [] - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -354,11 +365,15 @@ async def list_tools_async( and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) @@ -405,7 +420,7 @@ def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: ) return [] - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -414,11 +429,15 @@ def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) @@ -471,7 +490,7 @@ async def call_tool_async( f"MCP endpoint not available for tool {self.name}" ) - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -480,11 +499,15 @@ async def call_tool_async( and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) @@ -542,7 +565,7 @@ def call_tool( f"MCP endpoint not available for tool {self.name}" ) - mcp_endpoint, session_affinity = endpoint_result + mcp_endpoint, session_affinity, spec_headers = endpoint_result # MCP_REMOTE + proxy_enabled=false 时直连外部服务,不走 RAM 鉴权 # Only skip RAM auth for MCP_REMOTE with proxy disabled (direct external connection) @@ -551,11 +574,15 @@ def call_tool( and not pydash.get(self, "mcp_config.proxy_enabled", False) ) + # 合并 headers:protocol_spec 中的 headers 优先级更高 + # Merge headers: protocol_spec headers take precedence cfg = Config.with_configs(config) + merged_headers = {**(cfg.get_headers() or {}), **spec_headers} + session = ToolMCPSession( endpoint=mcp_endpoint, session_affinity=session_affinity, - headers=cfg.get_headers(), + headers=merged_headers, config=cfg, use_ram_auth=not is_mcp_remote_without_proxy, ) diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py index 2404d67..adcaa72 100644 --- a/tests/unittests/tool/test_tool.py +++ b/tests/unittests/tool/test_tool.py @@ -4,7 +4,8 @@ Tests functionality of Tool resource class and ToolClient. """ -from unittest.mock import AsyncMock, Mock, patch +import json +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -110,7 +111,11 @@ def test_get_mcp_endpoint_sse(self): mcp_config=McpConfig(session_affinity="MCP_SSE"), ) endpoint = tool._get_mcp_endpoint() - assert endpoint == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + assert endpoint == ( + "https://example.com/tools/my-tool/sse", + "MCP_SSE", + {}, + ) def test_get_mcp_endpoint_streamable(self): """测试获取 MCP Streamable endpoint""" @@ -123,6 +128,7 @@ def test_get_mcp_endpoint_streamable(self): assert endpoint == ( "https://example.com/tools/my-tool/mcp", "MCP_STREAMABLE", + {}, ) def test_get_mcp_endpoint_default(self): @@ -132,7 +138,11 @@ def test_get_mcp_endpoint_default(self): data_endpoint="https://example.com", ) endpoint = tool._get_mcp_endpoint() - assert endpoint == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + assert endpoint == ( + "https://example.com/tools/my-tool/sse", + "MCP_SSE", + {}, + ) def test_get_mcp_endpoint_no_name(self): """测试没有 name 时获取 MCP endpoint""" @@ -158,6 +168,7 @@ def test_get_mcp_endpoint_no_data_endpoint(self, mock_config_class): assert endpoint == ( "https://fallback.example.com/tools/my-tool/sse", "MCP_SSE", + {}, ) def test_from_inner_object(self): @@ -1042,9 +1053,10 @@ def test_parse_protocol_spec_mcp_url_sse(self): tool_name="my-tool", protocol_spec='{"mcpServers":{"server1":{"transportType":"sse","url":"https://my-server.com/sse"}}}', ) - url, session_affinity = tool._parse_protocol_spec_mcp_url() + url, session_affinity, headers = tool._parse_protocol_spec_mcp_url() assert url == "https://my-server.com/sse" assert session_affinity == "MCP_SSE" + assert headers == {} def test_parse_protocol_spec_mcp_url_streamable_http(self): """测试从 protocol_spec 解析 Streamable HTTP 类型的 MCP URL""" @@ -1052,9 +1064,10 @@ def test_parse_protocol_spec_mcp_url_streamable_http(self): tool_name="my-tool", protocol_spec='{"mcpServers":{"server1":{"transportType":"streamable-http","url":"https://my-server.com/mcp"}}}', ) - url, session_affinity = tool._parse_protocol_spec_mcp_url() + url, session_affinity, headers = tool._parse_protocol_spec_mcp_url() assert url == "https://my-server.com/mcp" assert session_affinity == "MCP_STREAMABLE" + assert headers == {} def test_parse_protocol_spec_mcp_url_unknown_transport_defaults_sse(self): """测试 transportType 未知时默认使用 SSE""" @@ -1062,9 +1075,10 @@ def test_parse_protocol_spec_mcp_url_unknown_transport_defaults_sse(self): tool_name="my-tool", protocol_spec='{"mcpServers":{"server1":{"transportType":"unknown","url":"https://my-server.com/path"}}}', ) - url, session_affinity = tool._parse_protocol_spec_mcp_url() + url, session_affinity, headers = tool._parse_protocol_spec_mcp_url() assert url == "https://my-server.com/path" assert session_affinity == "MCP_SSE" + assert headers == {} def test_parse_protocol_spec_mcp_url_empty_protocol_spec(self): """测试 protocol_spec 为空时抛出 ValueError""" @@ -1113,7 +1127,7 @@ def test_get_mcp_endpoint_mcp_remote_without_proxy(self): protocol_spec='{"mcpServers":{"s1":{"transportType":"sse","url":"https://external-mcp.com/sse"}}}', ) result = tool._get_mcp_endpoint() - assert result == ("https://external-mcp.com/sse", "MCP_SSE") + assert result == ("https://external-mcp.com/sse", "MCP_SSE", {}) def test_get_mcp_endpoint_mcp_remote_without_proxy_streamable(self): """测试 MCP_REMOTE + proxy_enabled=false + streamable-http 时从 protocol_spec 解析""" @@ -1127,7 +1141,7 @@ def test_get_mcp_endpoint_mcp_remote_without_proxy_streamable(self): protocol_spec='{"mcpServers":{"s1":{"transportType":"streamable-http","url":"https://external-mcp.com/mcp"}}}', ) result = tool._get_mcp_endpoint() - assert result == ("https://external-mcp.com/mcp", "MCP_STREAMABLE") + assert result == ("https://external-mcp.com/mcp", "MCP_STREAMABLE", {}) def test_get_mcp_endpoint_mcp_remote_with_proxy_uses_data_endpoint(self): """测试 MCP_REMOTE + proxy_enabled=true 时使用 data_endpoint 拼接""" @@ -1141,7 +1155,11 @@ def test_get_mcp_endpoint_mcp_remote_with_proxy_uses_data_endpoint(self): ), ) result = tool._get_mcp_endpoint() - assert result == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + assert result == ( + "https://example.com/tools/my-tool/sse", + "MCP_SSE", + {}, + ) def test_get_mcp_endpoint_mcp_bundle_uses_data_endpoint(self): """测试 MCP_BUNDLE 类型使用 data_endpoint 拼接""" @@ -1153,7 +1171,83 @@ def test_get_mcp_endpoint_mcp_bundle_uses_data_endpoint(self): mcp_config=McpConfig(session_affinity="MCP_SSE"), ) result = tool._get_mcp_endpoint() - assert result == ("https://example.com/tools/my-tool/sse", "MCP_SSE") + assert result == ( + "https://example.com/tools/my-tool/sse", + "MCP_SSE", + {}, + ) + + def test_parse_protocol_spec_mcp_url_with_headers(self): + """测试 protocol_spec 中包含 headers 时能正确解析""" + spec = json.dumps({ + "mcpServers": { + "server": { + "url": "https://mcp.example.com/mcp", + "transportType": "streamable-http", + "headers": { + "Authorization": "Bearer sk-xxx", + "X-Custom": "value", + }, + } + } + }) + tool = Tool(tool_name="my-tool", protocol_spec=spec) + url, affinity, headers = tool._parse_protocol_spec_mcp_url() + assert url == "https://mcp.example.com/mcp" + assert affinity == "MCP_STREAMABLE" + assert headers == { + "Authorization": "Bearer sk-xxx", + "X-Custom": "value", + } + + def test_parse_protocol_spec_mcp_url_without_headers(self): + """测试 protocol_spec 中没有 headers 字段时返回空 dict""" + spec = json.dumps( + {"mcpServers": {"server": {"url": "https://mcp.example.com/sse"}}} + ) + tool = Tool(tool_name="my-tool", protocol_spec=spec) + url, affinity, headers = tool._parse_protocol_spec_mcp_url() + assert url == "https://mcp.example.com/sse" + assert affinity == "MCP_SSE" + assert headers == {} + + def test_parse_protocol_spec_mcp_url_headers_non_dict_ignored(self): + """测试 headers 不是 dict 时被忽略""" + spec = json.dumps({ + "mcpServers": { + "server": { + "url": "https://mcp.example.com/sse", + "headers": "not-a-dict", + } + } + }) + tool = Tool(tool_name="my-tool", protocol_spec=spec) + url, affinity, headers = tool._parse_protocol_spec_mcp_url() + assert headers == {} + + def test_get_mcp_endpoint_mcp_remote_without_proxy_with_headers(self): + """测试直连模式下 headers 从 protocol_spec 传递""" + spec = json.dumps({ + "mcpServers": { + "server": { + "url": "https://mcp.example.com/mcp", + "transportType": "streamable-http", + "headers": {"Authorization": "Bearer sk-xxx"}, + } + } + }) + tool = Tool( + tool_name="my-tool", + create_method="MCP_REMOTE", + mcp_config=McpConfig(proxy_enabled=False), + protocol_spec=spec, + ) + result = tool._get_mcp_endpoint() + assert result == ( + "https://mcp.example.com/mcp", + "MCP_STREAMABLE", + {"Authorization": "Bearer sk-xxx"}, + ) # ==================== list_tools / call_tool 直连模式 session_affinity 测试 ==================== @@ -1221,6 +1315,85 @@ def test_call_tool_mcp_remote_direct_connect_session_affinity( assert call_kwargs["session_affinity"] == "MCP_STREAMABLE" assert call_kwargs["use_ram_auth"] is False + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp_remote_direct_connect_with_spec_headers( + self, mock_config_class, mock_mcp_session_class + ): + """测试 list_tools 直连模式下 spec_headers 被合并到 ToolMCPSession 的 headers中""" + mock_config = MagicMock() + mock_config.get_headers.return_value = {"X-Existing": "old-value"} + mock_config_class.with_configs.return_value = mock_config + + mock_session = MagicMock() + mock_session.list_tools_async = AsyncMock(return_value=[]) + mock_mcp_session_class.return_value = mock_session + + spec = json.dumps({ + "mcpServers": { + "server": { + "url": "https://mcp.example.com/mcp", + "transportType": "streamable-http", + "headers": { + "Authorization": "Bearer sk-xxx", + "X-Existing": "new-value", + }, + } + } + }) + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig(proxy_enabled=False), + protocol_spec=spec, + ) + tool.list_tools() + + call_kwargs = mock_mcp_session_class.call_args[1] + # spec_headers should override cfg headers for same key + assert call_kwargs["headers"] == { + "X-Existing": "new-value", + "Authorization": "Bearer sk-xxx", + } + assert call_kwargs["use_ram_auth"] is False + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp_remote_direct_connect_with_spec_headers( + self, mock_config_class, mock_mcp_session_class + ): + """测试 call_tool 直连模式下 spec_headers 被合并到 ToolMCPSession 的 headers中""" + mock_config = MagicMock() + mock_config.get_headers.return_value = {"X-Existing": "old-value"} + mock_config_class.with_configs.return_value = mock_config + + mock_session = MagicMock() + mock_session.call_tool_async = AsyncMock(return_value={"result": "ok"}) + mock_mcp_session_class.return_value = mock_session + + spec = json.dumps({ + "mcpServers": { + "server": { + "url": "https://mcp.example.com/mcp", + "transportType": "streamable-http", + "headers": {"Authorization": "Bearer sk-xxx"}, + } + } + }) + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + create_method="MCP_REMOTE", + mcp_config=McpConfig(proxy_enabled=False), + protocol_spec=spec, + ) + tool.call_tool("sub-tool", {"arg": "val"}) + + call_kwargs = mock_mcp_session_class.call_args[1] + assert call_kwargs["headers"] == {"Authorization": "Bearer sk-xxx"} + assert call_kwargs["use_ram_auth"] is False + def test_tool_create_method_field(self): """测试 Tool 的 create_method 字段""" tool = Tool(create_method="MCP_REMOTE") From 525172c5a50b22c519dd03f6595a5e44a31221f1 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Tue, 7 Apr 2026 16:08:36 +0800 Subject: [PATCH 10/10] fix(skill_loader): add conditional loading of execute_command based on env var This change introduces an environment variable `ALLOW_EXECUTE_COMMAND` that controls whether the `execute_command` tool is loaded into the CommonToolSet. By default, the tool is included unless explicitly disabled via the environment variable being set to `"false"` (case-insensitive). This provides better control over security-sensitive functionality. Tests have been added to verify all scenarios including default behavior, true/false values, and case insensitivity. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/integration/utils/skill_loader.py | 112 +++++++++++------- .../integration/test_skill_loader.py | 99 ++++++++++++++++ 2 files changed, 165 insertions(+), 46 deletions(-) diff --git a/agentrun/integration/utils/skill_loader.py b/agentrun/integration/utils/skill_loader.py index 7eed1f5..74b2b4d 100644 --- a/agentrun/integration/utils/skill_loader.py +++ b/agentrun/integration/utils/skill_loader.py @@ -600,10 +600,32 @@ def _execute_command_func( ensure_ascii=False, ) + @staticmethod + def _is_execute_command_allowed() -> bool: + """检查环境变量是否允许加载 execute_command 工具 + + Check whether the ALLOW_EXECUTE_COMMAND environment variable permits + loading the execute_command tool. + + The variable is read from ``os.environ``. When it is absent or set to + any value other than a case-insensitive ``"false"``, the tool is + allowed (default **True**). + + Returns: + True 表示允许 / True means allowed + """ + value = os.environ.get("ALLOW_EXECUTE_COMMAND", "true") + return value.lower() != "false" + def to_common_toolset(self) -> CommonToolSet: - """构造包含 load_skills、read_skill_file、execute_command 工具的 CommonToolSet + """构造包含 load_skills、read_skill_file 以及可选的 execute_command 工具的 CommonToolSet + + Construct CommonToolSet with load_skills, read_skill_file, and + optionally execute_command tools. - Construct CommonToolSet with load_skills, read_skill_file, and execute_command tools. + The execute_command tool is included only when the environment variable + ``ALLOW_EXECUTE_COMMAND`` is not set to ``"false"`` (case-insensitive). + When the variable is absent, it defaults to ``"true"`` (included). Returns: CommonToolSet 实例 / CommonToolSet instance @@ -656,54 +678,52 @@ def to_common_toolset(self) -> CommonToolSet: func=self._read_skill_file_func, ) - execute_command_tool = Tool( - name="execute_command", - description=( - "Execute a shell command on the local machine. " - "Use this to run scripts, install dependencies, or perform " - "file operations as instructed by skill documentation. " - "Returns stdout, stderr, exit_code, and timeout status.\n\n" - "⚠️ IMPORTANT: Before calling this tool, you MUST first " - "display the exact command to the user and ask for explicit " - "confirmation. Only proceed if the user approves. " - "Never execute commands without user approval." - ), - parameters=[ - ToolParameter( - name="command", - param_type="string", - description="The shell command to execute.", - required=True, + tools_list: List[Tool] = [load_skills_tool, read_skill_file_tool] + + if self._is_execute_command_allowed(): + execute_command_tool = Tool( + name="execute_command", + description=( + "Execute a shell command on the local machine. Use this to" + " run scripts, install dependencies, or perform file" + " operations as instructed by skill documentation. Returns" + " stdout, stderr, exit_code, and timeout status.\n\n⚠️" + " IMPORTANT: Before calling this tool, you MUST first" + " display the exact command to the user and ask for" + " explicit confirmation. Only proceed if the user approves." + " Never execute commands without user approval." ), - ToolParameter( - name="cwd", - param_type="string", - description=( - "Working directory for the command. " - "Defaults to the skills directory if not specified." + parameters=[ + ToolParameter( + name="command", + param_type="string", + description="The shell command to execute.", + required=True, ), - required=False, - ), - ToolParameter( - name="timeout", - param_type="integer", - description=( - "Maximum execution time in seconds. " - f"Defaults to {self._command_timeout}." + ToolParameter( + name="cwd", + param_type="string", + description=( + "Working directory for the command. " + "Defaults to the skills directory if not specified." + ), + required=False, ), - required=False, - ), - ], - func=self._execute_command_func, - ) + ToolParameter( + name="timeout", + param_type="integer", + description=( + "Maximum execution time in seconds. " + f"Defaults to {self._command_timeout}." + ), + required=False, + ), + ], + func=self._execute_command_func, + ) + tools_list.append(execute_command_tool) - return CommonToolSet( - tools_list=[ - load_skills_tool, - read_skill_file_tool, - execute_command_tool, - ] - ) + return CommonToolSet(tools_list=tools_list) def skill_tools( diff --git a/tests/unittests/integration/test_skill_loader.py b/tests/unittests/integration/test_skill_loader.py index eb8bd10..8605db1 100644 --- a/tests/unittests/integration/test_skill_loader.py +++ b/tests/unittests/integration/test_skill_loader.py @@ -1343,3 +1343,102 @@ def test_execute_command_description_has_safety_warning( exec_tool = tool_map["execute_command"] assert "IMPORTANT" in exec_tool.description assert "approval" in exec_tool.description.lower() + + +# ============================================================================= +# 17. ALLOW_EXECUTE_COMMAND 环境变量控制测试 +# ============================================================================= + + +class TestAllowExecuteCommandEnvVar: + """测试 ALLOW_EXECUTE_COMMAND 环境变量对 execute_command 工具加载的控制""" + + def test_default_includes_execute_command(self, tmp_path: Any) -> None: + """未设置环境变量时,默认包含 execute_command 工具""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("ALLOW_EXECUTE_COMMAND", None) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_names = [t.name for t in toolset.tools()] + assert "execute_command" in tool_names + assert len(toolset.tools()) == 3 + + def test_env_true_includes_execute_command(self, tmp_path: Any) -> None: + """ALLOW_EXECUTE_COMMAND=true 时,包含 execute_command 工具""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "true"}): + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_names = [t.name for t in toolset.tools()] + assert "execute_command" in tool_names + assert len(toolset.tools()) == 3 + + def test_env_false_excludes_execute_command(self, tmp_path: Any) -> None: + """ALLOW_EXECUTE_COMMAND=false 时,不包含 execute_command 工具""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "false"}): + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_names = [t.name for t in toolset.tools()] + assert "execute_command" not in tool_names + assert len(toolset.tools()) == 2 + assert "load_skills" in tool_names + assert "read_skill_file" in tool_names + + def test_env_false_case_insensitive(self, tmp_path: Any) -> None: + """ALLOW_EXECUTE_COMMAND=False / FALSE 等大小写变体均生效""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + for value in ("False", "FALSE", "fAlSe"): + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": value}): + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_names = [t.name for t in toolset.tools()] + assert ( + "execute_command" not in tool_names + ), f"execute_command should be excluded for value={value!r}" + + def test_env_non_false_includes_execute_command( + self, tmp_path: Any + ) -> None: + """ALLOW_EXECUTE_COMMAND 设置为非 false 的值时,包含 execute_command""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + for value in ("True", "TRUE", "1", "yes", "anything"): + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": value}): + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool_names = [t.name for t in toolset.tools()] + assert ( + "execute_command" in tool_names + ), f"execute_command should be included for value={value!r}" + + def test_is_execute_command_allowed_static_method(self) -> None: + """直接测试 _is_execute_command_allowed 静态方法""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("ALLOW_EXECUTE_COMMAND", None) + assert SkillLoader._is_execute_command_allowed() is True + + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "true"}): + assert SkillLoader._is_execute_command_allowed() is True + + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "false"}): + assert SkillLoader._is_execute_command_allowed() is False + + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "False"}): + assert SkillLoader._is_execute_command_allowed() is False + + def test_skill_tools_func_respects_env_var(self, tmp_path: Any) -> None: + """测试顶层 skill_tools() 函数也受环境变量控制""" + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with patch.dict(os.environ, {"ALLOW_EXECUTE_COMMAND": "false"}): + toolset = skill_tools(skills_dir=skills_dir) + tool_names = [t.name for t in toolset.tools()] + assert "execute_command" not in tool_names + assert "load_skills" in tool_names + assert "read_skill_file" in tool_names