diff --git a/aredis/pipeline.py b/aredis/pipeline.py index 44f14cfd..77e35c7c 100644 --- a/aredis/pipeline.py +++ b/aredis/pipeline.py @@ -45,7 +45,7 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.reset() + await self.execute() # also resets def __len__(self): return len(self.command_stack) diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst index 44614ed8..869c4471 100644 --- a/docs/source/pipelines.rst +++ b/docs/source/pipelines.rst @@ -14,7 +14,7 @@ Pipelines are quite simple to use: async with await client.pipeline() as pipe: await pipe.delete('bar') await pipe.set('bar', 'foo') - await pipe.execute() # needs to be called explicitly + # pipe.execute() is called when `with` block is exited Here are more examples: @@ -30,7 +30,7 @@ Here are more examples: await pipe.set('bar', 'foo') # commands will be buffered await pipe.keys('*') - res = await pipe.execute() + res = await pipe.execute() # call explicitly to retrieve results # results should be in order corresponding to your command assert res == [True, True, True, [b'bar', b'foo']] @@ -104,7 +104,7 @@ explicitly calling reset(): ... try: ... await pipe.watch('OUR-SEQUENCE-KEY') ... ... - ... await pipe.execute() + ... await pipe.execute() # trigger any WatchError early ... break ... except WatchError: ... continue diff --git a/tests/client/test_pipeline.py b/tests/client/test_pipeline.py index 1e106c13..be89fd28 100644 --- a/tests/client/test_pipeline.py +++ b/tests/client/test_pipeline.py @@ -48,6 +48,20 @@ async def test_pipeline_length(self, r): assert len(pipe) == 0 assert not pipe + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_pipeline_autoexecute(self, r): + await r.flushdb() + async with await r.pipeline() as pipe: + # Fill 'er up! + await pipe.set('d', 'd1') + await pipe.set('e', 'e1') + await pipe.set('f', 'f1') + assert len(pipe) == 3 + assert pipe + + # exiting with block calls execute() and reset(), so empty once again + assert len(pipe) == 0 + @pytest.mark.asyncio(forbid_global_loop=True) async def test_pipeline_no_transaction(self, r): await r.flushdb()