Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ async def main():

logger = logging.getLogger(__name__)

DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS = 1.0

LifespanResultT = TypeVar("LifespanResultT", default=Any)


Expand Down Expand Up @@ -347,6 +349,13 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
# When True, treat read EOF as a half-close and allow in-flight handlers
# to drain their responses via the still-open write stream (e.g. stdio
# with bash-redirected stdin).
drain_on_read_close: bool = False,
# Maximum time to wait for in-flight handlers to drain after read EOF.
# None means wait indefinitely.
read_eof_drain_timeout_seconds: float | None = DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -356,6 +365,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
close_write_stream_on_read_close=not drain_on_read_close,
)
)

Expand All @@ -378,11 +388,19 @@ async def run(
raise_exceptions,
)
finally:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()
if drain_on_read_close:
if read_eof_drain_timeout_seconds is not None:
with anyio.move_on_after(read_eof_drain_timeout_seconds) as drain_scope:
while session.has_in_flight_requests:
await anyio.sleep(0.01)
if drain_scope.cancelled_caught:
tg.cancel_scope.cancel()
else:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()

async def _handle_message(
self,
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_on_read_close=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def __init__(
write_stream: WriteStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
close_write_stream_on_read_close: bool = True,
) -> None:
super().__init__(read_stream, write_stream)
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=close_write_stream_on_read_close)
self._stateless = stateless
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
Expand Down
20 changes: 19 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,30 @@ def __init__(
write_stream: WriteStream[SessionMessage],
# If none, reading will never time out
read_timeout_seconds: float | None = None,
# When True, closing/EOF on the read stream closes the write stream too.
#
# For full-duplex transports (e.g., stdio), an input EOF can be a
# half-close: the peer is done sending requests but still expects
# responses on the output stream. In that case, callers may opt out so
# in-flight handlers can drain their responses before shutdown.
close_write_stream_on_read_close: bool = True,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._session_read_timeout_seconds = read_timeout_seconds
self._close_write_stream_on_read_close = close_write_stream_on_read_close
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()
self._exit_stack.push_async_callback(self._read_stream.aclose)
self._exit_stack.push_async_callback(self._write_stream.aclose)

@property
def has_in_flight_requests(self) -> bool:
"""Whether any received requests are still awaiting a response."""
return bool(self._in_flight)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
Expand Down Expand Up @@ -331,7 +346,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_close:
await stack.enter_async_context(self._write_stream)
try:

async def _handle_session_message(message: SessionMessage) -> None:
Expand Down
211 changes: 196 additions & 15 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InitializeRequestParams,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Expand Down Expand Up @@ -100,17 +101,79 @@ async def first_request():


@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_on_transport_close():
"""When the transport closes mid-request, server.run() must cancel in-flight
handlers rather than join on them.
async def test_server_drains_in_flight_handlers_on_transport_read_eof():
"""When the transport's read side hits EOF (e.g., stdio stdin closes), the
server must drain already-started handlers so their responses reach the
peer via the still-open write side."""
handler_started = anyio.Event()
handler_allowed_to_finish = anyio.Event()
server_run_returned = anyio.Event()

Without the cancel, the task group waits for the handler, which then tries
to respond through a write stream that _receive_loop already closed,
raising ClosedResourceError and crashing server.run() with exit code 1.
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await handler_allowed_to_finish.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

This drives server.run() with raw memory streams because InMemoryTransport
wraps it in its own finally-cancel (_memory.py) which masks the bug.
"""
server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(
server_read,
server_write,
server.create_initialization_options(),
drain_on_read_close=True,
read_eof_drain_timeout_seconds=None,
)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()

# Close the server's input stream — this is what stdin EOF does.
# server.run()'s incoming_messages loop ends, finally-cancel fires,
# handler gets CancelledError, server.run() returns.
await to_server.aclose()

handler_allowed_to_finish.set()

response = await from_server.receive()
assert isinstance(response.message, JSONRPCResponse)
assert response.message.id == 2

await server_run_returned.wait()


@pytest.mark.anyio
async def test_server_bounds_drain_on_read_eof_when_handler_never_finishes():
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()
Expand All @@ -121,14 +184,135 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
await anyio.sleep_forever()
finally:
handler_cancelled.set()
# unreachable: sleep_forever only exits via cancellation
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(
server_read,
server_write,
server.create_initialization_options(),
drain_on_read_close=True,
read_eof_drain_timeout_seconds=0.05,
)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(2):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight
request handlers will get cancelled too. Cancellation must be re-raised so
the task group can unwind cleanly."""
handler_started = anyio.Event()
server_run_returned = anyio.Event()
cancel_scope = anyio.CancelScope()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await anyio.sleep_forever()
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
try:
with cancel_scope:
await server.run(server_read, server_write, server.create_initialization_options())
finally:
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
cancel_scope.cancel()
await server_run_returned.wait()


@pytest.mark.anyio
async def test_server_drops_response_when_write_stream_closes_mid_request():
"""If the write side closes while a handler is in-flight, responding may
raise (ClosedResourceError/BrokenResourceError). The handler task should
exit without crashing the server."""
handler_started = anyio.Event()
allow_finish = anyio.Event()
server_run_returned = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await allow_finish.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(server_read, server_write, server.create_initialization_options())
server_run_returned.set()
Expand Down Expand Up @@ -161,16 +345,13 @@ async def run_server():
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await server_write.aclose()

# Close the server's input stream — this is what stdin EOF does.
# server.run()'s incoming_messages loop ends, finally-cancel fires,
# handler gets CancelledError, server.run() returns.
allow_finish.set()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
Expand Down
Loading
Loading