diff --git a/docs/migration.md b/docs/migration.md index 2528f046c..bc79a89eb 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -428,6 +428,34 @@ async def my_tool(x: int, ctx: Context) -> str: The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. +### Tool registration now accepts prebuilt `Tool` objects + +`MCPServer.add_tool()` and `ToolManager.add_tool()` now expect a fully constructed `Tool` instance, matching the resource registration pattern. Build tools with `Tool.from_function(...)` or register them through the `@mcp.tool()` decorator, which still handles construction for you. + +**Before (v1):** + +```python +def add(a: int, b: int) -> int: + return a + b + +mcp.add_tool(add) +``` + +**After (v2):** + +```python +from mcp.server.mcpserver.tools import Tool + + +def add(a: int, b: int) -> int: + return a + b + + +mcp.add_tool(Tool.from_function(add)) +``` + +If you need to customize the tool metadata before registration, build the `Tool` first and then pass it to `add_tool()`. + ### Registering lowlevel handlers on `MCPServer` (workaround) `MCPServer` does not expose public APIs for `subscribe_resource`, `unsubscribe_resource`, or `set_logging_level` handlers. In v1, the workaround was to reach into the private lowlevel server and use its decorator methods: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 6f9bb0e28..14f9eed52 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -455,45 +455,13 @@ async def read_resource( # If an exception happens when reading the resource, we should not leak the exception to the client. raise ResourceError(f"Error reading resource {uri}") from exc - def add_tool( - self, - fn: Callable[..., Any], - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - icons: list[Icon] | None = None, - meta: dict[str, Any] | None = None, - structured_output: bool | None = None, - ) -> None: + def add_tool(self, tool: Tool) -> None: """Add a tool to the server. - The tool function can optionally request a Context object by adding a parameter - with the Context type annotation. See the @tool decorator for examples. - Args: - fn: The function to register as a tool - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - icons: Optional list of icons for the tool - meta: Optional metadata dictionary for the tool - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool + tool: A Tool instance to add """ - self._tool_manager.add_tool( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - icons=icons, - meta=meta, - structured_output=structured_output, - ) + self._tool_manager.add_tool(tool) def remove_tool(self, name: str) -> None: """Remove a tool from the server by name. @@ -562,7 +530,7 @@ async def async_tool(x: int, context: Context) -> str: ) def decorator(fn: _CallableT) -> _CallableT: - self.add_tool( + tool = Tool.from_function( fn, name=name, title=title, @@ -572,6 +540,7 @@ def decorator(fn: _CallableT) -> _CallableT: meta=meta, structured_output=structured_output, ) + self.add_tool(tool) return fn return decorator diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 754313eb8..b0385c544 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -4,7 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter @@ -36,6 +36,12 @@ class Tool(BaseModel): icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool") + @field_validator("name") + @classmethod + def validate_name(cls, name: str) -> str: + validate_and_warn_tool_name(name) + return name + @cached_property def output_schema(self) -> dict[str, Any] | None: return self.fn_metadata.output_schema @@ -56,8 +62,6 @@ def from_function( """Create a Tool from a function.""" func_name = name or fn.__name__ - validate_and_warn_tool_name(func_name) - if func_name == "": raise ValueError("You must provide a name for lambda functions") diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed54797..1f1c46663 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -1,12 +1,10 @@ from __future__ import annotations -from collections.abc import Callable from typing import TYPE_CHECKING, Any from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools.base import Tool from mcp.server.mcpserver.utilities.logging import get_logger -from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT @@ -25,13 +23,10 @@ def __init__( tools: list[Tool] | None = None, ): self._tools: dict[str, Tool] = {} + self.warn_on_duplicate_tools = warn_on_duplicate_tools if tools is not None: for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool - - self.warn_on_duplicate_tools = warn_on_duplicate_tools + self.add_tool(tool) def get_tool(self, name: str) -> Tool | None: """Get tool by name.""" @@ -43,26 +38,9 @@ def list_tools(self) -> list[Tool]: def add_tool( self, - fn: Callable[..., Any], - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - icons: list[Icon] | None = None, - meta: dict[str, Any] | None = None, - structured_output: bool | None = None, + tool: Tool, ) -> Tool: - """Add a tool to the server.""" - tool = Tool.from_function( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - icons=icons, - meta=meta, - structured_output=structured_output, - ) + """Add a tool to the manager.""" existing = self._tools.get(tool.name) if existing: if self.warn_on_duplicate_tools: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 062b47d0f..5b8e3019b 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -349,6 +349,9 @@ def _try_create_model_and_schema( elif isinstance(type_expr, GenericAlias): origin = get_origin(type_expr) + if origin in (list, tuple, set, frozenset, Sequence) and _annotation_contains_any(type_expr): + return None, None, False + # Special case: dict with string keys can use RootModel if origin is dict: args = get_args(type_expr) @@ -474,6 +477,18 @@ def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]: return create_model(model_name, result=annotation) +def _annotation_contains_any(annotation: Any) -> bool: + """Return True if a type annotation contains `Any` anywhere within it.""" + if annotation is Any: + return True + + origin = get_origin(annotation) + if origin is None: + return False + + return any(_annotation_contains_any(arg) for arg in get_args(annotation) if arg is not Ellipsis) + + def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: """Create a RootModel for dict[str, T] types.""" # TODO(Marcelo): We should not rely on RootModel for this. diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index c57d1ee9f..ea4eb9f5a 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -674,6 +674,9 @@ def func_list_str() -> list[str]: # pragma: no cover def func_dict_str_int() -> dict[str, int]: # pragma: no cover return {"a": 1, "b": 2} + def func_list_any() -> list[Any]: # pragma: no cover + return ["a", "b", "c"] + def func_union() -> str | int: # pragma: no cover return "hello" @@ -689,6 +692,10 @@ def func_optional() -> str | None: # pragma: no cover "title": "func_list_strOutput", } + # Test list[Any] - should stay unstructured because it can contain arbitrary non-serializable values + meta = func_metadata(func_list_any) + assert meta.output_schema is None + # Test dict[str, int] - should NOT be wrapped meta = func_metadata(func_dict_str_int) assert meta.output_schema == { diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 49b6deb4b..b1bc05f2c 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,4 +1,5 @@ import base64 +from contextlib import asynccontextmanager from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -16,6 +17,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage from mcp.server.mcpserver.resources import FileResource, FunctionResource +from mcp.server.mcpserver.tools import Tool from mcp.server.mcpserver.utilities.types import Audio, Image from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError @@ -74,6 +76,44 @@ def test_dependencies(self): mcp_no_deps = MCPServer("test") assert mcp_no_deps.dependencies == [] + def test_run_dispatches_to_stdio(self, monkeypatch: pytest.MonkeyPatch): + mcp = MCPServer("test") + captured: dict[str, Any] = {} + + def fake_anyio_run(fn: Any, *args: Any, **kwargs: Any) -> None: + captured["fn"] = fn + captured["args"] = args + captured["kwargs"] = kwargs + + monkeypatch.setattr(MCPServer.run.__globals__["anyio"], "run", fake_anyio_run) + + mcp.run() + + assert captured["fn"] == mcp.run_stdio_async + assert captured["args"] == () + assert captured["kwargs"] == {} + + @pytest.mark.anyio + async def test_run_stdio_async_uses_stdio_server(self, monkeypatch: pytest.MonkeyPatch): + mcp = MCPServer("test") + read_stream = object() + write_stream = object() + init_options = object() + lowlevel_run = AsyncMock() + + mcp._lowlevel_server.run = lowlevel_run # type: ignore[method-assign] + monkeypatch.setattr(mcp._lowlevel_server, "create_initialization_options", lambda: init_options) + + @asynccontextmanager + async def fake_stdio_server(): + yield read_stream, write_stream + + monkeypatch.setitem(MCPServer.run_stdio_async.__globals__, "stdio_server", fake_stdio_server) + + await mcp.run_stdio_async() + + lowlevel_run.assert_awaited_once_with(read_stream, write_stream, init_options) + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") @@ -239,20 +279,21 @@ def mixed_content_tool_fn() -> list[ContentBlock]: class TestServerTools: async def test_add_tool(self): mcp = MCPServer() - mcp.add_tool(tool_fn) - mcp.add_tool(tool_fn) + tool = Tool.from_function(tool_fn) + mcp.add_tool(tool) + mcp.add_tool(tool) assert len(mcp._tool_manager.list_tools()) == 1 async def test_list_tools(self): mcp = MCPServer() - mcp.add_tool(tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 async def test_call_tool(self): mcp = MCPServer() - mcp.add_tool(tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) async with Client(mcp) as client: result = await client.call_tool("my_tool", {"arg1": "value"}) assert not hasattr(result, "error") @@ -260,7 +301,7 @@ async def test_call_tool(self): async def test_tool_exception_handling(self): mcp = MCPServer() - mcp.add_tool(error_tool_fn) + mcp.add_tool(Tool.from_function(error_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) assert len(result.content) == 1 @@ -271,7 +312,7 @@ async def test_tool_exception_handling(self): async def test_tool_error_handling(self): mcp = MCPServer() - mcp.add_tool(error_tool_fn) + mcp.add_tool(Tool.from_function(error_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) assert len(result.content) == 1 @@ -283,7 +324,7 @@ async def test_tool_error_handling(self): async def test_tool_error_details(self): """Test that exception details are properly formatted in the response""" mcp = MCPServer() - mcp.add_tool(error_tool_fn) + mcp.add_tool(Tool.from_function(error_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) content = result.content[0] @@ -294,7 +335,7 @@ async def test_tool_error_details(self): async def test_tool_return_value_conversion(self): mcp = MCPServer() - mcp.add_tool(tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) async with Client(mcp) as client: result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) assert len(result.content) == 1 @@ -311,7 +352,7 @@ async def test_tool_image_helper(self, tmp_path: Path): image_path.write_bytes(b"fake png data") mcp = MCPServer() - mcp.add_tool(image_tool_fn) + mcp.add_tool(Tool.from_function(image_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("image_tool_fn", {"path": str(image_path)}) assert len(result.content) == 1 @@ -331,7 +372,7 @@ async def test_tool_audio_helper(self, tmp_path: Path): audio_path.write_bytes(b"fake wav data") mcp = MCPServer() - mcp.add_tool(audio_tool_fn) + mcp.add_tool(Tool.from_function(audio_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)}) assert len(result.content) == 1 @@ -360,7 +401,7 @@ async def test_tool_audio_helper(self, tmp_path: Path): async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str): """Test that Audio helper correctly detects MIME types from file suffixes""" mcp = MCPServer() - mcp.add_tool(audio_tool_fn) + mcp.add_tool(Tool.from_function(audio_tool_fn)) # Create a test audio file with the specific extension audio_path = tmp_path / filename @@ -379,7 +420,7 @@ async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, async def test_tool_mixed_content(self): mcp = MCPServer() - mcp.add_tool(mixed_content_tool_fn) + mcp.add_tool(Tool.from_function(mixed_content_tool_fn)) async with Client(mcp) as client: result = await client.call_tool("mixed_content_tool_fn", {}) assert len(result.content) == 3 @@ -420,8 +461,8 @@ async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path): # TODO(Marcelo): It seems if we add the proper type hint, it generates an invalid JSON schema. # We need to fix this. - def mixed_list_fn() -> list: # type: ignore - return [ # type: ignore + def mixed_list_fn() -> list[Any]: + return [ "text message", Image(image_path), Audio(audio_path), @@ -430,7 +471,7 @@ def mixed_list_fn() -> list: # type: ignore ] mcp = MCPServer() - mcp.add_tool(mixed_list_fn) # type: ignore + mcp.add_tool(Tool.from_function(mixed_list_fn)) async with Client(mcp) as client: result = await client.call_tool("mixed_list_fn", {}) assert len(result.content) == 5 @@ -472,7 +513,7 @@ def get_user(user_id: int) -> UserOutput: return UserOutput(name="John Doe", age=30) mcp = MCPServer() - mcp.add_tool(get_user) + mcp.add_tool(Tool.from_function(get_user)) async with Client(mcp) as client: # Check that the tool has outputSchema @@ -501,7 +542,7 @@ def calculate_sum(a: int, b: int) -> int: return a + b mcp = MCPServer() - mcp.add_tool(calculate_sum) + mcp.add_tool(Tool.from_function(calculate_sum)) async with Client(mcp) as client: # Check that the tool has outputSchema @@ -527,7 +568,7 @@ def get_numbers() -> list[int]: return [1, 2, 3, 4, 5] mcp = MCPServer() - mcp.add_tool(get_numbers) + mcp.add_tool(Tool.from_function(get_numbers)) async with Client(mcp) as client: result = await client.call_tool("get_numbers", {}) @@ -542,7 +583,7 @@ def get_numbers() -> list[int]: return [1, 2, 3, 4, [5]] # type: ignore mcp = MCPServer() - mcp.add_tool(get_numbers) + mcp.add_tool(Tool.from_function(get_numbers)) async with Client(mcp) as client: result = await client.call_tool("get_numbers", {}) @@ -565,7 +606,7 @@ def get_metadata() -> dict[str, Any]: } mcp = MCPServer() - mcp.add_tool(get_metadata) + mcp.add_tool(Tool.from_function(get_metadata)) async with Client(mcp) as client: # Check schema @@ -600,7 +641,7 @@ def get_settings() -> dict[str, str]: return {"theme": "dark", "language": "en", "timezone": "UTC"} mcp = MCPServer() - mcp.add_tool(get_settings) + mcp.add_tool(Tool.from_function(get_settings)) async with Client(mcp) as client: # Check schema @@ -618,7 +659,7 @@ def get_settings() -> dict[str, str]: async def test_remove_tool(self): """Test removing a tool from the server.""" mcp = MCPServer() - mcp.add_tool(tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) # Verify tool exists assert len(mcp._tool_manager.list_tools()) == 1 @@ -639,8 +680,8 @@ async def test_remove_nonexistent_tool(self): async def test_remove_tool_and_list(self): """Test that a removed tool doesn't appear in list_tools.""" mcp = MCPServer() - mcp.add_tool(tool_fn) - mcp.add_tool(error_tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) + mcp.add_tool(Tool.from_function(error_tool_fn)) # Verify both tools exist async with Client(mcp) as client: @@ -662,7 +703,7 @@ async def test_remove_tool_and_list(self): async def test_remove_tool_and_call(self): """Test that calling a removed tool fails appropriately.""" mcp = MCPServer() - mcp.add_tool(tool_fn) + mcp.add_tool(Tool.from_function(tool_fn)) # Verify tool works before removal async with Client(mcp) as client: @@ -1014,7 +1055,7 @@ async def test_context_detection(self): def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return f"Request {ctx.request_id}: {x}" - tool = mcp._tool_manager.add_tool(tool_with_context) + tool = mcp._tool_manager.add_tool(Tool.from_function(tool_with_context)) assert tool.context_kwarg == "ctx" async def test_context_injection(self): @@ -1025,7 +1066,7 @@ def tool_with_context(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Request {ctx.request_id}: {x}" - mcp.add_tool(tool_with_context) + mcp.add_tool(Tool.from_function(tool_with_context)) async with Client(mcp) as client: result = await client.call_tool("tool_with_context", {"x": 42}) assert len(result.content) == 1 @@ -1042,7 +1083,7 @@ async def async_tool(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Async request {ctx.request_id}: {x}" - mcp.add_tool(async_tool) + mcp.add_tool(Tool.from_function(async_tool)) async with Client(mcp) as client: result = await client.call_tool("async_tool", {"x": 42}) assert len(result.content) == 1 @@ -1062,7 +1103,7 @@ async def logging_tool(msg: str, ctx: Context) -> str: await ctx.error("Error message") return f"Logged messages for {msg}" - mcp.add_tool(logging_tool) + mcp.add_tool(Tool.from_function(logging_tool)) with patch("mcp.server.session.ServerSession.send_log_message") as mock_log: async with Client(mcp) as client: @@ -1085,7 +1126,7 @@ async def test_optional_context(self): def no_context(x: int) -> int: return x * 2 - mcp.add_tool(no_context) + mcp.add_tool(Tool.from_function(no_context)) async with Client(mcp) as client: result = await client.call_tool("no_context", {"x": 21}) assert len(result.content) == 1 diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index e4dfd4ff9..90dbf28d3 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -23,7 +23,7 @@ def sum(a: int, b: int) -> int: # pragma: no cover return a + b manager = ToolManager() - manager.add_tool(sum) + manager.add_tool(Tool.from_function(sum)) tool = manager.get_tool("sum") assert tool is not None @@ -54,13 +54,24 @@ class AddArguments(ArgModelBase): context_kwarg=None, annotations=None, ) - manager = ToolManager(tools=[original_tool]) + duplicate_tool = Tool( + name="sum", + title="Duplicate Tool", + description="Add two numbers.", + fn=sum, + fn_metadata=fn_metadata, + is_async=False, + parameters=AddArguments.model_json_schema(), + context_kwarg=None, + annotations=None, + ) + manager = ToolManager(tools=[original_tool, duplicate_tool]) saved_tool = manager.get_tool("sum") - assert saved_tool == original_tool + assert saved_tool is original_tool # warn on duplicate tools with caplog.at_level(logging.WARNING): - manager = ToolManager(True, tools=[original_tool, original_tool]) + manager = ToolManager(True, tools=[original_tool, duplicate_tool]) assert "Tool already exists: sum" in caplog.text @pytest.mark.anyio @@ -72,7 +83,7 @@ async def fetch_data(url: str) -> str: # pragma: no cover return f"Data from {url}" manager = ToolManager() - manager.add_tool(fetch_data) + manager.add_tool(Tool.from_function(fetch_data)) tool = manager.get_tool("fetch_data") assert tool is not None @@ -93,7 +104,7 @@ def create_user(user: UserInput, flag: bool) -> dict[str, Any]: # pragma: no co return {"id": 1, **user.model_dump()} manager = ToolManager() - manager.add_tool(create_user) + manager.add_tool(Tool.from_function(create_user)) tool = manager.get_tool("create_user") assert tool is not None @@ -115,7 +126,7 @@ def __call__(self, x: int) -> int: # pragma: no cover return x * 2 manager = ToolManager() - tool = manager.add_tool(MyTool()) + tool = manager.add_tool(Tool.from_function(MyTool())) assert tool.name == "MyTool" assert tool.is_async is False assert tool.parameters["properties"]["x"]["type"] == "integer" @@ -132,25 +143,23 @@ async def __call__(self, x: int) -> int: # pragma: no cover return x * 2 manager = ToolManager() - tool = manager.add_tool(MyAsyncTool()) + tool = manager.add_tool(Tool.from_function(MyAsyncTool())) assert tool.name == "MyAsyncTool" assert tool.is_async is True assert tool.parameters["properties"]["x"]["type"] == "integer" def test_add_invalid_tool(self): - manager = ToolManager() with pytest.raises(AttributeError): - manager.add_tool(1) # type: ignore + Tool.from_function(1) # type: ignore[arg-type] def test_add_lambda(self): manager = ToolManager() - tool = manager.add_tool(lambda x: x, name="my_tool") # type: ignore[reportUnknownLambdaType] + tool = manager.add_tool(Tool.from_function(lambda x: x, name="my_tool")) # type: ignore[reportUnknownLambdaType] assert tool.name == "my_tool" def test_add_lambda_with_no_name(self): - manager = ToolManager() with pytest.raises(ValueError, match="You must provide a name for lambda functions"): - manager.add_tool(lambda x: x) # type: ignore[reportUnknownLambdaType] + Tool.from_function(lambda x: x) # type: ignore[reportUnknownLambdaType] def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test warning on duplicate tools.""" @@ -159,9 +168,9 @@ def f(x: int) -> int: # pragma: no cover return x manager = ToolManager() - manager.add_tool(f) + manager.add_tool(Tool.from_function(f)) with caplog.at_level(logging.WARNING): - manager.add_tool(f) + manager.add_tool(Tool.from_function(f)) assert "Tool already exists: f" in caplog.text def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): @@ -171,10 +180,10 @@ def f(x: int) -> int: # pragma: no cover return x manager = ToolManager() - manager.add_tool(f) + manager.add_tool(Tool.from_function(f)) manager.warn_on_duplicate_tools = False with caplog.at_level(logging.WARNING): - manager.add_tool(f) + manager.add_tool(Tool.from_function(f)) assert "Tool already exists: f" not in caplog.text @@ -186,7 +195,7 @@ def sum(a: int, b: int) -> int: return a + b manager = ToolManager() - manager.add_tool(sum) + manager.add_tool(Tool.from_function(sum)) result = await manager.call_tool("sum", {"a": 1, "b": 2}, Context()) assert result == 3 @@ -197,7 +206,7 @@ async def double(n: int) -> int: return n * 2 manager = ToolManager() - manager.add_tool(double) + manager.add_tool(Tool.from_function(double)) result = await manager.call_tool("double", {"n": 5}, Context()) assert result == 10 @@ -211,7 +220,7 @@ def __call__(self, x: int) -> int: return x * 2 manager = ToolManager() - tool = manager.add_tool(MyTool()) + tool = manager.add_tool(Tool.from_function(MyTool())) result = await tool.run({"x": 5}, Context()) assert result == 10 @@ -225,7 +234,7 @@ async def __call__(self, x: int) -> int: return x * 2 manager = ToolManager() - tool = manager.add_tool(MyAsyncTool()) + tool = manager.add_tool(Tool.from_function(MyAsyncTool())) result = await tool.run({"x": 5}, Context()) assert result == 10 @@ -236,7 +245,7 @@ def sum(a: int, b: int = 1) -> int: return a + b manager = ToolManager() - manager.add_tool(sum) + manager.add_tool(Tool.from_function(sum)) result = await manager.call_tool("sum", {"a": 1}, Context()) assert result == 2 @@ -247,7 +256,7 @@ def sum(a: int, b: int) -> int: # pragma: no cover return a + b manager = ToolManager() - manager.add_tool(sum) + manager.add_tool(Tool.from_function(sum)) with pytest.raises(ToolError): await manager.call_tool("sum", {"a": 1}, Context()) @@ -263,7 +272,7 @@ def sum_vals(vals: list[int]) -> int: return sum(vals) manager = ToolManager() - manager.add_tool(sum_vals) + manager.add_tool(Tool.from_function(sum_vals)) # Try both with plain list and with JSON list result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}, Context()) assert result == 6 @@ -276,7 +285,7 @@ def concat_strs(vals: list[str] | str) -> str: return vals if isinstance(vals, str) else "".join(vals) manager = ToolManager() - manager.add_tool(concat_strs) + manager.add_tool(Tool.from_function(concat_strs)) # Try both with plain python object and with JSON list result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}, Context()) assert result == "abc" @@ -300,7 +309,7 @@ def name_shrimp(tank: MyShrimpTank) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() - manager.add_tool(name_shrimp) + manager.add_tool(Tool.from_function(name_shrimp)) result = await manager.call_tool( "name_shrimp", {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, @@ -322,7 +331,7 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover return a manager = ToolManager() - tool = manager.add_tool(something) + tool = manager.add_tool(Tool.from_function(something)) assert "ctx" not in json.dumps(tool.parameters) assert "Context" not in json.dumps(tool.parameters) assert "ctx" not in tool.fn_metadata.arg_model.model_fields @@ -339,19 +348,19 @@ def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return str(x) manager = ToolManager() - tool = manager.add_tool(tool_with_context) + tool = manager.add_tool(Tool.from_function(tool_with_context)) assert tool.context_kwarg == "ctx" def tool_without_context(x: int) -> str: # pragma: no cover return str(x) - tool = manager.add_tool(tool_without_context) + tool = manager.add_tool(Tool.from_function(tool_without_context)) assert tool.context_kwarg is None def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, RequestT]) -> str: # pragma: no cover return str(x) - tool = manager.add_tool(tool_with_parametrized_context) + tool = manager.add_tool(Tool.from_function(tool_with_parametrized_context)) assert tool.context_kwarg == "ctx" @pytest.mark.anyio @@ -363,7 +372,7 @@ def tool_with_context(x: int, ctx: Context) -> str: return str(x) manager = ToolManager() - manager.add_tool(tool_with_context) + manager.add_tool(Tool.from_function(tool_with_context)) result = await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) assert result == "42" @@ -377,7 +386,7 @@ async def async_tool(x: int, ctx: Context) -> str: return str(x) manager = ToolManager() - manager.add_tool(async_tool) + manager.add_tool(Tool.from_function(async_tool)) result = await manager.call_tool("async_tool", {"x": 42}, context=Context()) assert result == "42" @@ -390,7 +399,7 @@ def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error") manager = ToolManager() - manager.add_tool(tool_with_context) + manager.add_tool(Tool.from_function(tool_with_context)) with pytest.raises(ToolError, match="Error executing tool tool_with_context"): await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) @@ -411,7 +420,7 @@ def read_data(path: str) -> str: # pragma: no cover ) manager = ToolManager() - tool = manager.add_tool(read_data, annotations=annotations) + tool = manager.add_tool(Tool.from_function(read_data, annotations=annotations)) assert tool.annotations is not None assert tool.annotations.title == "File Reader" @@ -452,7 +461,7 @@ def get_user(user_id: int) -> UserOutput: return UserOutput(name="John", age=30) manager = ToolManager() - manager.add_tool(get_user) + manager.add_tool(Tool.from_function(get_user)) result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @@ -466,7 +475,7 @@ def double_number(n: int) -> int: return 10 manager = ToolManager() - manager.add_tool(double_number) + manager.add_tool(Tool.from_function(double_number)) result = await manager.call_tool("double_number", {"n": 5}, Context()) assert result == 10 result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True) @@ -487,7 +496,7 @@ def get_user_dict(user_id: int) -> UserDict: return UserDict(name="Alice", age=25) manager = ToolManager() - manager.add_tool(get_user_dict) + manager.add_tool(Tool.from_function(get_user_dict)) result = await manager.call_tool("get_user_dict", {"user_id": 1}, Context()) assert result == expected_output @@ -507,7 +516,7 @@ def get_person() -> Person: return Person("Bob", 40) manager = ToolManager() - manager.add_tool(get_person) + manager.add_tool(Tool.from_function(get_person)) result = await manager.call_tool("get_person", {}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == expected_output @@ -524,7 +533,7 @@ def get_numbers() -> list[int]: return expected_list manager = ToolManager() - manager.add_tool(get_numbers) + manager.add_tool(Tool.from_function(get_numbers)) result = await manager.call_tool("get_numbers", {}, Context()) assert result == expected_list result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True) @@ -539,7 +548,7 @@ def get_dict() -> dict[str, Any]: return {"key": "value"} manager = ToolManager() - manager.add_tool(get_dict, structured_output=False) + manager.add_tool(Tool.from_function(get_dict, structured_output=False)) result = await manager.call_tool("get_dict", {}, Context()) assert isinstance(result, dict) assert result == {"key": "value"} @@ -555,7 +564,7 @@ def get_user() -> UserOutput: # pragma: no cover return UserOutput(name="Test", age=25) manager = ToolManager() - tool = manager.add_tool(get_user) + tool = manager.add_tool(Tool.from_function(get_user)) # Test that output_schema is populated expected_schema = { @@ -575,7 +584,7 @@ def get_config() -> dict[str, Any]: return {"debug": True, "port": 8080, "features": ["auth", "logging"]} manager = ToolManager() - tool = manager.add_tool(get_config) + tool = manager.add_tool(Tool.from_function(get_config)) # Check output schema assert tool.output_schema is not None @@ -600,7 +609,7 @@ def get_scores() -> dict[str, int]: return {"alice": 100, "bob": 85, "charlie": 92} manager = ToolManager() - tool = manager.add_tool(get_scores) + tool = manager.add_tool(Tool.from_function(get_scores)) # Check output schema assert tool.output_schema is not None @@ -630,7 +639,7 @@ def process_data(input_data: str) -> str: # pragma: no cover metadata = {"ui": {"type": "form", "fields": ["input"]}, "version": "1.0"} manager = ToolManager() - tool = manager.add_tool(process_data, meta=metadata) + tool = manager.add_tool(Tool.from_function(process_data, meta=metadata)) assert tool.meta is not None assert tool.meta == metadata @@ -645,7 +654,7 @@ def simple_tool(x: int) -> int: # pragma: no cover return x * 2 manager = ToolManager() - tool = manager.add_tool(simple_tool) + tool = manager.add_tool(Tool.from_function(simple_tool)) assert tool.meta is None @@ -746,7 +755,7 @@ def complex_tool(data: str) -> str: # pragma: no cover } manager = ToolManager() - tool = manager.add_tool(complex_tool, meta=metadata) + tool = manager.add_tool(Tool.from_function(complex_tool, meta=metadata)) assert tool.meta is not None assert tool.meta["ui"]["components"][0]["validation"]["minLength"] == 5 @@ -762,7 +771,7 @@ def tool_with_empty_meta(x: int) -> int: # pragma: no cover return x manager = ToolManager() - tool = manager.add_tool(tool_with_empty_meta, meta={}) + tool = manager.add_tool(Tool.from_function(tool_with_empty_meta, meta={})) assert tool.meta is not None assert tool.meta == {} @@ -800,7 +809,7 @@ def add(a: int, b: int) -> int: # pragma: no cover return a + b manager = ToolManager() - manager.add_tool(add) + manager.add_tool(Tool.from_function(add)) # Verify tool exists assert manager.get_tool("add") is not None @@ -836,9 +845,9 @@ def divide(a: int, b: int) -> float: # pragma: no cover return a / b manager = ToolManager() - manager.add_tool(add) - manager.add_tool(multiply) - manager.add_tool(divide) + manager.add_tool(Tool.from_function(add)) + manager.add_tool(Tool.from_function(multiply)) + manager.add_tool(Tool.from_function(divide)) # Verify all tools exist assert len(manager.list_tools()) == 3 @@ -864,7 +873,7 @@ def greet(name: str) -> str: return f"Hello, {name}!" manager = ToolManager() - manager.add_tool(greet) + manager.add_tool(Tool.from_function(greet)) # Verify tool works before removal result = await manager.call_tool("greet", {"name": "World"}, Context()) @@ -885,7 +894,7 @@ def test_func() -> str: # pragma: no cover return "test" manager = ToolManager() - manager.add_tool(test_func) + manager.add_tool(Tool.from_function(test_func)) # Verify tool exists assert manager.get_tool("test_func") is not None