diff --git a/src/ghoshell_moss/core/concepts/channel.py b/src/ghoshell_moss/core/concepts/channel.py index 8f7180f..50a0c07 100644 --- a/src/ghoshell_moss/core/concepts/channel.py +++ b/src/ghoshell_moss/core/concepts/channel.py @@ -642,12 +642,9 @@ async def recursive_start(_chan: Channel) -> None: async def recursive_close(_chan: Channel) -> None: children = _chan.children() - if len(children) == 0: - return group_stop = [] for child in children.values(): - if not child.is_running(): - group_stop.append(recursive_close(child)) + group_stop.append(recursive_close(child)) await asyncio.gather(*group_stop) if _chan.is_running(): await _chan.broker.close() diff --git a/tests/channels/test_thread_channel.py b/tests/channels/test_thread_channel.py index 050f821..33fe1dd 100644 --- a/tests/channels/test_thread_channel.py +++ b/tests/channels/test_thread_channel.py @@ -17,6 +17,22 @@ async def test_thread_channel_start_and_close(): assert not provider.is_running() +@pytest.mark.asyncio +async def test_channel_run_in_ctx_closes_child_channels(): + root = PyChannel(name="root") + child = root.new_child("child") + grandchild = child.new_child("grandchild") + + async with root.run_in_ctx(): + assert root.is_running() + assert child.is_running() + assert grandchild.is_running() + + assert not root.is_running() + assert not child.is_running() + assert not grandchild.is_running() + + @pytest.mark.asyncio async def test_thread_channel_raise_in_proxy(): provider, proxy = create_thread_channel("client") @@ -33,13 +49,25 @@ async def test_thread_channel_run_in_thread(): provider, proxy = create_thread_channel("client") chan = PyChannel(name="provider") provider.run_in_thread(chan) - await provider.aclose() await provider.wait_closed() assert not chan.is_running() assert not provider.is_running() +@pytest.mark.asyncio +async def test_thread_channel_run_in_ctx(): + provider, proxy = create_thread_channel("client") + chan = PyChannel(name="provider") + async with provider.run_in_ctx(chan): + async with proxy.run_in_ctx(): + await proxy.broker.wait_connected() + assert chan.is_running() + assert provider.is_running() + assert not chan.is_running() + assert not provider.is_running() + + @pytest.mark.asyncio async def test_thread_channel_run_in_tasks(): provider, proxy = create_thread_channel("client")