diff --git a/bunny_storm/async_connection.py b/bunny_storm/async_connection.py index 808c58e..3912b2d 100644 --- a/bunny_storm/async_connection.py +++ b/bunny_storm/async_connection.py @@ -105,10 +105,12 @@ async def _connect(self) -> RobustConnection: connection = await connect_robust(url=uri, loop=self._loop, timeout=self._timeout, - client_properties=self._properties) + client_properties=self._properties, + ssl=self._rabbitmq_connection_data.scheme == "amqps", + ssl_options=self._rabbitmq_connection_data.ssl_options) return connection except (asyncio.TimeoutError, ConnectionError): - self.logger.error(f"Connection attempt {attempt_num} / {self._connection_attempts} failed") + self.logger.error(f"Connection attempt {attempt_num} / {self._connection_attempts} failed", exc_info=1) if attempt_num < self._connection_attempts: self.logger.debug(f"Going to sleep for {self._attempt_backoff} seconds") await asyncio.sleep(self._attempt_backoff) diff --git a/bunny_storm/rabbitmq_connection_data.py b/bunny_storm/rabbitmq_connection_data.py index e5b6cc8..0e13399 100644 --- a/bunny_storm/rabbitmq_connection_data.py +++ b/bunny_storm/rabbitmq_connection_data.py @@ -14,6 +14,8 @@ class RabbitMQConnectionData: port: int = 5672 virtual_host: str = "/" connection_name: str = "" + scheme: str = "amqp" + ssl_options: dict = None def uri(self) -> str: """ @@ -21,5 +23,15 @@ def uri(self) -> str: :return: Connection URI """ vhost = "" if self.virtual_host == "/" else self.virtual_host - name_query = f"?name={self.connection_name}" if self.connection_name else "" - return f"amqp://{self.username}:{self.password}@{self.host}:{self.port}/{vhost}{name_query}" + + query = "" + query_list = [] + if self.connection_name: + query_list.append(f"name={self.connection_name}") + if self.ssl_options: + for option, value in self.ssl_options.items(): + query_list.append(f"{option}={value}") + if query_list: + query = "?" + "&".join(query_list) + + return f"{self.scheme}://{self.username}:{self.password}@{self.host}:{self.port}/{vhost}{query}" diff --git a/requirements.txt b/requirements.txt index 5a41365..04e5214 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,9 @@ coverage>=4.5.4 Sphinx>=1.8.5 twine>=1.14.0 -pytest>=4.6.5 -pytest-asyncio>=0.15.1 -pytest-runner>=5.1 +pytest>=4.6.5,<6 +pytest-asyncio==0.15.1 +pytest-runner>=5.1,<6 aiohttp>=3.7.4.post0 setuptools>=57.2.0 diff --git a/tests/conftest.py b/tests/conftest.py index cef32c5..efd83f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,11 @@ def rabbitmq_port() -> int: return int(os.getenv("RABBITMQ_PORT", "5672")) +@pytest.fixture(scope="session") +def rabbitmq_ssl_port() -> int: + return int(os.getenv("RABBITMQ_SSL_PORT", "5671")) + + @pytest.fixture(scope="session") def rabbitmq_virtual_host() -> str: return os.getenv("RABBITMQ_VIRTUAL_HOST", "vhost") @@ -49,6 +54,19 @@ def rabbitmq_connection_data(rabbitmq_user: str, rabbitmq_password: str, rabbitm connection_name="test_runner") return connection_data +@pytest.fixture(scope="function") +def rabbitmq_ssl_connection_data(rabbitmq_user: str, rabbitmq_password: str, rabbitmq_host: str, rabbitmq_ssl_port: int, + rabbitmq_virtual_host: str) -> RabbitMQConnectionData: + connection_data = RabbitMQConnectionData(username=rabbitmq_user, + password=rabbitmq_password, + host=rabbitmq_host, + port=rabbitmq_ssl_port, + scheme="amqps", + ssl_options={"no_verify_ssl": "1"}, + virtual_host=rabbitmq_virtual_host, + connection_name="test_runner_ssl") + return connection_data + @pytest.fixture(scope="session") def configuration() -> dict: diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 58ecce4..8711689 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -20,6 +20,18 @@ async def test_async_connection_get_connection(event_loop: asyncio.AbstractEvent # Assert assert isinstance(connection, RobustConnection) and async_connection.is_connected() +@pytest.mark.asyncio +async def test_async_connection_get_ssl_connection(event_loop: asyncio.AbstractEventLoop, + logger: logging.Logger, + rabbitmq_ssl_connection_data: RabbitMQConnectionData) -> None: + # Arrange + async_connection = AsyncConnection(rabbitmq_ssl_connection_data, logger, event_loop) + + # Act + connection = await async_connection.get_connection() + + # Assert + assert isinstance(connection, RobustConnection) and async_connection.is_connected() @pytest.mark.asyncio async def test_async_connection_connection_failure(event_loop: asyncio.AbstractEventLoop,