Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions retry/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
import random
import time
import traceback
from functools import partial

Expand All @@ -9,7 +9,7 @@
logging_logger = logging.getLogger(__name__)


def __retry_internal(f, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0,
async def __retry_internal(f, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0,
logger=logging_logger, log_traceback=False, on_exception=None):
"""
Executes a function and retries it if it failed.
Expand All @@ -32,7 +32,7 @@ def __retry_internal(f, exceptions=Exception, tries=-1, delay=0, max_delay=None,
_tries, _delay = tries, delay
while _tries:
try:
return f()
return await f()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to conditionally await by calling inspect.isawaitable. Not sure if you can keep the signature without the await, or do an equivalent of an await call using some library function:

  is_async = False
  value = f()
  if inspect.isawaitable(value):
     is_async = True
     value = await value
  return value

is_async could potentially be used in the except block to determine which sleep method to use

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a way to keep the signature intact. And also went with inspect module to find out whether it's dealing with an async function: my implementation

except exceptions as e:
if on_exception is not None:
if on_exception(e):
Expand All @@ -52,7 +52,7 @@ def __retry_internal(f, exceptions=Exception, tries=-1, delay=0, max_delay=None,
if log_traceback:
logger.warning(traceback.format_exc())

time.sleep(_delay)
await asyncio.sleep(_delay)
_delay *= backoff

if isinstance(jitter, tuple):
Expand Down