Skip to content

Add support for MCP's Streamable HTTP transport #1716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

import base64
import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from types import TracebackType
from typing import Any

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.shared.message import SessionMessage
from mcp.types import (
BlobResourceContents,
EmbeddedResource,
ImageContent,
JSONRPCMessage,
LoggingLevel,
TextContent,
TextResourceContents,
Expand All @@ -28,8 +30,8 @@

try:
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
except ImportError as _import_error:
raise ImportError(
'Please install the `mcp` package to use the MCP server, '
Expand All @@ -48,16 +50,16 @@ class MCPServer(ABC):
is_running: bool = False

_client: ClientSession
_read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
_write_stream: MemoryObjectSendStream[JSONRPCMessage]
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
_write_stream: MemoryObjectSendStream[SessionMessage]
_exit_stack: AsyncExitStack

@abstractmethod
@asynccontextmanager
async def client_streams(
self,
) -> AsyncIterator[
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]:
"""Create the streams for the MCP server."""
raise NotImplementedError('MCP Server subclasses must implement this method.')
Expand Down Expand Up @@ -227,7 +229,7 @@ async def main():
async def client_streams(
self,
) -> AsyncIterator[
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]:
server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd)
async with stdio_client(server=server) as (read_stream, write_stream):
Expand All @@ -241,11 +243,8 @@ def _get_log_level(self) -> LoggingLevel | None:
class MCPServerHTTP(MCPServer):
"""An MCP server that connects over streamable HTTP connections.

This class implements the SSE transport from the MCP specification.
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.

The name "HTTP" is used since this implemented will be adapted in future to use the new
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
This class implements the Streamable HTTP transport from the MCP specification.
See <https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http> for more information.

!!! note
Using this class as an async context manager will create a new pool of HTTP connections to connect
Expand All @@ -256,7 +255,7 @@ class MCPServerHTTP(MCPServer):
from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerHTTP

server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
server = MCPServerHTTP('http://localhost:3001/mcp') # (1)!
agent = Agent('openai:gpt-4o', mcp_servers=[server])

async def main():
Expand All @@ -269,27 +268,27 @@ async def main():
"""

url: str
"""The URL of the SSE endpoint on the MCP server.
"""The URL of the SSE or MCP endpoint on the MCP server.

For example for a server running locally, this might be `http://localhost:3001/sse`.
For example for a server running locally, this might be `http://localhost:3001/mcp`.
"""

headers: dict[str, Any] | None = None
"""Optional HTTP headers to be sent with each request to the SSE endpoint.
"""Optional HTTP headers to be sent with each request to the endpoint.

These headers will be passed directly to the underlying `httpx.AsyncClient`.
Useful for authentication, custom headers, or other HTTP-specific configurations.
"""

timeout: float = 5
"""Initial connection timeout in seconds for establishing the SSE connection.
timeout: timedelta | float = timedelta(seconds=5)
"""Initial connection timeout as a timedelta for establishing the connection.

This timeout applies to the initial connection setup and handshake.
If the connection cannot be established within this time, the operation will fail.
"""

sse_read_timeout: float = 60 * 5
"""Maximum time in seconds to wait for new SSE messages before timing out.
sse_read_timeout: timedelta | float = timedelta(minutes=5)
"""Maximum time as a timedelta to wait for new SSE messages before timing out.

This timeout applies to the long-lived SSE connection after it's established.
If no new messages are received within this time, the connection will be considered stale
Expand All @@ -303,15 +302,48 @@ async def main():
If `None`, no log level will be set.
"""

def __post_init__(self):
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)

@asynccontextmanager
async def client_streams(
self,
) -> AsyncIterator[
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]: # pragma: no cover
async with sse_client(
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)

async with streamablehttp_client(
url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
) as (read_stream, write_stream):
) as (read_stream, write_stream, _):
yield read_stream, write_stream

def _get_log_level(self) -> LoggingLevel | None:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.6.0; python_version >= '3.10'"]
mcp = ["mcp>=1.8.0; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]

Expand Down
43 changes: 30 additions & 13 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the MCP (Model Context Protocol) server implementation."""

from datetime import timedelta
from pathlib import Path

import pytest
Expand Down Expand Up @@ -62,25 +63,41 @@ async def test_stdio_server_with_cwd():
assert len(tools) == 10


def test_sse_server():
sse_server = MCPServerHTTP(url='http://localhost:8000/sse')
assert sse_server.url == 'http://localhost:8000/sse'
assert sse_server._get_log_level() is None # pyright: ignore[reportPrivateUsage]
def test_http_server():
http_server = MCPServerHTTP(url='http://localhost:8000/sse')
assert http_server.url == 'http://localhost:8000/sse'
assert http_server._get_log_level() is None # pyright: ignore[reportPrivateUsage]


def test_sse_server_with_header_and_timeout():
sse_server = MCPServerHTTP(
def test_http_server_with_header_and_timeout():
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
timeout=timedelta(seconds=10),
sse_read_timeout=timedelta(seconds=100),
log_level='info',
)
assert sse_server.url == 'http://localhost:8000/sse'
assert sse_server.headers is not None and sse_server.headers['my-custom-header'] == 'my-header-value'
assert sse_server.timeout == 10
assert sse_server.sse_read_timeout == 100
assert sse_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


def test_http_server_with_deprecated_arguments():
with pytest.warns(DeprecationWarning):
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
log_level='info',
)
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


@pytest.mark.vcr()
Expand Down
11 changes: 6 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.