diff --git a/mcp-run-python/deno.json b/mcp-run-python/deno.jsonc similarity index 98% rename from mcp-run-python/deno.json rename to mcp-run-python/deno.jsonc index cbe71d74a..9e079919e 100644 --- a/mcp-run-python/deno.json +++ b/mcp-run-python/deno.jsonc @@ -32,7 +32,7 @@ "src/*.ts", "src/prepareEnvCode.ts", // required to override gitignore "README.md", - "deno.json" + "deno.jsonc" ] } } diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index 6eb051f93..ba46f7d72 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -15,26 +15,29 @@ const VERSION = '0.0.13' export async function main() { const { args } = Deno - if (args.length === 1 && args[0] === 'stdio') { - await runStdio() - } else if (args.length >= 1 && args[0] === 'sse') { - const flags = parseArgs(Deno.args, { - string: ['port'], - default: { port: '3001' }, - }) - const port = parseInt(flags.port) - runSse(port) - } else if (args.length === 1 && args[0] === 'warmup') { - await warmup() + const flags = parseArgs(args, { + string: ['port', 'callbacks'], + default: { port: '3001' }, + }) + const { _: [task], callbacks, port } = flags + if (task === 'stdio') { + await runStdio(callbacks) + } else if (task === 'sse' || task === 'http') { + runSse(parseInt(port), callbacks) + } else if (task === 'warmup') { + await warmup(callbacks) } else { console.error( `\ Invalid arguments. -Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|sse|warmup] +Usage: + deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto \\ + jsr:@pydantic/mcp-run-python [stdio|sse|warmup] options: - --port <port> Port to run the SSE server on (default: 3001)`, + --port <port> Port to run the SSE server on (default: 3001). + --callbacks <python-signatures> Python code representing the signatures of client functions the server can call.`, ) Deno.exit(1) } @@ -43,7 +46,8 @@ options: /* * Create an MCP server with the `run_python_code` tool registered. */ -function createServer(): McpServer { +function createServer(callbacks?: string): McpServer { + const functions = _extractFunctions(callbacks) const server = new McpServer( { name: 'MCP Run Python', @@ -57,20 +61,39 @@ function createServer(): McpServer { }, ) - const toolDescription = `Tool to execute Python code and return stdout, stderr, and return value. + let toolDescription = `Tool to execute Python code and return stdout, stderr, and return value. -The code may be async, and the value on the last line will be returned as the return value. +The code may be async, and the value on the last line will be returned as the return. The code will be executed with Python 3.12. -Dependencies may be defined via PEP 723 script metadata, e.g. to install "pydantic", the script should start -with a comment of the form: +Dependencies may be defined via PEP 723 script metadata. + +To make HTTP requests, you must use the "httpx" library in async mode. +For example: + +\`\`\`python # /// script -# dependencies = ['pydantic'] +# dependencies = ['httpx'] # /// -print('python code here') +import httpx + +async with httpx.AsyncClient() as client: + response = await client.get('https://example.com') +# return the text of the page +response.text +\`\`\` ` + if (callbacks) { + toolDescription += ` +The following functions are already defined globally and available to call from within your code: + +\`\`\`python +${callbacks} +\`\`\` + ` + } let setLogLevel: LoggingLevel = 'emergency' @@ -85,29 +108,46 @@ print('python code here') { python_code: z.string().describe('Python code to run') }, async ({ python_code }: { python_code: string }) => { const logPromises: Promise<void>[] = [] - const result = await runCode([{ + const mainPy = { name: 'main.py', content: python_code, active: true, - }], (level, data) => { + } + const codeLog = (level: LoggingLevel, data: string) => { if (LogLevels.indexOf(level) >= LogLevels.indexOf(setLogLevel)) { logPromises.push(server.server.sendLoggingMessage({ level, data })) } - }) + } + async function clientCallback(func: string, args?: string, kwargs?: string) { + const { content } = await server.server.createMessage({ + messages: [], + maxTokens: 0, + systemPrompt: '', + metadata: { pydantic_custom_use: '__python_function_call__', func, args, kwargs }, + }) + if (content.type !== 'text') { + throw new Error('Expected return content type to be "text"') + } else { + return content.text + } + } + + const result = await runCode([mainPy], codeLog, functions, clientCallback) await Promise.all(logPromises) return { content: [{ type: 'text', text: asXml(result) }], } }, ) + return server } /* * Run the MCP server using the SSE transport, e.g. over HTTP. */ -function runSse(port: number) { - const mcpServer = createServer() +function runSse(port: number, callbacks?: string) { + const mcpServer = createServer(callbacks) const transports: { [sessionId: string]: SSEServerTransport } = {} const server = http.createServer(async (req, res) => { @@ -162,8 +202,8 @@ function runSse(port: number) { /* * Run the MCP server using the Stdio transport. */ -async function runStdio() { - const mcpServer = createServer() +async function runStdio(callbacks?: string) { + const mcpServer = createServer(callbacks) const transport = new StdioServerTransport() await mcpServer.connect(transport) } @@ -171,7 +211,11 @@ async function runStdio() { /* * Run pyodide to download packages which can otherwise interrupt the server */ -async function warmup() { +async function warmup(callbacks?: string) { + if (callbacks) { + const functions = _extractFunctions(callbacks) + console.error(`Functions extracted from callbacks: ${JSON.stringify(functions)}`) + } console.error( `Running warmup script for MCP Run Python version ${VERSION}...`, ) @@ -193,6 +237,10 @@ a console.log('\nwarmup successful 🎉') } +function _extractFunctions(callbacks?: string): string[] { + return callbacks ? [...callbacks.matchAll(/^async def (\w+)/g).map(([, f]) => f)] : [] +} + // list of log levels to use for level comparison const LogLevels: LoggingLevel[] = [ 'debug', diff --git a/mcp-run-python/src/prepare_env.py b/mcp-run-python/src/prepare_env.py index e22db9ca7..32422c296 100644 --- a/mcp-run-python/src/prepare_env.py +++ b/mcp-run-python/src/prepare_env.py @@ -10,15 +10,17 @@ import re import sys import traceback -from collections.abc import Iterable, Iterator +from collections.abc import Awaitable, Iterable, Iterator from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, TypedDict +from typing import Any, Callable, Literal, TypedDict import micropip import pyodide_js import tomllib +from pydantic import ConfigDict, TypeAdapter +from pydantic_core import to_json from pyodide.code import find_imports __all__ = 'prepare_env', 'dump_json' @@ -31,18 +33,18 @@ class File(TypedDict): @dataclass -class Success: +class PrepSuccess: dependencies: list[str] | None kind: Literal['success'] = 'success' @dataclass -class Error: +class PrepError: message: str kind: Literal['error'] = 'error' -async def prepare_env(files: list[File]) -> Success | Error: +async def prepare_env(files: list[File]) -> PrepSuccess | PrepError: sys.setrecursionlimit(400) cwd = Path.cwd() @@ -68,14 +70,12 @@ async def prepare_env(files: list[File]) -> Success | Error: except Exception: with open(logs_filename) as f: logs = f.read() - return Error(message=f'{logs} {traceback.format_exc()}') + return PrepError(message=f'{logs} {traceback.format_exc()}') - return Success(dependencies=dependencies) + return PrepSuccess(dependencies=dependencies) def dump_json(value: Any) -> str | None: - from pydantic_core import to_json - if value is None: return None if isinstance(value, str): @@ -84,6 +84,50 @@ def dump_json(value: Any) -> str | None: return to_json(value, indent=2, fallback=_json_fallback).decode() +class CallSuccess(TypedDict): + return_value: Any + kind: Literal['success'] + + +class CallError(TypedDict): + exc_type: str + message: str + kind: Literal['error'] + + +call_result_ta: TypeAdapter[CallSuccess | CallError] = TypeAdapter( + CallSuccess | CallError, config=ConfigDict(defer_build=True) +) + + +@dataclass(slots=True) +class RegisterFunction: + _func_name: str + _callback: Callable[[str, str | None, str | None], Awaitable[str]] + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: + result_json = await self._callback(self._func_name, _dump_args(args), _dump_args(kwargs)) + result = call_result_ta.validate_json(result_json) + if result['kind'] == 'success': + return result['return_value'] + + exc_type, message = result['exc_type'], result['message'] + try: + exc_type_ = __builtins__[exc_type] + except KeyError: + raise Exception(f'{message}\n(Raised exception type: {exc_type})') + else: + raise exc_type_(message) + + def __repr__(self) -> str: + return f'<client callback {self._func_name}>' + + +def _dump_args(value: Any) -> str | None: + if value: + return to_json(value, fallback=_json_fallback).decode() + + def _json_fallback(value: Any) -> Any: tp: Any = type(value) module = tp.__module__ @@ -95,7 +139,7 @@ def _json_fallback(value: Any) -> Any: elif module == 'pyodide.ffi': return value.to_py() else: - return repr(value) + return str(value) def _add_extra_dependencies(dependencies: list[str]) -> list[str]: diff --git a/mcp-run-python/src/runCode.ts b/mcp-run-python/src/runCode.ts index fdb1d7084..92cc5248e 100644 --- a/mcp-run-python/src/runCode.ts +++ b/mcp-run-python/src/runCode.ts @@ -1,4 +1,3 @@ -/* eslint @typescript-eslint/no-explicit-any: off */ import { loadPyodide } from 'pyodide' import { preparePythonCode } from './prepareEnvCode.ts' import type { LoggingLevel } from '@modelcontextprotocol/sdk/types.js' @@ -12,6 +11,8 @@ export interface CodeFile { export async function runCode( files: CodeFile[], log: (level: LoggingLevel, data: string) => void, + functionNames?: string[], + clientCallback?: (func: string, args?: string, kwargs?: string) => Promise<string>, ): Promise<RunSuccess | RunError> { // remove once https://github.com/pyodide/pyodide/pull/5514 is released const realConsoleLog = console.log @@ -58,6 +59,14 @@ export async function runCode( const prepareStatus = await preparePyEnv.prepare_env(pyodide.toPy(files)) + const globals: Record<string, unknown> = { __name__: '__main__' } + + if (functionNames && clientCallback) { + for (const functionName of functionNames) { + globals[functionName] = preparePyEnv.RegisterFunction(functionName, clientCallback) + } + } + let runResult: RunSuccess | RunError if (prepareStatus.kind == 'error') { runResult = { @@ -70,7 +79,7 @@ export async function runCode( const activeFile = files.find((f) => f.active)! || files[0] try { const rawValue = await pyodide.runPythonAsync(activeFile.content, { - globals: pyodide.toPy({ __name__: '__main__' }), + globals: pyodide.toPy(globals), filename: activeFile.name, }) runResult = { @@ -99,7 +108,7 @@ interface RunSuccess { // we could record stdout and stderr separately, but I suspect simplicity is more important output: string[] dependencies: string[] - returnValueJson: string | null + returnValueJson: string | undefined } interface RunError { @@ -153,6 +162,13 @@ function formatError(err: any): string { / {2}File "\/lib\/python\d+\.zip\/_pyodide\/.*\n {4}.*\n(?: {4,}\^+\n)?/g, '', ) + // remove frames from _prepare_env.py + errStr = errStr.replace( + / {2}File "\/tmp\/mcp_run_python\/_prepare_env.py".*\n {4,}.+\n/g, + '', + ) + // remove trailing newlines + errStr = errStr.replace(/\n+$/, '') return errStr } @@ -164,8 +180,12 @@ interface PrepareError { kind: 'error' message: string } + interface PreparePyEnv { prepare_env: (files: CodeFile[]) => Promise<PrepareSuccess | PrepareError> - // deno-lint-ignore no-explicit-any - dump_json: (value: any) => string | null + RegisterFunction: ( + func_name: string, + callback: (func_name: string, args?: string, kwargs?: string) => Promise<string>, + ) => unknown + dump_json: (value: unknown) => string | undefined } diff --git a/mcp-run-python/test_mcp_servers.py b/mcp-run-python/test_mcp_servers.py index a11a6d5ba..8886bc08e 100644 --- a/mcp-run-python/test_mcp_servers.py +++ b/mcp-run-python/test_mcp_servers.py @@ -19,9 +19,9 @@ DENO_ARGS = [ 'run', '-N', + '--node-modules-dir=auto', '-R=mcp-run-python/node_modules', '-W=mcp-run-python/node_modules', - '--node-modules-dir=auto', 'mcp-run-python/src/main.ts', ] @@ -147,7 +147,6 @@ async def test_list_tools(mcp_session: ClientSession) -> None: print(unknown) ^^^^^^^ NameError: name 'unknown' is not defined - </error>\ """), id='undefined-variable', diff --git a/mcp-run-python/uprev.py b/mcp-run-python/uprev.py new file mode 100644 index 000000000..f66cb0d64 --- /dev/null +++ b/mcp-run-python/uprev.py @@ -0,0 +1,33 @@ +import re +import sys +from pathlib import Path + +if len(sys.argv) != 2: + print('Usage: python uprev.py <new_version>') + sys.exit(1) + +new_version = sys.argv[1] +this_dir = Path(__file__).parent +root_dir = (this_dir / '..').resolve() + +path_regexes = [ + (this_dir / 'deno.jsonc', r'^\s+"version": "(.+?)"'), + (this_dir / 'src/main.ts', "^const VERSION = '(.+?)'"), + (root_dir / 'pydantic_ai_slim/pydantic_ai/mcp_run_python.py', "^MCP_RUN_PYTHON_VERSION = '(.+?)'"), +] + + +if __name__ == '__main__': + for path, regex in path_regexes: + path_pretty = path.relative_to(root_dir) + + def replace_version(m: re.Match[str]) -> str: + version = m.group(1) + print(f'Updated version from {version} to {new_version} in {path_pretty}') + return m.group(0).replace(version, new_version) + + content = path.read_text() + content, count = re.subn(regex, replace_version, content, count=1, flags=re.M) + if count != 1: + raise ValueError(f'Failed to update version in {path}') + path.write_text(content) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c5dd95f2c..1fff18ea8 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -1,12 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass from pathlib import Path from types import TracebackType -from typing import Any +from typing import Any, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.types import JSONRPCMessage, LoggingLevel @@ -15,10 +15,11 @@ from pydantic_ai.tools import ToolDefinition try: + from mcp import types as mcp_types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client - from mcp.types import CallToolResult + from mcp.shared.context import RequestContext except ImportError as _import_error: raise ImportError( 'Please install the `mcp` package to use the MCP server, ' @@ -27,6 +28,11 @@ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP' +OptionalSamplingFunction = Callable[ + [RequestContext['ClientSession', Any], mcp_types.CreateMessageRequestParams], + Awaitable[mcp_types.CreateMessageResult | mcp_types.ErrorData | None], +] + class MCPServer(ABC): """Base class for attaching agents to MCP servers. @@ -55,8 +61,22 @@ async def client_streams( @abstractmethod def _get_log_level(self) -> LoggingLevel | None: """Get the log level for the MCP server.""" + + @abstractmethod + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + """Maybe get a sampling callback function for this server definition.""" raise NotImplementedError('MCP Server subclasses must implement this method.') + async def _sampling_callback( + self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams + ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: + """MCP sampling callback.""" + if custom_sampling_callback := self._custom_sampling_callback(): + if result := await custom_sampling_callback(context, params): + return result + + raise NotImplementedError('MCP Sampling not yet implemented, except for custom sampling callbacks') + async def list_tools(self) -> list[ToolDefinition]: """Retrieve tools that are currently active on the server. @@ -74,7 +94,7 @@ async def list_tools(self) -> list[ToolDefinition]: for tool in tools.tools ] - async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> mcp_types.CallToolResult: """Call a tool on the server. Args: @@ -90,7 +110,9 @@ async def __aenter__(self) -> Self: self._exit_stack = AsyncExitStack() self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) - client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream) + client = ClientSession( + read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback + ) self._client = await self._exit_stack.enter_async_context(client) await self._client.initialize() @@ -168,6 +190,9 @@ async def main(): cwd: str | Path | None = None """The working directory to use when spawning the process.""" + custom_sampling_callback: OptionalSamplingFunction | None = None + """Optional callback function for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -181,6 +206,9 @@ async def client_streams( def _get_log_level(self) -> LoggingLevel | None: return self.log_level + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + return self.custom_sampling_callback + @dataclass class MCPServerHTTP(MCPServer): @@ -248,6 +276,9 @@ async def main(): If `None`, no log level will be set. """ + custom_sampling_callback: OptionalSamplingFunction | None = None + """Optional callback function for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -261,3 +292,6 @@ async def client_streams( def _get_log_level(self) -> LoggingLevel | None: return self.log_level + + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + return self.custom_sampling_callback diff --git a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py new file mode 100644 index 000000000..6b46b7904 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py @@ -0,0 +1,232 @@ +import ast +import inspect +import subprocess +from collections.abc import AsyncIterator, Awaitable, Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass +from time import time +from typing import Any, Callable, Literal, cast, override + +import anyio +import httpx +import pydantic_core +from mcp import ClientSession, types as mcp_types +from mcp.shared.context import RequestContext +from pydantic import BaseModel, Json +from pydantic._internal._validate_call import ValidateCallWrapper # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import TypedDict + +from .mcp import MCPServerHTTP, MCPServerStdio + +__all__ = 'mcp_run_python_stdio', 'MCPRunPythonHTTP' + +MCP_RUN_PYTHON_VERSION = '0.0.13' +Callback = Callable[..., Awaitable[Any]] + + +def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool = False) -> MCPServerStdio: + """Prepare a server server connection using `'stdio'` transport. + + Args: + callbacks: A sequence of callback functions to be register on the server. + local_code: Whether to run local `mcp-run-python` code, this is mostly used for development and testing. + + Returns: + A server connection definition. + """ + return MCPServerStdio( + 'deno', + args=_deno_args('stdio', callbacks, local_code), + cwd='mcp-run-python' if local_code else None, + custom_sampling_callback=_PythonSamplingCallback(callbacks) if callbacks else None, + ) + + +@dataclass +class MCPRunPythonHTTP: + """Setup for `mcp-run-python` running with HTTP transport.""" + + callbacks: Sequence[Callback] = () + """Callbacks to be registered on the server.""" + port: int = 3001 + """Port to run the server on.""" + local_code: bool = False + """Whether to run local `mcp-run-python` code, this is mostly used for development and testing.""" + + @property + def url(self) -> str: + """URL the server will be run on.""" + return f'http://localhost:{self.port}/sse' + + def server_def(self, url: str | None = None) -> MCPServerHTTP: + """Create a server definition to pass to a pydantic-ai [`Agent`][pydantic_ai.Agent].""" + return MCPServerHTTP( + url or self.url, + custom_sampling_callback=_PythonSamplingCallback(self.callbacks) if self.callbacks else None, + ) + + def run(self) -> None: + """Run the server and block until it is terminated.""" + try: + subprocess.run(self._args(), cwd=self._cwd(), check=True) + except KeyboardInterrupt: + pass + + @asynccontextmanager + async def run_context(self, server_wait_timeout: float | None = 2) -> AsyncIterator[None]: + """Run the server as an async context manager. + + Args: + server_wait_timeout: The timeout in seconds to wait for the server to start, or `None` to not wait. + """ + p = await anyio.open_process(self._args(), cwd=self._cwd(), stdout=None, stderr=None) + async with p: + if server_wait_timeout: + await self.wait_for_server(server_wait_timeout) + yield + p.terminate() + + async def wait_for_server(self, timeout: float = 2): + """Wait for the server to be ready.""" + async with httpx.AsyncClient(timeout=0.01) as client: + start = time() + while True: + try: + await client.head(self.url) + except httpx.RequestError: + if time() - start > timeout: + raise TimeoutError(f'Server did not start within {timeout} seconds') + await anyio.sleep(0.1) + else: + break + + def _args(self) -> list[str]: + return ['deno'] + _deno_args('http', self.callbacks, self.local_code) + ['--port', str(self.port)] + + def _cwd(self) -> str | None: + return 'mcp-run-python' if self.local_code else None + + +def _deno_args(mode: Literal['stdio', 'http'], callbacks: Sequence[Callback], local_code: bool) -> list[str]: + args = [ + 'run', + '-N', + '-R=node_modules', + '-W=node_modules', + '--node-modules-dir=auto', + 'src/main.ts' if local_code else f'jsr:@pydantic/mcp-run-python@{MCP_RUN_PYTHON_VERSION}', + mode, + ] + + if callbacks: + sigs = '\n\n'.join(_callback_signature(cb) for cb in callbacks) + args += ['--callbacks', sigs] + return args + + +def _callback_signature(func: Callback) -> str: + """Extract the signature of a function. + + This simply means getting the source code of the function, and removing the body of the function while keeping the docstring. + """ + source = inspect.getsource(func) + ast_mod = ast.parse(source) + assert isinstance(ast_mod, ast.Module), f'Expected Module, got {type(ast_mod)}' + assert len(ast_mod.body) == 1, f'Expected single function definition, got {len(ast_mod.body)}' + f = ast_mod.body[0] + assert isinstance(f, ast.AsyncFunctionDef), f'Expected an async function, got {type(func)}' + lines = source.splitlines() + e = f.body[0] + # if the first expression is a docstring, keep it and no need for an ellipsis as the body + if isinstance(e, ast.Expr) and isinstance(e.value, ast.Constant) and isinstance(e.value.value, str): + e = f.body[1] + lines = lines[: e.lineno - 1] + else: + lines = lines[: e.lineno - 1] + lines.append(e.col_offset * ' ' + '...') + + # if the function has any decorators, this will remove them. + if f.lineno != 1: + lines = lines[f.lineno - 1 :] + + return '\n'.join(lines) + + +class _PythonSamplingCallback: + def __init__(self, callbacks: Sequence[Callback]): + self.function_lookup: dict[str, ValidateCallWrapper] = {} + for callback in callbacks: + name = callback.__name__ + if name in self.function_lookup: + raise ValueError(f'Duplicate callback name: {name}') + self.function_lookup[name] = ValidateCallWrapper( + callback, # pyright: ignore[reportArgumentType] + None, + False, + None, + ) + + async def __call__( + self, + context: RequestContext[ClientSession, Any], + params: mcp_types.CreateMessageRequestParams, + ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData | None: + if not params.metadata or params.metadata.get('pydantic_custom_use') != '__python_function_call__': + return None + + call_metadata = _PythonCallMetadata.model_validate(params.metadata) + if function_wrapper := self.function_lookup.get(call_metadata.func): + content: _CallSuccess | _CallError + try: + return_value = await function_wrapper.__pydantic_validator__.validate_python(call_metadata.args_kwargs) + except ValueError as e: + # special support for ValueError since it's commonly subclassed, and it's the parent of ValidationError + # TODO we should probably have specific support for other common errors + content = _CallError(exc_type='ValueError', message=str(e), kind='error') + except Exception as e: + content = _CallError(exc_type=e.__class__.__name__, message=str(e), kind='error') + else: + content = _CallSuccess(return_value=return_value, kind='success') + + content_text = pydantic_core.to_json(content, fallback=_json_fallback).decode() + return mcp_types.CreateMessageResult( + role='assistant', content=mcp_types.TextContent(type='text', text=content_text), model='python' + ) + else: + raise LookupError(f'Function `{call_metadata.func}` not found') + + @override + def __repr__(self) -> str: + return f'<_PythonSamplingCallback: {", ".join(map(repr, self.function_lookup))}>' + + +class _PythonCallMetadata(BaseModel): + func: str + args: Json[list[Any]] | None = None # JSON + kwargs: Json[dict[str, Any]] | None = None # JSON + + @property + def args_kwargs(self) -> pydantic_core.ArgsKwargs: + return pydantic_core.ArgsKwargs(tuple(self.args or ()), self.kwargs) + + +class _CallSuccess(TypedDict): + return_value: Any + kind: Literal['success'] + + +class _CallError(TypedDict): + exc_type: str + message: str + kind: Literal['error'] + + +def _json_fallback(value: Any) -> Any: + tp = cast(Any, type(value)) + if tp.__module__ == 'numpy': + if tp.__name__ in {'ndarray', 'matrix'}: + return value.tolist() + else: + return value.item() + else: + return repr(value)