diff --git a/.gitignore b/.gitignore index 344dde3f..a6b36ced 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vagrant/.vagrant .vscode/ *.iml .pytest_cache/ +*.so diff --git a/aredis/__init__.py b/aredis/__init__.py index 864aec51..a1fab611 100644 --- a/aredis/__init__.py +++ b/aredis/__init__.py @@ -1,33 +1,70 @@ -from aredis.client import (StrictRedis, StrictRedisCluster) -from aredis.connection import ( - Connection, - UnixDomainSocketConnection, - ClusterConnection -) +from aredis.client import StrictRedis, StrictRedisCluster +from aredis.connection import Connection, UnixDomainSocketConnection, ClusterConnection from aredis.pool import ConnectionPool, ClusterConnectionPool from aredis.exceptions import ( - AuthenticationError, BusyLoadingError, ConnectionError, - DataError, InvalidResponse, PubSubError, ReadOnlyError, - RedisError, ResponseError, TimeoutError, WatchError, - CompressError, ClusterDownException, ClusterCrossSlotError, - CacheError, ClusterDownError, ClusterError, RedisClusterException, - RedisClusterError, ExecAbortError, LockError, NoScriptError + AuthenticationFailureError, + AuthenticationRequiredError, + NoPermissionError, + BusyLoadingError, + ConnectionError, + DataError, + InvalidResponse, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, + CompressError, + ClusterDownException, + ClusterCrossSlotError, + CacheError, + ClusterDownError, + ClusterError, + RedisClusterException, + RedisClusterError, + ExecAbortError, + LockError, + NoScriptError, ) -__version__ = '1.1.8' +__version__ = "1.1.8" + -VERSION = tuple(map(int, __version__.split('.'))) +VERSION = tuple(map(int, __version__.split("."))) __all__ = [ - 'StrictRedis', 'StrictRedisCluster', - 'Connection', 'UnixDomainSocketConnection', 'ClusterConnection', - 'ConnectionPool', 'ClusterConnectionPool', - 'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError', - 'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError', - 'ResponseError', 'TimeoutError', 'WatchError', - 'CompressError', 'ClusterDownException', 'ClusterCrossSlotError', - 'CacheError', 'ClusterDownError', 'ClusterError', 'RedisClusterException', - 'RedisClusterError', 'ExecAbortError', 'LockError', 'NoScriptError' + "StrictRedis", + "StrictRedisCluster", + "Connection", + "UnixDomainSocketConnection", + "ClusterConnection", + "ConnectionPool", + "ClusterConnectionPool", + "AuthenticationFailureError", + "AuthenticationRequiredError", + "NoPermissionError", + "BusyLoadingError", + "ConnectionError", + "DataError", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", + "RedisError", + "ResponseError", + "TimeoutError", + "WatchError", + "CompressError", + "ClusterDownException", + "ClusterCrossSlotError", + "CacheError", + "ClusterDownError", + "ClusterError", + "RedisClusterException", + "RedisClusterError", + "ExecAbortError", + "LockError", + "NoScriptError", ] diff --git a/aredis/client.py b/aredis/client.py index 91973db5..017be733 100644 --- a/aredis/client.py +++ b/aredis/client.py @@ -92,7 +92,7 @@ def from_url(cls, url, db=None, **kwargs): return cls(connection_pool=connection_pool) def __init__(self, host='localhost', port=6379, - db=0, password=None, stream_timeout=None, + db=0, username=None, password=None, stream_timeout=None, connect_timeout=None, connection_pool=None, unix_socket_path=None, encoding='utf-8', decode_responses=False, ssl=False, ssl_context=None, @@ -100,10 +100,12 @@ def __init__(self, host='localhost', port=6379, ssl_cert_reqs=None, ssl_ca_certs=None, max_connections=None, retry_on_timeout=False, max_idle_time=0, idle_check_interval=1, + client_name=None, loop=None, **kwargs): if not connection_pool: kwargs = { 'db': db, + 'username': username, 'password': password, 'encoding': encoding, 'stream_timeout': stream_timeout, @@ -113,6 +115,7 @@ def __init__(self, host='localhost', port=6379, 'decode_responses': decode_responses, 'max_idle_time': max_idle_time, 'idle_check_interval': idle_check_interval, + 'client_name': client_name, 'loop': loop } # based on input, setup appropriate connection args diff --git a/aredis/connection.py b/aredis/connection.py index 6dc79050..efda4bbb 100755 --- a/aredis/connection.py +++ b/aredis/connection.py @@ -10,7 +10,8 @@ from io import BytesIO import aredis.compat -from aredis.exceptions import (ConnectionError, TimeoutError, +from aredis.exceptions import (AuthenticationFailureError, AuthenticationRequiredError, + NoPermissionError, ConnectionError, TimeoutError, RedisError, ExecAbortError, BusyLoadingError, NoScriptError, ReadOnlyError, ResponseError, @@ -154,6 +155,9 @@ class BaseParser: 'MOVED': MovedError, 'CLUSTERDOWN': ClusterDownError, 'CROSSSLOT': ClusterCrossSlotError, + 'WRONGPASS': AuthenticationFailureError, + 'NOAUTH': AuthenticationRequiredError, + 'NOPERM': NoPermissionError, } def parse_error(self, response): @@ -367,11 +371,12 @@ class BaseConnection: def __init__(self, retry_on_timeout=False, stream_timeout=None, parser_class=DefaultParser, reader_read_size=65535, encoding='utf-8', decode_responses=False, - *, loop=None): + *, client_name=None, loop=None): self._parser = parser_class(reader_read_size) self._stream_timeout = stream_timeout self._reader = None self._writer = None + self.username = '' self.password = '' self.db = '' self.pid = os.getpid() @@ -381,6 +386,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None, self.encoding = encoding self.decode_responses = decode_responses self.loop = loop + self.client_name = client_name # flag to show if a connection is waiting for response self.awaiting_response = False self.last_active_at = time.time() @@ -433,17 +439,27 @@ async def _connect(self): async def on_connect(self): self._parser.on_connect(self) + # if a username and a password is specified, authenticate + if self.username and self.password: + await self.send_command('AUTH', self.username, self.password) + if nativestr(await self.read_response()) != 'OK': + raise ConnectionError('Failed to set username or password') # if a password is specified, authenticate - if self.password: + elif self.password: await self.send_command('AUTH', self.password) if nativestr(await self.read_response()) != 'OK': - raise ConnectionError('Invalid Password') + raise ConnectionError('Failed to set password') # if a database is specified, switch to it if self.db: await self.send_command('SELECT', self.db) if nativestr(await self.read_response()) != 'OK': raise ConnectionError('Invalid Database') + + if self.client_name is not None: + await self.send_command('CLIENT SETNAME', self.client_name) + if nativestr(await self.read_response()) != 'OK': + raise ConnectionError('Failed to set client name: {}'.format(self.client_name)) self.last_active_at = time.time() async def read_response(self): @@ -569,17 +585,18 @@ def pack_commands(self, commands): class Connection(BaseConnection): description = 'Connection' - def __init__(self, host='127.0.0.1', port=6379, password=None, + def __init__(self, host='127.0.0.1', port=6379, username=None, password=None, db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None, ssl_context=None, parser_class=DefaultParser, reader_read_size=65535, encoding='utf-8', decode_responses=False, socket_keepalive=None, - socket_keepalive_options=None, *, loop=None): + socket_keepalive_options=None, *, client_name=None, loop=None): super(Connection, self).__init__(retry_on_timeout, stream_timeout, parser_class, reader_read_size, encoding, decode_responses, - loop=loop) + client_name=client_name, loop=loop) self.host = host self.port = port + self.username = username self.password = password self.db = db self.ssl_context = ssl_context @@ -623,16 +640,17 @@ async def _connect(self): class UnixDomainSocketConnection(BaseConnection): description = "UnixDomainSocketConnection" - def __init__(self, path='', password=None, + def __init__(self, path='', username=None, password=None, db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None, ssl_context=None, parser_class=DefaultParser, reader_read_size=65535, - encoding='utf-8', decode_responses=False, *, loop=None): + encoding='utf-8', decode_responses=False, *, client_name=None, loop=None): super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout, parser_class, reader_read_size, encoding, decode_responses, - loop=loop) + client_name=client_name, loop=loop) self.path = path self.db = db + self.username = username self.password = password self.ssl_context = ssl_context self._connect_timeout = connect_timeout diff --git a/aredis/exceptions.py b/aredis/exceptions.py index 775efe68..484f76ee 100644 --- a/aredis/exceptions.py +++ b/aredis/exceptions.py @@ -2,7 +2,15 @@ class RedisError(Exception): pass -class AuthenticationError(RedisError): +class AuthenticationFailureError(RedisError): + pass + + +class AuthenticationRequiredError(RedisError): + pass + + +class NoPermissionError(RedisError): pass diff --git a/aredis/pool.py b/aredis/pool.py index 1fbb2cc7..ef1be1cc 100644 --- a/aredis/pool.py +++ b/aredis/pool.py @@ -103,10 +103,12 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): url_options[name] = value[0] if decode_components: + username = unquote(url.username) if url.username else None password = unquote(url.password) if url.password else None path = unquote(url.path) if url.path else None hostname = unquote(url.hostname) if url.hostname else None else: + username = url.username password = url.password path = url.path hostname = url.hostname @@ -114,6 +116,7 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): # We only support redis:// and unix:// schemes. if url.scheme == 'unix': url_options.update({ + 'username': username, 'password': password, 'path': path, 'connection_class': UnixDomainSocketConnection, @@ -123,6 +126,7 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): url_options.update({ 'host': hostname, 'port': int(url.port or 6379), + 'username': username, 'password': password, }) diff --git a/tests/client/conftest.py b/tests/client/conftest.py index afc4d3ae..e58aff6a 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -37,7 +37,7 @@ def skip_python_vsersion_lt(min_version): @pytest.fixture() def r(event_loop): - return aredis.StrictRedis(loop=event_loop) + return aredis.StrictRedis(client_name='test', loop=event_loop) class AsyncMock(Mock): diff --git a/tests/client/test_commands.py b/tests/client/test_commands.py index 6c47c53d..0097461e 100644 --- a/tests/client/test_commands.py +++ b/tests/client/test_commands.py @@ -64,7 +64,7 @@ async def test_client_list_after_client_setname(self, r): @skip_if_server_version_lt('2.6.9') @pytest.mark.asyncio(forbid_global_loop=True) async def test_client_getname(self, r): - assert await r.client_getname() is None + assert await r.client_getname() == 'test' @skip_if_server_version_lt('2.6.9') @pytest.mark.asyncio(forbid_global_loop=True)