From a7bf339953bd3c7aada50744ac672a43f2058fe0 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Mon, 3 Nov 2025 21:41:37 +0530 Subject: [PATCH 01/11] properly propagate contextvars for server protocols --- Lib/asyncio/base_events.py | 9 ++++-- Lib/asyncio/selector_events.py | 50 +++++++++++++++++----------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 8cbb71f708537f..a4afd6e7a5e1f6 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -14,6 +14,7 @@ """ import collections +import contextvars import collections.abc import concurrent.futures import errno @@ -289,6 +290,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None + self._context = contextvars.copy_context() def __repr__(self): return f'<{self.__class__.__name__} sockets={self.sockets!r}>' @@ -318,7 +320,7 @@ def _start_serving(self): self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, self, self._backlog, self._ssl_handshake_timeout, - self._ssl_shutdown_timeout) + self._ssl_shutdown_timeout, context=self._context) def get_loop(self): return self._loop @@ -1211,9 +1213,10 @@ async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_shutdown_timeout=None, context=None): sock.setblocking(False) + context = context if context is not None else contextvars.copy_context() protocol = protocol_factory() waiter = self.create_future() @@ -1225,7 +1228,7 @@ async def _create_connection_transport( ssl_handshake_timeout=ssl_handshake_timeout, ssl_shutdown_timeout=ssl_shutdown_timeout) else: - transport = self._make_socket_transport(sock, protocol, waiter) + transport = self._make_socket_transport(sock, protocol, waiter, context=context) try: await waiter diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index ff7e16df3c6273..85bbf4c7f360c1 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -67,10 +67,10 @@ def __init__(self, selector=None): self._transports = weakref.WeakValueDictionary() def _make_socket_transport(self, sock, protocol, waiter=None, *, - extra=None, server=None): + extra=None, server=None, context=None): self._ensure_fd_no_transport(sock) return _SelectorSocketTransport(self, sock, protocol, waiter, - extra, server) + extra, server, context=context) def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, @@ -159,16 +159,16 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): self._add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, sslcontext, server, backlog, - ssl_handshake_timeout, ssl_shutdown_timeout) + ssl_handshake_timeout, ssl_shutdown_timeout, context) def _accept_connection( self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -204,21 +204,21 @@ def _accept_connection( self._start_serving, protocol_factory, sock, sslcontext, server, backlog, ssl_handshake_timeout, - ssl_shutdown_timeout) + ssl_shutdown_timeout, context) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} accept = self._accept_connection2( protocol_factory, conn, extra, sslcontext, server, - ssl_handshake_timeout, ssl_shutdown_timeout) - self.create_task(accept) + ssl_handshake_timeout, ssl_shutdown_timeout, context=context) + self.create_task(accept, context=context) async def _accept_connection2( self, protocol_factory, conn, extra, sslcontext=None, server=None, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): protocol = None transport = None try: @@ -233,7 +233,7 @@ async def _accept_connection2( else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, - server=server) + server=server, context=context) try: await waiter @@ -275,9 +275,9 @@ def _ensure_fd_no_transport(self, fd): f'File descriptor {fd!r} is used by transport ' f'{transport!r}') - def _add_reader(self, fd, callback, *args): + def _add_reader(self, fd, callback, *args, context=None): self._check_closed() - handle = events.Handle(callback, args, self, None) + handle = events.Handle(callback, args, self, context=context) key = self._selector.get_map().get(fd) if key is None: self._selector.register(fd, selectors.EVENT_READ, @@ -770,7 +770,7 @@ class _SelectorTransport(transports._FlowControlMixin, # exception) _sock = None - def __init__(self, loop, sock, protocol, extra=None, server=None): + def __init__(self, loop, sock, protocol, extra=None, server=None, context=None): super().__init__(extra, loop) self._extra['socket'] = trsock.TransportSocket(sock) try: @@ -784,7 +784,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._extra['peername'] = None self._sock = sock self._sock_fd = sock.fileno() - + self._context = context self._protocol_connected = False self.set_protocol(protocol) @@ -866,7 +866,7 @@ def close(self): if not self._buffer: self._conn_lost += 1 self._loop._remove_writer(self._sock_fd) - self._loop.call_soon(self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None, context=self._context) def __del__(self, _warn=warnings.warn): if self._sock is not None: @@ -899,7 +899,7 @@ def _force_close(self, exc): self._closing = True self._loop._remove_reader(self._sock_fd) self._conn_lost += 1 - self._loop.call_soon(self._call_connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc, context=self._context) def _call_connection_lost(self, exc): try: @@ -921,7 +921,7 @@ def get_write_buffer_size(self): def _add_reader(self, fd, callback, *args): if not self.is_reading(): return - self._loop._add_reader(fd, callback, *args) + self._loop._add_reader(fd, callback, *args, context=self._context) class _SelectorSocketTransport(_SelectorTransport): @@ -930,10 +930,10 @@ class _SelectorSocketTransport(_SelectorTransport): _sendfile_compatible = constants._SendfileMode.TRY_NATIVE def __init__(self, loop, sock, protocol, waiter=None, - extra=None, server=None): - + extra=None, server=None, context=None): + assert context is not None self._read_ready_cb = None - super().__init__(loop, sock, protocol, extra, server) + super().__init__(loop, sock, protocol, extra, server, context) self._eof = False self._empty_waiter = None if _HAS_SENDMSG: @@ -945,14 +945,14 @@ def __init__(self, loop, sock, protocol, waiter=None, # decreases the latency (in some cases significantly.) base_events._set_nodelay(self._sock) - self._loop.call_soon(self._protocol.connection_made, self) + self._loop.call_soon(self._protocol.connection_made, self, context=context) # only start reading when connection_made() has been called self._loop.call_soon(self._add_reader, - self._sock_fd, self._read_ready) + self._sock_fd, self._read_ready, context=context) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, - waiter, None) + waiter, None, context=context) def set_protocol(self, protocol): if isinstance(protocol, protocols.BufferedProtocol): @@ -1081,7 +1081,7 @@ def write(self, data): if not data: return # Not all was written; register write handler. - self._loop._add_writer(self._sock_fd, self._write_ready) + self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) # Add it to the buffer. self._buffer.append(data) @@ -1185,7 +1185,7 @@ def writelines(self, list_of_data): self._write_ready() # If the entire buffer couldn't be written, register a write handler if self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) + self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) self._maybe_pause_protocol() def can_write_eof(self): From 40a8f7c151d78e46bf48e48c5510677d18bafaa6 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Tue, 2 Dec 2025 11:50:42 +0530 Subject: [PATCH 02/11] fix _add_writer --- Lib/asyncio/selector_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 85bbf4c7f360c1..67b2f07c32e4ff 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -309,9 +309,9 @@ def _remove_reader(self, fd): else: return False - def _add_writer(self, fd, callback, *args): + def _add_writer(self, fd, callback, *args, context=None): self._check_closed() - handle = events.Handle(callback, args, self, None) + handle = events.Handle(callback, args, self, context=context) key = self._selector.get_map().get(fd) if key is None: self._selector.register(fd, selectors.EVENT_WRITE, From f845484fc689f6b6d6eb295660ffdfca82c76204 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Tue, 13 Jan 2026 19:48:26 +0530 Subject: [PATCH 03/11] refactor 1 --- Lib/asyncio/selector_events.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 67b2f07c32e4ff..4c620a254285f9 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -866,7 +866,7 @@ def close(self): if not self._buffer: self._conn_lost += 1 self._loop._remove_writer(self._sock_fd) - self._loop.call_soon(self._call_connection_lost, None, context=self._context) + self._call_soon(self._call_connection_lost, None) def __del__(self, _warn=warnings.warn): if self._sock is not None: @@ -899,7 +899,7 @@ def _force_close(self, exc): self._closing = True self._loop._remove_reader(self._sock_fd) self._conn_lost += 1 - self._loop.call_soon(self._call_connection_lost, exc, context=self._context) + self._call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: @@ -923,6 +923,11 @@ def _add_reader(self, fd, callback, *args): return self._loop._add_reader(fd, callback, *args, context=self._context) + def _add_writer(self, fd, callback, *args): + self._loop._add_writer(fd, callback, *args, context=self._context) + + def _call_soon(self, callback, *args): + self._loop.call_soon(callback, *args, context=self._context) class _SelectorSocketTransport(_SelectorTransport): @@ -945,14 +950,12 @@ def __init__(self, loop, sock, protocol, waiter=None, # decreases the latency (in some cases significantly.) base_events._set_nodelay(self._sock) - self._loop.call_soon(self._protocol.connection_made, self, context=context) + self._call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._add_reader, - self._sock_fd, self._read_ready, context=context) + self._call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called - self._loop.call_soon(futures._set_result_unless_cancelled, - waiter, None, context=context) + self._call_soon(futures._set_result_unless_cancelled, waiter, None) def set_protocol(self, protocol): if isinstance(protocol, protocols.BufferedProtocol): @@ -1081,7 +1084,7 @@ def write(self, data): if not data: return # Not all was written; register write handler. - self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) + self._add_writer(self._sock_fd, self._write_ready) # Add it to the buffer. self._buffer.append(data) @@ -1185,7 +1188,7 @@ def writelines(self, list_of_data): self._write_ready() # If the entire buffer couldn't be written, register a write handler if self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) + self._add_writer(self._sock_fd, self._write_ready) self._maybe_pause_protocol() def can_write_eof(self): @@ -1226,14 +1229,12 @@ def __init__(self, loop, sock, protocol, address=None, super().__init__(loop, sock, protocol, extra) self._address = address self._buffer_size = 0 - self._loop.call_soon(self._protocol.connection_made, self) + self._call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._add_reader, - self._sock_fd, self._read_ready) + self._call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called - self._loop.call_soon(futures._set_result_unless_cancelled, - waiter, None) + self._call_soon(futures._set_result_unless_cancelled, waiter, None) def get_write_buffer_size(self): return self._buffer_size @@ -1280,7 +1281,7 @@ def sendto(self, data, addr=None): self._sock.sendto(data, addr) return except (BlockingIOError, InterruptedError): - self._loop._add_writer(self._sock_fd, self._sendto_ready) + self._add_writer(self._sock_fd, self._sendto_ready) except OSError as exc: self._protocol.error_received(exc) return From f556a848abb8b97c3db54bee70748eaf0638e8bb Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Tue, 13 Jan 2026 19:56:13 +0530 Subject: [PATCH 04/11] all tests pass on mac --- Lib/asyncio/selector_events.py | 1 - Lib/test/test_asyncio/test_base_events.py | 2 +- Lib/test/test_asyncio/utils.py | 8 ++++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 4c620a254285f9..6ee0a878f49c78 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -936,7 +936,6 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None, context=None): - assert context is not None self._read_ready_cb = None super().__init__(loop, sock, protocol, extra, server, context) self._eof = False diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 8c02de77c24740..3a409d6f1cdcdd 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -2104,7 +2104,7 @@ def test_accept_connection_exception(self, m_log): constants.ACCEPT_RETRY_DELAY, # self.loop._start_serving mock.ANY, - MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) + MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY, mock.ANY) def test_call_coroutine(self): async def simple_coroutine(): diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index a480e16e81bb91..62cfcf8ceb5f2a 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -388,8 +388,8 @@ def close(self): else: # pragma: no cover raise AssertionError("Time generator is not finished") - def _add_reader(self, fd, callback, *args): - self.readers[fd] = events.Handle(callback, args, self, None) + def _add_reader(self, fd, callback, *args, context=None): + self.readers[fd] = events.Handle(callback, args, self, context) def _remove_reader(self, fd): self.remove_reader_count[fd] += 1 @@ -414,8 +414,8 @@ def assert_no_reader(self, fd): if fd in self.readers: raise AssertionError(f'fd {fd} is registered') - def _add_writer(self, fd, callback, *args): - self.writers[fd] = events.Handle(callback, args, self, None) + def _add_writer(self, fd, callback, *args, context=None): + self.writers[fd] = events.Handle(callback, args, self, context) def _remove_writer(self, fd): self.remove_writer_count[fd] += 1 From cbbf0c61441e188a3de2d324899bd14dd1c5c79e Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Tue, 13 Jan 2026 20:32:56 +0530 Subject: [PATCH 05/11] temp fix windows --- Lib/asyncio/proactor_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 3fa93b14a6787f..3a54ac7931981c 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -642,7 +642,7 @@ def __init__(self, proactor): signal.set_wakeup_fd(self._csock.fileno()) def _make_socket_transport(self, sock, protocol, waiter=None, - extra=None, server=None): + extra=None, server=None, context=None): return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) @@ -837,7 +837,7 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_shutdown_timeout=None, context=None): def loop(f=None): try: From b344e60dab74dfa71f7cefb3f542602cfb3c9294 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 11:13:31 +0530 Subject: [PATCH 06/11] add tests --- Lib/asyncio/selector_events.py | 5 +- Lib/test/test_asyncio/test_server_context.py | 264 +++++++++++++++++++ 2 files changed, 267 insertions(+), 2 deletions(-) create mode 100644 Lib/test/test_asyncio/test_server_context.py diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 6ee0a878f49c78..f695cd2f975cac 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -209,10 +209,11 @@ def _accept_connection( raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} + conn_context = context.copy() if context is not None else None accept = self._accept_connection2( protocol_factory, conn, extra, sslcontext, server, - ssl_handshake_timeout, ssl_shutdown_timeout, context=context) - self.create_task(accept, context=context) + ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context) + self.create_task(accept, context=conn_context) async def _accept_connection2( self, protocol_factory, conn, extra, diff --git a/Lib/test/test_asyncio/test_server_context.py b/Lib/test/test_asyncio/test_server_context.py new file mode 100644 index 00000000000000..9dc210aebfa50c --- /dev/null +++ b/Lib/test/test_asyncio/test_server_context.py @@ -0,0 +1,264 @@ + +import asyncio +import contextvars +import unittest + +from unittest import TestCase + +def tearDownModule(): + asyncio.events._set_event_loop_policy(None) + +class ServerContextvarsTestCase: + loop_factory = None # To be defined in subclasses + + def run_coro(self, coro): + return asyncio.run(coro, loop_factory=self.loop_factory) + + def test_start_server1(self): + # Test that asyncio.start_server captures the context at the time of server creation + async def test(): + var = contextvars.ContextVar("var", default="default") + + async def handle_client(reader, writer): + value = var.get() + writer.write(value.encode()) + await writer.drain() + writer.close() + + server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + # change the value + var.set("after_server") + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + data = await reader.read(100) + writer.close() + await writer.wait_closed() + return data.decode() + + async with server: + addr = server.sockets[0].getsockname() + self.assertEqual(await client(addr), "default") + + self.assertEqual(var.get(), "after_server") + + self.run_coro(test()) + + def test_start_server2(self): + # Test that mutations to the context in one handler don't affect other handlers or the server's context + async def test(): + var = contextvars.ContextVar("var", default="default") + + async def handle_client(reader, writer): + value = var.get() + writer.write(value.encode()) + var.set("in_handler") + await writer.drain() + writer.close() + + server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + var.set("after_server") + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + data = await reader.read(100) + writer.close() + await writer.wait_closed() + return data.decode() + + async with server: + addr = server.sockets[0].getsockname() + self.assertEqual(await client(addr), "default") + self.assertEqual(await client(addr), "default") + self.assertEqual(await client(addr), "default") + + self.assertEqual(var.get(), "after_server") + + self.run_coro(test()) + + def test_start_server3(self): + # Test that mutations to context in concurrent handlers don't affect each other or the server's context + async def test(): + var = contextvars.ContextVar("var", default="default") + var.set("before_server") + + async def handle_client(reader, writer): + writer.write(var.get().encode()) + await writer.drain() + writer.close() + + server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + var.set("after_server") + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + data = await reader.read(100) + self.assertEqual(data.decode(), "before_server") + writer.close() + await writer.wait_closed() + + async with server: + addr = server.sockets[0].getsockname() + async with asyncio.TaskGroup() as tg: + for _ in range(100): + tg.create_task(client(addr)) + + self.assertEqual(var.get(), "after_server") + + self.run_coro(test()) + + def test_create_server1(self): + # Test that loop.create_server captures the context at the time of server creation + # and that mutations to the context in protocol callbacks don't affect the server's context + async def test(): + var = contextvars.ContextVar("var", default="default") + + class EchoProtocol(asyncio.Protocol): + def connection_made(self, transport): + self.transport = transport + value = var.get() + var.set("in_handler") + self.transport.write(value.encode()) + self.transport.close() + + server = await asyncio.get_running_loop().create_server( + lambda: EchoProtocol(), '127.0.0.1', 0) + var.set("after_server") + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + data = await reader.read(100) + self.assertEqual(data.decode(), "default") + writer.close() + await writer.wait_closed() + + async with server: + addr = server.sockets[0].getsockname() + await client(addr) + + self.assertEqual(var.get(), "after_server") + + self.run_coro(test()) + + def test_create_server2(self): + # Test that mutations to context in one protocol instance don't affect other instances or the server's context + async def test(): + var = contextvars.ContextVar("var", default="default") + + class EchoProtocol(asyncio.Protocol): + def __init__(self): + super().__init__() + assert var.get() == "default", var.get() + def connection_made(self, transport): + self.transport = transport + value = var.get() + var.set("in_handler") + self.transport.write(value.encode()) + self.transport.close() + + server = await asyncio.get_running_loop().create_server( + lambda: EchoProtocol(), '127.0.0.1', 0) + + var.set("after_server") + + async def client(addr, expected): + reader, writer = await asyncio.open_connection(*addr) + data = await reader.read(100) + self.assertEqual(data.decode(), expected) + writer.close() + await writer.wait_closed() + + async with server: + addr = server.sockets[0].getsockname() + await client(addr, "default") + await client(addr, "default") + + self.assertEqual(var.get(), "after_server") + + self.run_coro(test()) + + def test_gh140947(self): + # See https://github.com/python/cpython/issues/140947 + + cvar1 = contextvars.ContextVar("cvar1") + cvar2 = contextvars.ContextVar("cvar2") + cvar3 = contextvars.ContextVar("cvar3") + results = {} + + def capture_context(meth): + result = [] + for k,v in contextvars.copy_context().items(): + result.append((k.name, v)) + results[meth] = sorted(result) + + class DemoProtocol(asyncio.Protocol): + def __init__(self, on_conn_lost): + self.transport = None + self.on_conn_lost = on_conn_lost + self.tasks = set() + + def connection_made(self, transport): + capture_context("connection_made") + self.transport = transport + + def data_received(self, data): + capture_context("data_received") + + task = asyncio.create_task(self.asgi()) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) + + self.transport.pause_reading() + + def connection_lost(self, exc): + capture_context("connection_lost") + if not self.on_conn_lost.done(): + self.on_conn_lost.set_result(True) + + async def asgi(self): + capture_context("asgi start") + + cvar1.set(True) + + # make sure that we only resume after the pause + # otherwise the resume does nothing + while not self.transport._paused: + await asyncio.sleep(0.1) + + cvar2.set(True) + + self.transport.resume_reading() + + cvar3.set(True) + + capture_context("asgi end") + + + async def main(): + loop = asyncio.get_running_loop() + on_conn_lost = loop.create_future() + + host, port = "127.0.0.1", 8888 + + async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port): + reader, writer = await asyncio.open_connection(host, port) + writer.write(b"anything") + await writer.drain() + writer.close() + await writer.wait_closed() + await on_conn_lost + self.run_coro(main()) + self.assertDictEqual(results, { + "connection_made": [], + "data_received": [], + "asgi start": [], + "asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)], + "connection_lost": [], + }) + + +class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase): + loop_factory = staticmethod(asyncio.new_event_loop) + +if __name__ == "__main__": + unittest.main() From 6a8e2bdfad5fde24a43fe3d137b4c50a458c45d3 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 11:28:53 +0530 Subject: [PATCH 07/11] fix tests --- Lib/test/test_asyncio/test_server_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_server_context.py b/Lib/test/test_asyncio/test_server_context.py index 9dc210aebfa50c..9001f9bc9fd08c 100644 --- a/Lib/test/test_asyncio/test_server_context.py +++ b/Lib/test/test_asyncio/test_server_context.py @@ -188,7 +188,8 @@ def test_gh140947(self): def capture_context(meth): result = [] for k,v in contextvars.copy_context().items(): - result.append((k.name, v)) + if k.name.startswith("cvar"): + result.append((k.name, v)) results[meth] = sorted(result) class DemoProtocol(asyncio.Protocol): From ec0ed9b8fe24b904685841b4a5bf3ceb90e900fd Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 12:19:21 +0530 Subject: [PATCH 08/11] add more windows tests --- Lib/test/test_asyncio/test_server_context.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_server_context.py b/Lib/test/test_asyncio/test_server_context.py index 9001f9bc9fd08c..80e7dda3823957 100644 --- a/Lib/test/test_asyncio/test_server_context.py +++ b/Lib/test/test_asyncio/test_server_context.py @@ -1,7 +1,7 @@ - import asyncio import contextvars import unittest +import sys from unittest import TestCase @@ -261,5 +261,12 @@ async def main(): class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = staticmethod(asyncio.new_event_loop) +if sys.platform == "win32": + class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase): + loop_factory = asyncio.ProactorEventLoop + + class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase): + loop_factory = asyncio.SelectorEventLoop + if __name__ == "__main__": unittest.main() From 09df55ed3ef1ead05d86cc8e52da145336775c9d Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 13:08:14 +0530 Subject: [PATCH 09/11] fix ssl part --- Lib/asyncio/base_events.py | 6 +- Lib/asyncio/selector_events.py | 8 +- Lib/test/test_asyncio/test_server_context.py | 82 ++++++++++++++------ 3 files changed, 68 insertions(+), 28 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index ae07ef1789ad0f..7a6837546d930f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -511,7 +511,8 @@ def _make_ssl_transport( extra=None, server=None, ssl_handshake_timeout=None, ssl_shutdown_timeout=None, - call_connection_made=True): + call_connection_made=True, + context=None): """Create SSL transport.""" raise NotImplementedError @@ -1228,7 +1229,8 @@ async def _create_connection_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_shutdown_timeout=ssl_shutdown_timeout, + context=context) else: transport = self._make_socket_transport(sock, protocol, waiter, context=context) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index f695cd2f975cac..9685e7fc05d241 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -78,16 +78,17 @@ def _make_ssl_transport( extra=None, server=None, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, + context=None, ): self._ensure_fd_no_transport(rawsock) ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout + ssl_shutdown_timeout=ssl_shutdown_timeout, ) _SelectorSocketTransport(self, rawsock, ssl_protocol, - extra=extra, server=server) + extra=extra, server=server, context=context) return ssl_protocol._app_transport def _make_datagram_transport(self, sock, protocol, @@ -230,7 +231,8 @@ async def _accept_connection2( conn, protocol, sslcontext, waiter=waiter, server_side=True, extra=extra, server=server, ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_shutdown_timeout=ssl_shutdown_timeout, + context=context) else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, diff --git a/Lib/test/test_asyncio/test_server_context.py b/Lib/test/test_asyncio/test_server_context.py index 80e7dda3823957..bc1a2119eef70e 100644 --- a/Lib/test/test_asyncio/test_server_context.py +++ b/Lib/test/test_asyncio/test_server_context.py @@ -5,11 +5,20 @@ from unittest import TestCase +try: + import ssl +except ImportError: + ssl = None + +from test.test_asyncio import utils as test_utils + def tearDownModule(): asyncio.events._set_event_loop_policy(None) class ServerContextvarsTestCase: loop_factory = None # To be defined in subclasses + server_ssl_context = None # To be defined in subclasses for SSL tests + client_ssl_context = None # To be defined in subclasses for SSL tests def run_coro(self, coro): return asyncio.run(coro, loop_factory=self.loop_factory) @@ -25,12 +34,14 @@ async def handle_client(reader, writer): await writer.drain() writer.close() - server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + server = await asyncio.start_server(handle_client, '127.0.0.1', 0, + ssl=self.server_ssl_context) # change the value var.set("after_server") async def client(addr): - reader, writer = await asyncio.open_connection(*addr) + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) data = await reader.read(100) writer.close() await writer.wait_closed() @@ -56,11 +67,13 @@ async def handle_client(reader, writer): await writer.drain() writer.close() - server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + server = await asyncio.start_server(handle_client, '127.0.0.1', 0, + ssl=self.server_ssl_context) var.set("after_server") async def client(addr): - reader, writer = await asyncio.open_connection(*addr) + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) data = await reader.read(100) writer.close() await writer.wait_closed() @@ -87,11 +100,13 @@ async def handle_client(reader, writer): await writer.drain() writer.close() - server = await asyncio.start_server(handle_client, '127.0.0.1', 0) + server = await asyncio.start_server(handle_client, '127.0.0.1', 0, + ssl=self.server_ssl_context) var.set("after_server") async def client(addr): - reader, writer = await asyncio.open_connection(*addr) + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), "before_server") writer.close() @@ -122,11 +137,13 @@ def connection_made(self, transport): self.transport.close() server = await asyncio.get_running_loop().create_server( - lambda: EchoProtocol(), '127.0.0.1', 0) + lambda: EchoProtocol(), '127.0.0.1', 0, + ssl=self.server_ssl_context) var.set("after_server") async def client(addr): - reader, writer = await asyncio.open_connection(*addr) + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), "default") writer.close() @@ -157,12 +174,14 @@ def connection_made(self, transport): self.transport.close() server = await asyncio.get_running_loop().create_server( - lambda: EchoProtocol(), '127.0.0.1', 0) + lambda: EchoProtocol(), '127.0.0.1', 0, + ssl=self.server_ssl_context) var.set("after_server") async def client(addr, expected): - reader, writer = await asyncio.open_connection(*addr) + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), expected) writer.close() @@ -184,6 +203,7 @@ def test_gh140947(self): cvar2 = contextvars.ContextVar("cvar2") cvar3 = contextvars.ContextVar("cvar3") results = {} + is_ssl = self.server_ssl_context is not None def capture_context(meth): result = [] @@ -218,36 +238,37 @@ def connection_lost(self, exc): async def asgi(self): capture_context("asgi start") - cvar1.set(True) - # make sure that we only resume after the pause # otherwise the resume does nothing - while not self.transport._paused: - await asyncio.sleep(0.1) - + if is_ssl: + while not self.transport._ssl_protocol._app_reading_paused: + await asyncio.sleep(0.01) + else: + while not self.transport._paused: + await asyncio.sleep(0.01) cvar2.set(True) - self.transport.resume_reading() - cvar3.set(True) - capture_context("asgi end") - async def main(): loop = asyncio.get_running_loop() on_conn_lost = loop.create_future() - host, port = "127.0.0.1", 8888 - - async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port): - reader, writer = await asyncio.open_connection(host, port) + server = await loop.create_server( + lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0, + ssl=self.server_ssl_context) + async with server: + addr = server.sockets[0].getsockname() + reader, writer = await asyncio.open_connection(*addr, + ssl=self.client_ssl_context) writer.write(b"anything") await writer.drain() writer.close() await writer.wait_closed() await on_conn_lost + self.run_coro(main()) self.assertDictEqual(results, { "connection_made": [], @@ -261,6 +282,11 @@ async def main(): class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = staticmethod(asyncio.new_event_loop) +@unittest.skipUnless(ssl, "SSL not available") +class AsyncioEventLoopSSLTests(AsyncioEventLoopTests): + server_ssl_context = test_utils.simple_server_sslcontext() + client_ssl_context = test_utils.simple_client_sslcontext() + if sys.platform == "win32": class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = asyncio.ProactorEventLoop @@ -268,5 +294,15 @@ class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase): class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = asyncio.SelectorEventLoop + @unittest.skipUnless(ssl, "SSL not available") + class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests): + server_ssl_context = test_utils.simple_server_sslcontext() + client_ssl_context = test_utils.simple_client_sslcontext() + + @unittest.skipUnless(ssl, "SSL not available") + class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests): + server_ssl_context = test_utils.simple_server_sslcontext() + client_ssl_context = test_utils.simple_client_sslcontext() + if __name__ == "__main__": unittest.main() From edcedbb538bfc14b9fd8c1107b528bf2121c741c Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 13:14:21 +0530 Subject: [PATCH 10/11] fix tests --- Lib/test/test_asyncio/test_base_events.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 3a409d6f1cdcdd..e59bc25668b4cb 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1696,7 +1696,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, server_side=False, server_hostname='python.org', ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_shutdown_timeout=shutdown_timeout, + context=ANY) # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( @@ -1711,7 +1712,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, server_side=False, server_hostname='perl.com', ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_shutdown_timeout=shutdown_timeout, + context=ANY) # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( @@ -1726,7 +1728,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, server_side=False, server_hostname='', ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_shutdown_timeout=shutdown_timeout, + context=ANY) def test_create_connection_no_ssl_server_hostname_errors(self): # When not using ssl, server_hostname must be None. From e1ebe671cc033f0e93ddded272b0e1bb7ecb8dfe Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 21 Mar 2026 13:22:51 +0530 Subject: [PATCH 11/11] fix windows tests --- Lib/asyncio/proactor_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 3a54ac7931981c..2dc1569d780791 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -651,7 +651,7 @@ def _make_ssl_transport( *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_shutdown_timeout=None, context=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname,