From 05e204604ed445934bab6c66188a9e0926e28465 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 12:37:44 +0900 Subject: [PATCH 01/16] Run conformance client tests in parallel Signed-off-by: Anuraag Agrawal --- conformance/test/client.py | 50 ++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/conformance/test/client.py b/conformance/test/client.py index e621496..d44e9d0 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -2,6 +2,7 @@ import argparse import asyncio +import multiprocessing import ssl import sys import time @@ -543,32 +544,45 @@ async def send_unary_request( class Args(argparse.Namespace): mode: Literal["sync", "async"] + parallel: int async def main() -> None: parser = argparse.ArgumentParser(description="Conformance client") parser.add_argument("--mode", choices=["sync", "async"]) + parser.add_argument("--parallel", type=int, default=multiprocessing.cpu_count() * 4) args = parser.parse_args(namespace=Args()) stdin, stdout = await create_standard_streams() - while True: - try: - size_buf = await stdin.readexactly(4) - except asyncio.IncompleteReadError: - return - size = int.from_bytes(size_buf, byteorder="big") - # Allow to raise even on EOF since we always should have a message - request_buf = await stdin.readexactly(size) - request = ClientCompatRequest() - request.ParseFromString(request_buf) - - response = await _run_test(args.mode, request) - - response_buf = response.SerializeToString() - size_buf = len(response_buf).to_bytes(4, byteorder="big") - stdout.write(size_buf) - stdout.write(response_buf) - await stdout.drain() + sema = asyncio.Semaphore(args.parallel) + stdout_lock = asyncio.Lock() + tasks: list[asyncio.Task] = [] + try: + while True: + try: + size_buf = await stdin.readexactly(4) + except asyncio.IncompleteReadError: + return + size = int.from_bytes(size_buf, byteorder="big") + # Allow to raise even on EOF since we always should have a message + request_buf = await stdin.readexactly(size) + request = ClientCompatRequest() + request.ParseFromString(request_buf) + + async def task(request: ClientCompatRequest) -> None: + async with sema: + response = await _run_test(args.mode, request) + + response_buf = response.SerializeToString() + size_buf = len(response_buf).to_bytes(4, byteorder="big") + async with stdout_lock: + stdout.write(size_buf) + stdout.write(response_buf) + await stdout.drain() + + tasks.append(asyncio.create_task(task(request))) + finally: + await asyncio.gather(*tasks) if __name__ == "__main__": From 8fc25597067de934bc2e75b972a2b58d4297ebe7 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 13:32:13 +0900 Subject: [PATCH 02/16] shield httpx async cleanup from cancellation Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 60 +++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index d8fcfc3..e3b93f9 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -2,6 +2,7 @@ import asyncio import functools +import sys from asyncio import CancelledError, sleep, wait_for from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -365,39 +366,46 @@ async def _send_request_bidi_stream( request, self._codec, self._send_compression ) - async with ( - asyncio_timeout(timeout_s), - self._session.stream( + async with asyncio_timeout(timeout_s): + stream = self._session.stream( method="POST", url=url, headers=request_headers, content=request_data, timeout=timeout, - ) as resp, - ): - compression = _client_shared.validate_response_content_encoding( - resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") ) - _client_shared.validate_stream_response_content_type( - self._codec.name(), resp.headers.get("content-type", "") - ) - handle_response_headers(resp.headers) - - if resp.status_code == 200: - reader = EnvelopeReader( - ctx.method().output, - self._codec, - compression, - self._read_max_bytes, + resp = await stream.__aenter__() + try: + compression = _client_shared.validate_response_content_encoding( + resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") + ) + _client_shared.validate_stream_response_content_type( + self._codec.name(), resp.headers.get("content-type", "") ) - async for chunk in resp.aiter_bytes(): - for message in reader.feed(chunk): - yield message - # Check for cancellation each message. While this seems heavyweight, - # conformance tests require it. - await sleep(0) - else: - raise ConnectWireError.from_response(resp).to_exception() + handle_response_headers(resp.headers) + + if resp.status_code == 200: + reader = EnvelopeReader( + ctx.method().output, + self._codec, + compression, + self._read_max_bytes, + ) + async for chunk in resp.aiter_bytes(): + for message in reader.feed(chunk): + yield message + # Check for cancellation each message. While this seems heavyweight, + # conformance tests require it. + await sleep(0) + else: + raise ConnectWireError.from_response(resp).to_exception() + finally: + # We always need response cleanup to run even during cancellation, which is only + # possible if shielding it with manual invocation. Besides potential cleanup issues, + # a symptom of not doing this is the cancellation error getting replaced by one + # during the httpx cleanup and not getting mapped to the correct connect error. + exc_type, exc_val, exc_tb = sys.exc_info() + await asyncio.shield(stream.__aexit__(exc_type, exc_val, exc_tb)) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: From e17c3544ca569186ff94ef07b76f6031e2de1914 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 13:32:42 +0900 Subject: [PATCH 03/16] Cleanup Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index e3b93f9..cbf56e7 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -38,7 +38,6 @@ from ._asyncio_timeout import timeout as asyncio_timeout if TYPE_CHECKING: - import sys from collections.abc import AsyncIterator, Iterable, Mapping from types import TracebackType From d63c78eee205140676848361627734faad09d6bb Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 14:21:11 +0900 Subject: [PATCH 04/16] Avoid async context manager entirely Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index cbf56e7..eb2bae5 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -2,7 +2,6 @@ import asyncio import functools -import sys from asyncio import CancelledError, sleep, wait_for from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -38,6 +37,7 @@ from ._asyncio_timeout import timeout as asyncio_timeout if TYPE_CHECKING: + import sys from collections.abc import AsyncIterator, Iterable, Mapping from types import TracebackType @@ -366,15 +366,18 @@ async def _send_request_bidi_stream( ) async with asyncio_timeout(timeout_s): - stream = self._session.stream( - method="POST", - url=url, - headers=request_headers, - content=request_data, - timeout=timeout, - ) - resp = await stream.__aenter__() + resp = None try: + # Use build_request + send to avoid AsyncContextManager which + # has issues in cleanup during cancellation. + httpx_request = self._session.build_request( + method="POST", + url=url, + headers=request_headers, + content=request_data, + timeout=timeout, + ) + resp = await self._session.send(httpx_request, stream=True) compression = _client_shared.validate_response_content_encoding( resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") ) @@ -399,12 +402,8 @@ async def _send_request_bidi_stream( else: raise ConnectWireError.from_response(resp).to_exception() finally: - # We always need response cleanup to run even during cancellation, which is only - # possible if shielding it with manual invocation. Besides potential cleanup issues, - # a symptom of not doing this is the cancellation error getting replaced by one - # during the httpx cleanup and not getting mapped to the correct connect error. - exc_type, exc_val, exc_tb = sys.exc_info() - await asyncio.shield(stream.__aexit__(exc_type, exc_val, exc_tb)) + if resp is not None: + await asyncio.shield(resp.aclose()) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: From d9f23a65a523448c2d741ef6fa0f43db3a3dd134 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 14:36:22 +0900 Subject: [PATCH 05/16] Change cancellation check approach Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index eb2bae5..a1a6c35 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -2,7 +2,7 @@ import asyncio import functools -from asyncio import CancelledError, sleep, wait_for +from asyncio import CancelledError, wait_for from typing import TYPE_CHECKING, Any, Protocol, TypeVar import httpx @@ -395,10 +395,10 @@ async def _send_request_bidi_stream( ) async for chunk in resp.aiter_bytes(): for message in reader.feed(chunk): - yield message # Check for cancellation each message. While this seems heavyweight, # conformance tests require it. - await sleep(0) + if not asyncio.current_task().cancelled(): + yield message else: raise ConnectWireError.from_response(resp).to_exception() finally: From 98933773389da6e5ed2564049f8348cf6828a3fd Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 14:45:05 +0900 Subject: [PATCH 06/16] Fix typing Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index a1a6c35..49272d6 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -397,7 +397,9 @@ async def _send_request_bidi_stream( for message in reader.feed(chunk): # Check for cancellation each message. While this seems heavyweight, # conformance tests require it. - if not asyncio.current_task().cancelled(): + if ( + task := asyncio.current_task() + ) and not task.cancelled(): yield message else: raise ConnectWireError.from_response(resp).to_exception() From e40b7deb75117f0f2958780b7176c9fce1ab3247 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 7 Jan 2026 14:55:21 +0900 Subject: [PATCH 07/16] One more try Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 49272d6..9ce6d34 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -397,6 +397,7 @@ async def _send_request_bidi_stream( for message in reader.feed(chunk): # Check for cancellation each message. While this seems heavyweight, # conformance tests require it. + await asyncio.sleep(0) if ( task := asyncio.current_task() ) and not task.cancelled(): From e9bee915209cb683127c039be20c8f4fbf81028a Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 11:18:37 +0900 Subject: [PATCH 08/16] Fix cancellation detection Signed-off-by: Anuraag Agrawal --- conformance/test/client.py | 129 +++++++++++++++++++++++++------- conformance/test/test_client.py | 7 +- src/connectrpc/_client_async.py | 28 +++++-- 3 files changed, 130 insertions(+), 34 deletions(-) diff --git a/conformance/test/client.py b/conformance/test/client.py index d44e9d0..7b3eeb5 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -8,7 +8,7 @@ import time import traceback from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar import httpx from _util import create_standard_streams @@ -107,6 +107,45 @@ def _unpack_request(message: Any, request: T) -> T: return request +def _build_tls_context( + server_cert: bytes, + client_cert: bytes | None = None, + client_key: bytes | None = None, +) -> ssl.SSLContext: + ctx = ssl.create_default_context( + purpose=ssl.Purpose.SERVER_AUTH, cadata=server_cert.decode() + ) + if client_cert is None or client_key is None: + return ctx + with NamedTemporaryFile() as cert_file, NamedTemporaryFile() as key_file: + cert_file.write(client_cert) + cert_file.flush() + key_file.write(client_key) + key_file.flush() + ctx.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name) + return ctx + + +def _schedule_cancel(task: asyncio.Task, delay_s: float) -> asyncio.Handle: + loop = asyncio.get_running_loop() + + def _cancel() -> None: + task.cancel() + + return loop.call_later(delay_s, _cancel) + + +def _schedule_cancel_after_close_send( + task: asyncio.Task, delay_s: float, close_send_event: asyncio.Event +) -> asyncio.Task: + async def _run() -> None: + await close_send_event.wait() + await asyncio.sleep(delay_s) + task.cancel() + + return asyncio.create_task(_run()) + + async def _run_test( mode: Literal["sync", "async"], test_request: ClientCompatRequest ) -> ClientCompatResponse: @@ -126,7 +165,8 @@ async def _run_test( request_headers.add(header.name, value) payloads: list[ConformancePayload] = [] - + close_send_event = asyncio.Event() + loop = asyncio.get_running_loop() with ResponseMetadata() as meta: try: task: asyncio.Task @@ -141,22 +181,17 @@ async def _run_test( scheme = "http" if test_request.server_tls_cert: scheme = "https" - ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH, - cadata=test_request.server_tls_cert.decode(), - ) if test_request.HasField("client_tls_creds"): - with ( - NamedTemporaryFile() as cert_file, - NamedTemporaryFile() as key_file, - ): - cert_file.write(test_request.client_tls_creds.cert) - cert_file.flush() - key_file.write(test_request.client_tls_creds.key) - key_file.flush() - ctx.load_cert_chain( - certfile=cert_file.name, keyfile=key_file.name - ) + ctx = await asyncio.to_thread( + _build_tls_context, + test_request.server_tls_cert, + test_request.client_tls_creds.cert, + test_request.client_tls_creds.key, + ) + else: + ctx = await asyncio.to_thread( + _build_tls_context, test_request.server_tls_cert + ) session_kwargs["verify"] = ctx match mode: case "sync": @@ -189,7 +224,7 @@ def send_bidi_stream_request_sync( num := test_request.cancel.after_num_responses ) and len(payloads) >= num: - task.cancel() + loop.call_soon_threadsafe(task.cancel) def bidi_request_stream_sync(): for message in test_request.request_messages: @@ -200,6 +235,15 @@ def bidi_request_stream_sync(): yield _unpack_request( message, BidiStreamRequest() ) + if test_request.cancel.HasField( + "before_close_send" + ): + loop.call_soon_threadsafe(task.cancel) + time.sleep(600) + else: + loop.call_soon_threadsafe( + close_send_event.set + ) task = asyncio.create_task( asyncio.to_thread( @@ -231,6 +275,15 @@ def request_stream_sync(): yield _unpack_request( message, ClientStreamRequest() ) + if test_request.cancel.HasField( + "before_close_send" + ): + loop.call_soon_threadsafe(task.cancel) + time.sleep(600) + else: + loop.call_soon_threadsafe( + close_send_event.set + ) task = asyncio.create_task( asyncio.to_thread( @@ -263,6 +316,7 @@ def send_idempotent_unary_request_sync( ), ) ) + close_send_event.set() case "ServerStream": def send_server_stream_request_sync( @@ -279,7 +333,7 @@ def send_server_stream_request_sync( num := test_request.cancel.after_num_responses ) and len(payloads) >= num: - task.cancel() + loop.call_soon_threadsafe(task.cancel) task = asyncio.create_task( asyncio.to_thread( @@ -291,6 +345,7 @@ def send_server_stream_request_sync( ), ) ) + close_send_event.set() case "Unary": def send_unary_request_sync( @@ -314,6 +369,7 @@ def send_unary_request_sync( ), ) ) + close_send_event.set() case "Unimplemented": task = asyncio.create_task( asyncio.to_thread( @@ -326,15 +382,21 @@ def send_unary_request_sync( timeout_ms=timeout_ms, ) ) + close_send_event.set() case _: msg = f"Unrecognized method: {test_request.method}" raise ValueError(msg) + cancel_task: asyncio.Task | None = None if test_request.cancel.after_close_send_ms: - await asyncio.sleep( - test_request.cancel.after_close_send_ms / 1000.0 + delay = test_request.cancel.after_close_send_ms / 1000.0 + cancel_task = _schedule_cancel_after_close_send( + task, delay, close_send_event ) - task.cancel() - await task + try: + await task + finally: + if cancel_task is not None: + cancel_task.cancel() case "async": async with ( httpx.AsyncClient(**session_kwargs) as session, @@ -385,6 +447,8 @@ async def bidi_stream_request(): # a long time. We won't end up sleeping for long since we # cancelled. await asyncio.sleep(600) + else: + close_send_event.set() task = asyncio.create_task( send_bidi_stream_request( @@ -422,6 +486,8 @@ async def client_stream_request(): # a long time. We won't end up sleeping for long since we # cancelled. await asyncio.sleep(600) + else: + close_send_event.set() task = asyncio.create_task( send_client_stream_request( @@ -451,6 +517,7 @@ async def send_idempotent_unary_request( ), ) ) + close_send_event.set() case "ServerStream": async def send_server_stream_request( @@ -478,6 +545,7 @@ async def send_server_stream_request( ), ) ) + close_send_event.set() case "Unary": async def send_unary_request( @@ -500,6 +568,7 @@ async def send_unary_request( ), ) ) + close_send_event.set() case "Unimplemented": task = asyncio.create_task( client.unimplemented( @@ -511,15 +580,21 @@ async def send_unary_request( timeout_ms=timeout_ms, ) ) + close_send_event.set() case _: msg = f"Unrecognized method: {test_request.method}" raise ValueError(msg) + cancel_task: asyncio.Task | None = None if test_request.cancel.after_close_send_ms: - await asyncio.sleep( - test_request.cancel.after_close_send_ms / 1000.0 + delay = test_request.cancel.after_close_send_ms / 1000.0 + cancel_task = _schedule_cancel_after_close_send( + task, delay, close_send_event ) - task.cancel() - await task + try: + await task + finally: + if cancel_task is not None: + cancel_task.cancel() except ConnectError as e: test_response.response.error.code = _convert_code(e.code) test_response.response.error.message = e.message diff --git a/conformance/test/test_client.py b/conformance/test/test_client.py index 22426ca..9a69751 100644 --- a/conformance/test/test_client.py +++ b/conformance/test/test_client.py @@ -71,6 +71,10 @@ def test_client_async() -> None: args = maybe_patch_args_with_debug( [sys.executable, _client_py_path, "--mode", "async"] ) + flaky_tests = [] + if sys.version_info < (3, 11): + # Python 3.11 is required to reliably detect cancellation. + flaky_tests = ["--known-flaky", "Client Cancellation/**"] result = subprocess.run( [ "go", @@ -81,8 +85,7 @@ def test_client_async() -> None: "--mode", "client", *_skipped_tests_async, - "--known-flaky", - "Client Cancellation/**", + *flaky_tests, "--", *args, ], diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 9ce6d34..09400e1 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -325,6 +325,8 @@ async def _send_request_unary( response = ctx.method().output() self._codec.decode(resp.content, response) + if _task_cancelled(): + raise CancelledError return response raise ConnectWireError.from_response(resp).to_exception() except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: @@ -397,11 +399,9 @@ async def _send_request_bidi_stream( for message in reader.feed(chunk): # Check for cancellation each message. While this seems heavyweight, # conformance tests require it. - await asyncio.sleep(0) - if ( - task := asyncio.current_task() - ) and not task.cancelled(): - yield message + if _task_cancelled(): + raise CancelledError + yield message else: raise ConnectWireError.from_response(resp).to_exception() finally: @@ -427,6 +427,24 @@ def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: return Timeout(None) +def _task_cancelled() -> bool: + task = asyncio.current_task() + if task is None: + return False + if task.cancelled(): + return True + # Only available in Python 3.11+. If httpx squashes cancellation, we can't + # know about the cancellation and can return messages even after cancellation. + # cancelling cannot be squashed and is the reliable way to detect this case. + cancelling = getattr(task, "cancelling", None) + if callable(cancelling): + try: + return cancelling() > 0 + except Exception: + return False + return False + + async def _streaming_request_content( msgs: AsyncIterator[Any], codec: Codec, compression: Compression | None ) -> AsyncIterator[bytes]: From 2ccce6dc2a723c6fa4384b80e710ccf05e16d2b9 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 11:31:18 +0900 Subject: [PATCH 09/16] Format Signed-off-by: Anuraag Agrawal --- conformance/test/client.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/conformance/test/client.py b/conformance/test/client.py index 7b3eeb5..cd6d27c 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -241,9 +241,7 @@ def bidi_request_stream_sync(): loop.call_soon_threadsafe(task.cancel) time.sleep(600) else: - loop.call_soon_threadsafe( - close_send_event.set - ) + loop.call_soon_threadsafe(close_send_event.set) task = asyncio.create_task( asyncio.to_thread( @@ -281,9 +279,7 @@ def request_stream_sync(): loop.call_soon_threadsafe(task.cancel) time.sleep(600) else: - loop.call_soon_threadsafe( - close_send_event.set - ) + loop.call_soon_threadsafe(close_send_event.set) task = asyncio.create_task( asyncio.to_thread( From d6394c3a470b8da630deecc4e474fc541b1c320b Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 11:36:30 +0900 Subject: [PATCH 10/16] Fix type check Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 09400e1..f2ddcc7 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -3,7 +3,7 @@ import asyncio import functools from asyncio import CancelledError, wait_for -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast import httpx from httpx import USE_CLIENT_DEFAULT, Timeout @@ -38,7 +38,7 @@ if TYPE_CHECKING: import sys - from collections.abc import AsyncIterator, Iterable, Mapping + from collections.abc import AsyncIterator, Callable, Iterable, Mapping from types import TracebackType from ._compression import Compression @@ -439,7 +439,7 @@ def _task_cancelled() -> bool: cancelling = getattr(task, "cancelling", None) if callable(cancelling): try: - return cancelling() > 0 + return cast("Callable[[], int]", cancelling)() > 0 except Exception: return False return False From 4819861587950f5380597c01b55fceedb44a7630 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 12:01:24 +0900 Subject: [PATCH 11/16] Accont for squashed cancels Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index f2ddcc7..0699af5 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -276,6 +276,7 @@ async def _send_request_unary( timeout_s = None timeout = USE_CLIENT_DEFAULT + cancel_count = _task_cancel_count() try: request_data = self._codec.encode(request) if self._send_compression: @@ -325,7 +326,7 @@ async def _send_request_unary( response = ctx.method().output() self._codec.decode(resp.content, response) - if _task_cancelled(): + if _task_cancelled_since(cancel_count): raise CancelledError return response raise ConnectWireError.from_response(resp).to_exception() @@ -362,6 +363,7 @@ async def _send_request_bidi_stream( timeout_s = None timeout = USE_CLIENT_DEFAULT + cancel_count = _task_cancel_count() try: request_data = _streaming_request_content( request, self._codec, self._send_compression @@ -399,7 +401,7 @@ async def _send_request_bidi_stream( for message in reader.feed(chunk): # Check for cancellation each message. While this seems heavyweight, # conformance tests require it. - if _task_cancelled(): + if _task_cancelled_since(cancel_count): raise CancelledError yield message else: @@ -427,22 +429,32 @@ def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: return Timeout(None) -def _task_cancelled() -> bool: +# cancelling count always goes up, regardless of if CancelledError is squashed or not. +# To detect the user cancelled the specific request task, we need to compare the cancel +# count before the operation and after each message. +def _task_cancel_count() -> int: task = asyncio.current_task() if task is None: - return False - if task.cancelled(): - return True + return 0 # Only available in Python 3.11+. If httpx squashes cancellation, we can't # know about the cancellation and can return messages even after cancellation. # cancelling cannot be squashed and is the reliable way to detect this case. cancelling = getattr(task, "cancelling", None) if callable(cancelling): try: - return cast("Callable[[], int]", cancelling)() > 0 + return cast("Callable[[], int]", cancelling)() except Exception: - return False - return False + return 0 + return 0 + + +def _task_cancelled_since(start_count: int) -> bool: + task = asyncio.current_task() + if task is None: + return False + if task.cancelled(): + return True + return _task_cancel_count() > start_count async def _streaming_request_content( From 701c8d7c51e5aa8086cd9d083867235dced25602 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 13:52:51 +0900 Subject: [PATCH 12/16] WIP Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 187 +++++++++++++++++--------------- 1 file changed, 101 insertions(+), 86 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 0699af5..9d9d80a 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import functools from asyncio import CancelledError, wait_for from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast @@ -38,7 +39,7 @@ if TYPE_CHECKING: import sys - from collections.abc import AsyncIterator, Callable, Iterable, Mapping + from collections.abc import AsyncIterator, Iterable, Mapping from types import TracebackType from ._compression import Compression @@ -276,68 +277,87 @@ async def _send_request_unary( timeout_s = None timeout = USE_CLIENT_DEFAULT - cancel_count = _task_cancel_count() - try: + result_queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1) + + async def _do_request() -> None: request_data = self._codec.encode(request) if self._send_compression: request_data = self._send_compression.compress(request_data) - if ctx.http_method() == "GET": - params = _client_shared.prepare_get_params( - self._codec, request_data, request_headers - ) - request_headers.pop("content-type", None) - resp = await wait_for( - self._session.get( - url=url, headers=request_headers, params=params, timeout=timeout - ), - timeout_s, - ) - else: - resp = await wait_for( - self._session.post( + try: + if ctx.http_method() == "GET": + params = _client_shared.prepare_get_params( + self._codec, request_data, request_headers + ) + request_headers.pop("content-type", None) + httpx_request = self._session.build_request( + method="GET", + url=url, + headers=request_headers, + params=params, + timeout=timeout, + ) + else: + httpx_request = self._session.build_request( + method="POST", url=url, headers=request_headers, content=request_data, timeout=timeout, - ), - timeout_s, - ) - - _client_shared.validate_response_content_encoding( - resp.headers.get("content-encoding", "") - ) - _client_shared.validate_response_content_type( - self._codec.name(), - resp.status_code, - resp.headers.get("content-type", ""), - ) - handle_response_headers(resp.headers) - - if resp.status_code == 200: - if ( - self._read_max_bytes is not None - and len(resp.content) > self._read_max_bytes - ): - raise ConnectError( - Code.RESOURCE_EXHAUSTED, - f"message is larger than configured max {self._read_max_bytes}", ) - response = ctx.method().output() - self._codec.decode(resp.content, response) - if _task_cancelled_since(cancel_count): - raise CancelledError - return response - raise ConnectWireError.from_response(resp).to_exception() + resp = await wait_for(self._session.send(httpx_request), timeout_s) + + _client_shared.validate_response_content_encoding( + resp.headers.get("content-encoding", "") + ) + _client_shared.validate_response_content_type( + self._codec.name(), + resp.status_code, + resp.headers.get("content-type", ""), + ) + handle_response_headers(resp.headers) + + if resp.status_code == 200: + if ( + self._read_max_bytes is not None + and len(resp.content) > self._read_max_bytes + ): + raise ConnectError( + Code.RESOURCE_EXHAUSTED, + f"message is larger than configured max {self._read_max_bytes}", + ) + + response = ctx.method().output() + self._codec.decode(resp.content, response) + result_queue.put_nowait(response) + return + raise ConnectWireError.from_response(resp).to_exception() + except BaseException as exc: + if result_queue.empty(): + result_queue.put_nowait(exc) + raise + + task = asyncio.create_task(_do_request()) + task.add_done_callback(_consume_task_result) + try: + item = await result_queue.get() + if isinstance(item, BaseException): + raise item + return cast("RES", item) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: raise except CancelledError as e: + if not task.done(): + task.cancel() raise ConnectError(Code.CANCELED, "Request was cancelled") from e except Exception as e: raise ConnectError(Code.UNAVAILABLE, str(e)) from e + finally: + if not task.done(): + task.cancel() async def _send_request_client_stream( self, request: AsyncIterator[REQ], ctx: RequestContext[REQ, RES] @@ -363,15 +383,17 @@ async def _send_request_bidi_stream( timeout_s = None timeout = USE_CLIENT_DEFAULT - cancel_count = _task_cancel_count() - try: - request_data = _streaming_request_content( - request, self._codec, self._send_compression - ) + queue: asyncio.Queue[object] = asyncio.Queue() + sentinel = object() + + async def _produce() -> None: + resp = None + try: + request_data = _streaming_request_content( + request, self._codec, self._send_compression + ) - async with asyncio_timeout(timeout_s): - resp = None - try: + async with asyncio_timeout(timeout_s): # Use build_request + send to avoid AsyncContextManager which # has issues in cleanup during cancellation. httpx_request = self._session.build_request( @@ -399,24 +421,39 @@ async def _send_request_bidi_stream( ) async for chunk in resp.aiter_bytes(): for message in reader.feed(chunk): - # Check for cancellation each message. While this seems heavyweight, - # conformance tests require it. - if _task_cancelled_since(cancel_count): - raise CancelledError - yield message + await queue.put(message) else: raise ConnectWireError.from_response(resp).to_exception() - finally: - if resp is not None: + except Exception as exc: + queue.put_nowait(exc) + finally: + if resp is not None: + with contextlib.suppress(Exception): await asyncio.shield(resp.aclose()) + queue.put_nowait(sentinel) + + producer = asyncio.create_task(_produce()) + producer.add_done_callback(_consume_task_result) + try: + while True: + item = await queue.get() + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield cast("RES", item) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: raise except CancelledError as e: + producer.cancel() raise ConnectError(Code.CANCELED, "Request was cancelled") from e except Exception as e: raise ConnectError(Code.UNAVAILABLE, str(e)) from e + finally: + if not producer.done(): + producer.cancel() def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: @@ -429,32 +466,10 @@ def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: return Timeout(None) -# cancelling count always goes up, regardless of if CancelledError is squashed or not. -# To detect the user cancelled the specific request task, we need to compare the cancel -# count before the operation and after each message. -def _task_cancel_count() -> int: - task = asyncio.current_task() - if task is None: - return 0 - # Only available in Python 3.11+. If httpx squashes cancellation, we can't - # know about the cancellation and can return messages even after cancellation. - # cancelling cannot be squashed and is the reliable way to detect this case. - cancelling = getattr(task, "cancelling", None) - if callable(cancelling): - try: - return cast("Callable[[], int]", cancelling)() - except Exception: - return 0 - return 0 - - -def _task_cancelled_since(start_count: int) -> bool: - task = asyncio.current_task() - if task is None: - return False - if task.cancelled(): - return True - return _task_cancel_count() > start_count +def _consume_task_result(task: asyncio.Task[Any]) -> None: + # Task completion can raise CancelledError (BaseException) in shutdown/cancel paths. + with contextlib.suppress(BaseException): + task.result() async def _streaming_request_content( From 7c6f8ef93eabf2c0a270c12a361b8f7dd3a08492 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 14:21:15 +0900 Subject: [PATCH 13/16] WIP Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 83 +++++++++++++++++---------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 9d9d80a..92aab90 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -10,7 +10,6 @@ from httpx import USE_CLIENT_DEFAULT, Timeout from . import _client_shared -from ._asyncio_timeout import timeout as asyncio_timeout from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec from ._envelope import EnvelopeReader from ._interceptor_async import ( @@ -30,13 +29,6 @@ from .code import Code from .errors import ConnectError -try: - from asyncio import ( - timeout as asyncio_timeout, # pyright: ignore[reportAttributeAccessIssue] - ) -except ImportError: - from ._asyncio_timeout import timeout as asyncio_timeout - if TYPE_CHECKING: import sys from collections.abc import AsyncIterator, Iterable, Mapping @@ -383,6 +375,9 @@ async def _send_request_bidi_stream( timeout_s = None timeout = USE_CLIENT_DEFAULT + loop = asyncio.get_running_loop() + deadline = None if timeout_s is None else loop.time() + timeout_s + queue: asyncio.Queue[object] = asyncio.Queue() sentinel = object() @@ -393,37 +388,36 @@ async def _produce() -> None: request, self._codec, self._send_compression ) - async with asyncio_timeout(timeout_s): - # Use build_request + send to avoid AsyncContextManager which - # has issues in cleanup during cancellation. - httpx_request = self._session.build_request( - method="POST", - url=url, - headers=request_headers, - content=request_data, - timeout=timeout, - ) - resp = await self._session.send(httpx_request, stream=True) - compression = _client_shared.validate_response_content_encoding( - resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") - ) - _client_shared.validate_stream_response_content_type( - self._codec.name(), resp.headers.get("content-type", "") + # Use build_request + send to avoid AsyncContextManager which + # has issues in cleanup during cancellation. + httpx_request = self._session.build_request( + method="POST", + url=url, + headers=request_headers, + content=request_data, + timeout=timeout, + ) + resp = await self._session.send(httpx_request, stream=True) + compression = _client_shared.validate_response_content_encoding( + resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") + ) + _client_shared.validate_stream_response_content_type( + self._codec.name(), resp.headers.get("content-type", "") + ) + handle_response_headers(resp.headers) + + if resp.status_code == 200: + reader = EnvelopeReader( + ctx.method().output, + self._codec, + compression, + self._read_max_bytes, ) - handle_response_headers(resp.headers) - - if resp.status_code == 200: - reader = EnvelopeReader( - ctx.method().output, - self._codec, - compression, - self._read_max_bytes, - ) - async for chunk in resp.aiter_bytes(): - for message in reader.feed(chunk): - await queue.put(message) - else: - raise ConnectWireError.from_response(resp).to_exception() + async for chunk in resp.aiter_bytes(): + for message in reader.feed(chunk): + await queue.put(message) + else: + raise ConnectWireError.from_response(resp).to_exception() except Exception as exc: queue.put_nowait(exc) finally: @@ -436,7 +430,18 @@ async def _produce() -> None: producer.add_done_callback(_consume_task_result) try: while True: - item = await queue.get() + try: + if deadline is None: + item = await queue.get() + else: + remaining = deadline - loop.time() + if remaining <= 0: + raise asyncio.TimeoutError + item = await asyncio.wait_for(queue.get(), remaining) + except asyncio.TimeoutError: + if not producer.done(): + producer.cancel() + raise if item is sentinel: break if isinstance(item, Exception): From 664977359a7778b96d49f40de0d773d61ccf2158 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 14:31:40 +0900 Subject: [PATCH 14/16] mark cancel-after-responses flaky Signed-off-by: Anuraag Agrawal --- conformance/test/test_client.py | 7 ++----- src/connectrpc/_client_async.py | 1 - 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/conformance/test/test_client.py b/conformance/test/test_client.py index 9a69751..46012ad 100644 --- a/conformance/test/test_client.py +++ b/conformance/test/test_client.py @@ -71,10 +71,6 @@ def test_client_async() -> None: args = maybe_patch_args_with_debug( [sys.executable, _client_py_path, "--mode", "async"] ) - flaky_tests = [] - if sys.version_info < (3, 11): - # Python 3.11 is required to reliably detect cancellation. - flaky_tests = ["--known-flaky", "Client Cancellation/**"] result = subprocess.run( [ "go", @@ -85,7 +81,8 @@ def test_client_async() -> None: "--mode", "client", *_skipped_tests_async, - *flaky_tests, + "--known-flaky", + "**/cancel-after-responses", "--", *args, ], diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 92aab90..227ac84 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -472,7 +472,6 @@ def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: def _consume_task_result(task: asyncio.Task[Any]) -> None: - # Task completion can raise CancelledError (BaseException) in shutdown/cancel paths. with contextlib.suppress(BaseException): task.result() From 49562fc7f256f1519f0c0cfa9e7e0d31b19de945 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 14:42:13 +0900 Subject: [PATCH 15/16] Unary timeout Signed-off-by: Anuraag Agrawal --- src/connectrpc/_client_async.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 227ac84..a520682 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -333,7 +333,15 @@ async def _do_request() -> None: task = asyncio.create_task(_do_request()) task.add_done_callback(_consume_task_result) try: - item = await result_queue.get() + try: + if timeout_s is None: + item = await result_queue.get() + else: + item = await asyncio.wait_for(result_queue.get(), timeout_s) + except asyncio.TimeoutError: + if not task.done(): + task.cancel() + raise if isinstance(item, BaseException): raise item return cast("RES", item) From de88601f9972348f020c9ce7484eb9bcaeebfac7 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 14:57:32 +0900 Subject: [PATCH 16/16] Restore old timeout test Signed-off-by: Anuraag Agrawal --- test/test_errors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_errors.py b/test/test_errors.py index 03b21b3..9b48253 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -353,7 +353,7 @@ def make_hat(self, request, ctx) -> NoReturn: @pytest.mark.parametrize( - ("client_timeout_ms", "call_timeout_ms"), [(200, None), (None, 200)] + ("client_timeout_ms", "call_timeout_ms"), [(50, None), (None, 50)] ) def test_sync_client_timeout( client_timeout_ms, call_timeout_ms, timeout_server: str @@ -385,12 +385,12 @@ def modify_timeout_header(request: Request) -> None: assert exc_info.value.code == Code.DEADLINE_EXCEEDED assert exc_info.value.message == "Request timed out" - assert recorded_timeout_header == "200" + assert recorded_timeout_header == "50" @pytest.mark.asyncio @pytest.mark.parametrize( - ("client_timeout_ms", "call_timeout_ms"), [(200, None), (None, 200)] + ("client_timeout_ms", "call_timeout_ms"), [(50, None), (None, 50)] ) async def test_async_client_timeout( client_timeout_ms, call_timeout_ms, timeout_server: str @@ -416,4 +416,4 @@ async def modify_timeout_header(request: Request) -> None: assert exc_info.value.code == Code.DEADLINE_EXCEEDED assert exc_info.value.message == "Request timed out" - assert recorded_timeout_header == "200" + assert recorded_timeout_header == "50"