diff --git a/examples/example_s2_server.py b/examples/example_s2_server.py index f38d0d4..9146ea8 100644 --- a/examples/example_s2_server.py +++ b/examples/example_s2_server.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta from typing import Any -from s2python.authorization.default_server import S2DefaultServer +from s2python.authorization.default_server import S2DefaultHTTPServer from s2python.generated.gen_s2_pairing import ( S2NodeDescription, Deployment, @@ -77,7 +77,7 @@ def signal_handler(sig: int, frame: Any) -> None: ) # Create and configure the server - server = S2DefaultServer( + server = S2DefaultHTTPServer( host=args.host, http_port=args.http_port, ws_port=args.ws_port, diff --git a/examples/mock_s2_server.py b/examples/mock_s2_server.py index c085b63..4f4caa8 100644 --- a/examples/mock_s2_server.py +++ b/examples/mock_s2_server.py @@ -1,4 +1,3 @@ -import http.server import socketserver import json from typing import Any @@ -7,6 +6,8 @@ import random import string +from s2python.authorization.default_server import S2DefaultHTTPHandler + # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mock_s2_server") @@ -35,7 +36,7 @@ def generate_token() -> str: HTTP_PORT = 8000 -class MockS2Handler(http.server.BaseHTTPRequestHandler): +class MockS2Handler(S2DefaultHTTPHandler): def do_POST(self) -> None: # pylint: disable=C0103 content_length = int(self.headers.get("Content-Length", 0)) post_data = self.rfile.read(content_length).decode("utf-8") @@ -60,10 +61,6 @@ def do_POST(self) -> None: # pylint: disable=C0103 logger.info('Expected token: %s', PAIRING_TOKEN) if request_token_string == PAIRING_TOKEN: - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - # Create pairing response response = { "s2ServerNodeId": SERVER_NODE_ID, @@ -78,22 +75,13 @@ def do_POST(self) -> None: # pylint: disable=C0103 }, "requestConnectionUri": f"http://localhost:{HTTP_PORT}/requestConnection", } - - self.wfile.write(json.dumps(response).encode()) + self._send_json_response(200, response) logger.info("Pairing request successful") else: - self.send_response(401) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": "Invalid token"}).encode()) + self._send_json_response(401, {"error": "Invalid token"}) logger.error("Invalid pairing token") elif self.path == "/requestConnection": - # Handle connection request - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - # Create challenge (normally would be a JWE) challenge = "mock_challenge_string" @@ -104,21 +92,16 @@ def do_POST(self) -> None: # pylint: disable=C0103 "selectedProtocol": "WebSocketSecure", } - self.wfile.write(json.dumps(response).encode()) + # Handle connection request + self._send_json_response(200, response) logger.info("Connection request successful") else: - self.send_response(404) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": "Endpoint not found"}).encode()) + self._send_json_response(404, {"error": "Endpoint not found"}) logger.error('Unknown endpoint: %s', self.path) except Exception as e: - self.send_response(500) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": str(e)}).encode()) + self._send_json_response(500, {"error": str(e)}) logger.error('Error handling request: %s', e) raise e diff --git a/src/s2python/authorization/default_server.py b/src/s2python/authorization/default_server.py index 4f823cb..d0c0648 100644 --- a/src/s2python/authorization/default_server.py +++ b/src/s2python/authorization/default_server.py @@ -10,7 +10,7 @@ import asyncio import uuid from datetime import datetime, timezone -from typing import Dict, Any, Tuple, Optional +from typing import Dict, Any, Tuple, Optional, Union from jwskate import Jwk, Jwt from jwskate.jwe.compact import JweCompact @@ -30,6 +30,15 @@ logger = logging.getLogger("S2DefaultServer") +class S2DefaultWebSocketHandler(websockets.WebSocketServerProtocol): + """Default WebSocket handler for S2 protocol server.""" + + def __init__(self, *args: Any, server_instance: Any = None, **kwargs: Any) -> None: + """Initialize the handler with server instance.""" + self.server_instance = server_instance + super().__init__(*args, **kwargs) + + class S2DefaultHTTPHandler(http.server.BaseHTTPRequestHandler): """Default HTTP handler for S2 protocol server.""" @@ -53,20 +62,29 @@ def do_POST(self) -> None: # pylint: disable=C0103 elif self.path == "/requestConnection": self._handle_connection_request(request_json) else: - self.send_response(404) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": "Endpoint not found"}).encode()) + self._send_json_response(404, {"error": "Endpoint not found"}) logger.error("Unknown endpoint: %s", self.path) except Exception as e: - self.send_response(500) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": str(e)}).encode()) + self._send_json_response(500, {"error": str(e)}) logger.error("Error handling request: %s", e) raise e + def _send_json_response(self, status_code: int, response_body: Union[dict, str]) -> None: + """ + Helper function to send a JSON response. + :param handler: The HTTP handler instance (self). + :param status_code: HTTP status code. + :param response_body: Dictionary or JSON string containing the response body. + """ + self.send_response(status_code) + self.send_header("Content-Type", "application/json") + self.end_headers() + if isinstance(response_body, str): + self.wfile.write(response_body.encode()) + else: + self.wfile.write(json.dumps(response_body).encode()) + def _handle_pairing_request(self, request_json: Dict[str, Any]) -> None: """Handle a pairing request. @@ -81,17 +99,11 @@ def _handle_pairing_request(self, request_json: Dict[str, Any]) -> None: response = self.server_instance.handle_pairing_request(pairing_request) # Send response - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(response.model_dump_json().encode()) + self._send_json_response(200, response.model_dump_json()) logger.info("Pairing request successful") except ValueError as e: - self.send_response(400) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": str(e)}).encode()) + self._send_json_response(400, {"error": str(e)}) logger.error("Invalid pairing request: %s", e) def _handle_connection_request(self, request_json: Dict[str, Any]) -> None: @@ -108,17 +120,11 @@ def _handle_connection_request(self, request_json: Dict[str, Any]) -> None: response = self.server_instance.handle_connection_request(connection_request) # Send response - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(response.model_dump_json().encode()) + self._send_json_response(200, response.model_dump_json()) logger.info("Connection request successful") except ValueError as e: - self.send_response(400) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"error": str(e)}).encode()) + self._send_json_response(400, {"error": str(e)}) logger.error("Invalid connection request: %s", e) def log_message(self, format: str, *args: Any) -> None: # pylint: disable=W0622 @@ -299,6 +305,9 @@ def _create_encrypted_challenge( # try to decrypt the JWE return str(jwe) + +class S2DefaultHTTPServer(S2DefaultServer): + def start_server(self) -> None: """Start the HTTP server.""" if self.instance == "http": @@ -328,7 +337,7 @@ def stop_server(self) -> None: # self._httpd.shutdown() self._httpd.server_close() self._httpd = None - + def start_ws_server(self) -> None: """Start the WebSocket server.""" self._ws_server = websockets.serve(self._handle_websocket_connection, self.host, self.ws_port) diff --git a/tox.ini b/tox.ini index fbef9e5..bacf5f9 100644 --- a/tox.ini +++ b/tox.ini @@ -42,7 +42,7 @@ changedir = {toxinidir} deps = -r dev-requirements.txt commands = - pylint src/ tests/unit/ + pylint src/ tests/unit/ --fail-under=9.8 [testenv:typecheck] description = Typecheck the source code using mypy.