Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
21 changes: 14 additions & 7 deletions src/s2python/s2_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
import ssl
from dataclasses import dataclass
from typing import Optional, List, Type, Dict, Callable, Awaitable, Union
from typing import Any, Optional, List, Type, Dict, Callable, Awaitable, Union

import websockets
from websockets.asyncio.client import ClientConnection as WSConnection, connect as ws_connect
Expand Down Expand Up @@ -200,6 +200,7 @@ class S2Connection: # pylint: disable=too-many-instance-attributes
_stop_event: asyncio.Event
_restart_connection_event: asyncio.Event
_verify_certificate: bool
_bearer_token: Optional[str]

def __init__( # pylint: disable=too-many-arguments
self,
Expand All @@ -209,6 +210,7 @@ def __init__( # pylint: disable=too-many-arguments
asset_details: AssetDetails,
reconnect: bool = False,
verify_certificate: bool = True,
bearer_token: Optional[str] = None
) -> None:
self.url = url
self.reconnect = reconnect
Expand All @@ -229,6 +231,7 @@ def __init__( # pylint: disable=too-many-arguments
self._handlers.register_handler(SelectControlType, self.handle_select_control_type_as_rm)
self._handlers.register_handler(Handshake, self.handle_handshake)
self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response_as_rm)
self._bearer_token = bearer_token

def start_as_rm(self) -> None:
self._run_eventloop(self._run_as_rm())
Expand Down Expand Up @@ -323,13 +326,17 @@ async def wait_till_connection_restart() -> None:

async def _connect_ws(self) -> None:
try:
# set up connection arguments for SSL and bearer token, if required
connection_kwargs: Dict[str, Any] = {}
if self.url.startswith("wss://") and not self._verify_certificate:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
self.ws = await ws_connect(uri=self.url, ssl=ssl_context)
else:
self.ws = await ws_connect(uri=self.url)
connection_kwargs['ssl'] = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
connection_kwargs['ssl'].check_hostname = False
connection_kwargs['ssl'].verify_mode = ssl.CERT_NONE

if self._bearer_token:
connection_kwargs['additional_headers'] = {"Authorization": f"Bearer {self._bearer_token}"}

self.ws = await ws_connect(uri=self.url, **connection_kwargs)
except (EOFError, OSError) as e:
logger.info("Could not connect due to: %s", str(e))

Expand Down