From 174ebb8d6dc51f77ac270f6a0998bfe98ca3e659 Mon Sep 17 00:00:00 2001 From: tom-pilotschool Date: Sat, 7 Jun 2025 17:03:13 +1000 Subject: [PATCH 1/3] Updated tests for websockets event loop --- tests/test_websocket_event_loop.py | 811 +++++++++++++++++++++++++++++ 1 file changed, 811 insertions(+) create mode 100644 tests/test_websocket_event_loop.py diff --git a/tests/test_websocket_event_loop.py b/tests/test_websocket_event_loop.py new file mode 100644 index 0000000..02772cc --- /dev/null +++ b/tests/test_websocket_event_loop.py @@ -0,0 +1,811 @@ +#!/usr/bin/env python3 +""" +WebSocket Event Loop Stability Test Suite + +This test suite is designed to validate that the Daebus WebSocket implementation +can handle various stress conditions without encountering "no running event loop" errors. + +The tests cover: +1. Rapid connection/disconnection cycles +2. Concurrent message handling +3. Server shutdown during active connections +4. Rate limiting stress testing +5. Broadcast queue stress testing +6. Thread safety of WebSocket operations +""" + +import pytest +import asyncio +import threading +import time +import json +import websockets +import socket +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest.mock import patch, MagicMock +import logging + +# Configure logging to capture errors +logging.basicConfig(level=logging.DEBUG) + +# Import the modules we need to test +from daebus import Daebus, DaebusHttp, DaebusWebSocket + + +def find_free_port(): + """Find and return a free port number by opening a temporary socket.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + + +class WebSocketEventLoopTests: + """Test suite for WebSocket event loop stability and error handling.""" + + def setup_method(self): + """Set up test environment before each test.""" + # Use a random free port to avoid conflicts + self.test_port = find_free_port() + + # Create the app with mocked Redis + with patch('daebus.modules.daemon.Redis'), \ + patch('daebus.modules.daemon.BackgroundScheduler'): + + self.app = Daebus("test_event_loop_app") + self.http = DaebusHttp(port=self.test_port) + self.ws = DaebusWebSocket() + + self.app.attach(self.http) + self.app.attach(self.ws) + + # Storage for test results and errors + self.test_errors = [] + self.connection_events = [] + self.message_events = [] + self.server_ready_event = threading.Event() + self.stop_event = threading.Event() + + # Set up event handlers + @self.ws.socket_connect() + def on_connect(data, client_id): + self.connection_events.append(('connect', client_id, time.time())) + return {'status': 'connected', 'client_id': client_id} + + @self.ws.socket_disconnect() + def on_disconnect(data, client_id): + self.connection_events.append(('disconnect', client_id, time.time())) + + @self.ws.socket('test_message') + def handle_test_message(data, client_id): + self.message_events.append(('message', client_id, data, time.time())) + return {'echo': data, 'client_id': client_id} + + @self.ws.socket('error_message') + def handle_error_message(data, client_id): + # Deliberately cause an error to test error handling + raise Exception("Test error in message handler") + + @self.ws.socket('async_message') + async def handle_async_message(data, client_id): + # Test async message handling + await asyncio.sleep(0.1) + return {'async_response': data, 'client_id': client_id} + + @self.ws.socket('broadcast_trigger') + def handle_broadcast_trigger(data, client_id): + # Test broadcasting from message handler + try: + self.ws.broadcast_to_all({'broadcast_data': data}, 'broadcast') + return {'broadcast_sent': True} + except Exception as e: + self.test_errors.append(f"Broadcast error: {e}") + raise + + def teardown_method(self): + """Clean up after each test.""" + try: + # Signal server to stop + self.stop_event.set() + + # Stop the app if it's running + if hasattr(self.app, '_running') and self.app._running: + self.app.stop() + + # Stop WebSocket server if running + if hasattr(self.ws, 'is_running') and self.ws.is_running: + self.ws.is_running = False + + # Give some time for cleanup + time.sleep(0.5) + except Exception as e: + print(f"Cleanup error: {e}") + + def get_websocket_url(self): + """Get the WebSocket URL for the test server.""" + return f"ws://127.0.0.1:{self.test_port}" + + async def create_websocket_client(self, timeout=5): + """Create an async WebSocket client connection.""" + url = self.get_websocket_url() + + try: + # Create WebSocket connection with proper timeout + websocket = await websockets.connect( + url, + close_timeout=timeout, + ping_interval=None, # Disable ping/pong + ping_timeout=None + ) + return websocket + except Exception as e: + print(f"Failed to create WebSocket connection: {e}") + return None + + async def send_message(self, websocket, message_type, data): + """Send a message through the WebSocket.""" + message = { + 'type': message_type, + 'data': data + } + await websocket.send(json.dumps(message)) + + async def receive_message(self, websocket, timeout=1): + """Receive a message from the WebSocket.""" + try: + message = await asyncio.wait_for(websocket.recv(), timeout=timeout) + return json.loads(message) + except asyncio.TimeoutError: + return None + except Exception as e: + print(f"Error receiving message: {e}") + return None + + def start_server_in_thread(self): + """Start the WebSocket server in a separate thread.""" + def run_server(): + # Start HTTP server + self.app.http.start() + + # Start WebSocket server directly + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def handler_wrapper(websocket): + # Call with a mock path parameter + await self.ws._handle_connection(websocket, "/") + + async def setup_ws_server(): + try: + ws_server = await websockets.serve( + handler_wrapper, + "127.0.0.1", + self.test_port + ) + self.ws.is_running = True + self.ws.server = ws_server + self.server_ready_event.set() + return ws_server + except Exception as e: + print(f"Error starting WebSocket server: {e}") + self.server_ready_event.set() + raise + + try: + ws_server = loop.run_until_complete(setup_ws_server()) + + # Wait for stop signal + while not self.stop_event.is_set(): + loop.run_until_complete(asyncio.sleep(0.1)) + + except Exception as e: + print(f"Server error: {e}") + finally: + # Cleanup + if 'ws_server' in locals(): + ws_server.close() + loop.run_until_complete(ws_server.wait_closed()) + self.ws.is_running = False + loop.close() + self.app.http.stop() + + # Start server thread + self.server_thread = threading.Thread(target=run_server) + self.server_thread.daemon = True + self.server_thread.start() + + # Wait for server to be ready + if not self.server_ready_event.wait(10.0): + raise RuntimeError("Timed out waiting for server to start") + + @pytest.mark.asyncio + async def test_rapid_connection_cycles(self): + """Test rapid connect/disconnect cycles that might stress the event loop.""" + print("Testing rapid connection cycles...") + + # Start the server + self.start_server_in_thread() + + async def connection_worker(worker_id): + """Worker function for rapid connections.""" + errors = [] + connections = 0 + + for i in range(10): # 10 rapid connections per worker + try: + websocket = await self.create_websocket_client(timeout=2) + if websocket: + connections += 1 + + # Send a test message + await self.send_message(websocket, 'test_message', { + 'worker': worker_id, + 'iteration': i + }) + + # Try to receive response + response = await self.receive_message(websocket, timeout=2) + if response is None: + errors.append(f"Worker {worker_id}, iteration {i}: No response received") + + await websocket.close() + + # Small delay between connections + await asyncio.sleep(0.05) + else: + errors.append(f"Worker {worker_id}, iteration {i}: Failed to connect") + + except Exception as e: + errors.append(f"Worker {worker_id}, iteration {i}: {e}") + + return {'worker_id': worker_id, 'errors': errors, 'connections': connections} + + # Run multiple workers simultaneously + tasks = [connection_worker(i) for i in range(5)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Analyze results + total_errors = [] + for result in results: + if isinstance(result, Exception): + total_errors.append(f"Task exception: {result}") + else: + total_errors.extend(result['errors']) + + print(f"Rapid connection test completed. Total errors: {len(total_errors)}") + if total_errors: + print("Errors encountered:") + for error in total_errors[:10]: # Show first 10 errors + print(f" - {error}") + + # Assert no critical event loop errors + event_loop_errors = [e for e in total_errors if 'event loop' in str(e).lower()] + assert len(event_loop_errors) == 0, f"Event loop errors detected: {event_loop_errors}" + + @pytest.mark.asyncio + async def test_concurrent_message_handling(self): + """Test concurrent message handling that might cause event loop issues.""" + print("Testing concurrent message handling...") + + # Start the server + self.start_server_in_thread() + + # Create multiple persistent connections + clients = [] + for i in range(3): + websocket = await self.create_websocket_client() + if websocket: + clients.append(websocket) + + assert len(clients) > 0, "Failed to create any client connections" + + async def message_worker(websocket, worker_id): + """Send messages concurrently from a client.""" + errors = [] + responses = [] + + try: + for i in range(10): # 10 messages per client + # Mix different message types + if i % 3 == 0: + await self.send_message(websocket, 'test_message', { + 'worker': worker_id, 'msg': i + }) + elif i % 3 == 1: + await self.send_message(websocket, 'async_message', { + 'worker': worker_id, 'msg': i + }) + else: + await self.send_message(websocket, 'broadcast_trigger', { + 'worker': worker_id, 'msg': i + }) + + # Try to receive response + response = await self.receive_message(websocket, timeout=1) + if response: + responses.append(response) + + await asyncio.sleep(0.01) # Small delay between messages + + except Exception as e: + errors.append(f"Worker {worker_id}: {e}") + + return {'worker_id': worker_id, 'errors': errors, 'responses': len(responses)} + + # Send messages concurrently from all clients + tasks = [message_worker(client, i) for i, client in enumerate(clients)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Clean up clients + for client in clients: + await client.close() + + # Analyze results + total_errors = [] + for result in results: + if isinstance(result, Exception): + total_errors.append(f"Task exception: {result}") + else: + total_errors.extend(result['errors']) + + print(f"Concurrent message test completed. Total errors: {len(total_errors)}") + if total_errors: + print("Errors encountered:") + for error in total_errors[:10]: + print(f" - {error}") + + # Check for event loop specific errors + event_loop_errors = [e for e in total_errors if 'event loop' in str(e).lower()] + assert len(event_loop_errors) == 0, f"Event loop errors detected: {event_loop_errors}" + + @pytest.mark.asyncio + async def test_server_shutdown_during_activity(self): + """Test server shutdown while clients are active - this often triggers event loop issues.""" + print("Testing server shutdown during activity...") + + # Start the server + self.start_server_in_thread() + + # Create active connections + clients = [] + for i in range(2): + websocket = await self.create_websocket_client() + if websocket: + clients.append(websocket) + + async def keep_sending_messages(websocket, client_id): + """Keep sending messages until the server shuts down.""" + message_count = 0 + errors = [] + + try: + while message_count < 50: # Try to send many messages + try: + await self.send_message(websocket, 'test_message', { + 'client': client_id, + 'msg_num': message_count, + 'timestamp': time.time() + }) + + # Try to receive response (but don't wait long) + await self.receive_message(websocket, timeout=0.1) + + message_count += 1 + await asyncio.sleep(0.05) # 50ms between messages + + except Exception as e: + errors.append(f"Client {client_id}, message {message_count}: {e}") + break # Stop on error + + except Exception as e: + errors.append(f"Client {client_id} general error: {e}") + + return {'client_id': client_id, 'messages_sent': message_count, 'errors': errors} + + # Start sending messages from multiple clients + tasks = [keep_sending_messages(client, i) for i, client in enumerate(clients)] + + # Let them run for a bit + await asyncio.sleep(1) + + # Now shut down the server while messages are being sent + print("Shutting down server during active messaging...") + try: + self.stop_event.set() + if hasattr(self.ws, 'is_running'): + self.ws.is_running = False + except Exception as e: + self.test_errors.append(f"Server shutdown error: {e}") + + # Wait for all message sending to complete or timeout + try: + results = await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=5) + except asyncio.TimeoutError: + print("Message sending tasks timed out during shutdown") + results = [] + + # Clean up clients + for client in clients: + try: + await client.close() + except: + pass + + # Analyze results + total_errors = [] + total_messages = 0 + for result in results: + if isinstance(result, Exception): + total_errors.append(f"Task exception: {result}") + else: + total_errors.extend(result['errors']) + total_messages += result['messages_sent'] + + print(f"Shutdown test completed. Messages sent: {total_messages}, Errors: {len(total_errors)}") + + # Look for event loop errors specifically + event_loop_errors = [e for e in total_errors + self.test_errors if 'event loop' in str(e).lower()] + + if event_loop_errors: + print("Event loop errors found:") + for error in event_loop_errors: + print(f" - {error}") + + # This is the key assertion - no event loop errors should occur + assert len(event_loop_errors) == 0, f"Event loop errors detected during shutdown: {event_loop_errors}" + + @pytest.mark.asyncio + async def test_thread_safety_operations(self): + """Test thread safety by accessing WebSocket methods from different contexts.""" + print("Testing thread safety operations...") + + # Start the server + self.start_server_in_thread() + + # Create a client + websocket = await self.create_websocket_client() + assert websocket is not None, "Failed to connect client" + + # Wait for connection to be established + await asyncio.sleep(0.2) + + def thread_worker(thread_id): + """Worker that tries to use WebSocket methods from different threads.""" + errors = [] + + try: + # Try different WebSocket operations from this thread + + # 1. Check client count + count = self.ws.get_client_count() + + # 2. Get client list + clients = self.ws.get_clients() + + # 3. Try broadcasting (this is most likely to trigger event loop issues) + try: + sent_count = self.ws.broadcast_to_all({ + 'from_thread': thread_id, + 'broadcast_time': time.time() + }, 'thread_broadcast') + except Exception as e: + # Capture any event loop related errors + errors.append(f"Thread {thread_id} broadcast error: {e}") + + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + return {'thread_id': thread_id, 'errors': errors} + + # Run operations from multiple threads simultaneously + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(thread_worker, i) for i in range(5)] + results = [future.result() for future in as_completed(futures)] + + await websocket.close() + + # Analyze results + total_errors = [] + for result in results: + total_errors.extend(result['errors']) + + print(f"Thread safety test completed. Total errors: {len(total_errors)}") + + if total_errors: + print("Errors encountered:") + for error in total_errors[:10]: + print(f" - {error}") + + # Check specifically for event loop errors + event_loop_errors = [e for e in total_errors if 'event loop' in str(e).lower()] + assert len(event_loop_errors) == 0, f"Event loop thread safety errors: {event_loop_errors}" + + @pytest.mark.asyncio + async def test_broadcast_from_message_handlers(self): + """Test broadcasting from within message handlers - a common source of event loop issues.""" + print("Testing broadcasts triggered from message handlers...") + + # Start the server + self.start_server_in_thread() + + # Create multiple clients to test broadcasting + clients = [] + for i in range(3): + websocket = await self.create_websocket_client() + if websocket: + clients.append(websocket) + + assert len(clients) > 0, "Failed to create client connections" + + # Send broadcast trigger messages rapidly from multiple clients + async def trigger_broadcasts(websocket, client_id): + """Trigger broadcasts from message handlers.""" + errors = [] + + try: + for i in range(5): # Multiple broadcast triggers + await self.send_message(websocket, 'broadcast_trigger', { + 'client': client_id, + 'trigger_num': i, + 'timestamp': time.time() + }) + + # Don't wait for response to simulate rapid triggers + await asyncio.sleep(0.01) + + except Exception as e: + errors.append(f"Client {client_id}: {e}") + + return {'client_id': client_id, 'errors': errors} + + # Trigger broadcasts from all clients simultaneously + tasks = [trigger_broadcasts(client, i) for i, client in enumerate(clients)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Let broadcasts propagate + await asyncio.sleep(0.5) + + # Clean up clients + for client in clients: + await client.close() + + # Analyze results + total_errors = [] + for result in results: + if isinstance(result, Exception): + total_errors.append(f"Task exception: {result}") + else: + total_errors.extend(result['errors']) + + print(f"Broadcast handler test completed. Total errors: {len(total_errors)}") + + if total_errors: + print("Errors encountered:") + for error in total_errors[:10]: + print(f" - {error}") + + # Check for event loop errors + event_loop_errors = [e for e in total_errors + self.test_errors if 'event loop' in str(e).lower()] + assert len(event_loop_errors) == 0, f"Event loop errors in broadcast handlers: {event_loop_errors}" + + @pytest.mark.asyncio + async def test_event_loop_edge_cases(self): + """Test edge cases that could trigger 'no running event loop' errors.""" + print("Testing event loop edge cases...") + + # Start the server + self.start_server_in_thread() + + # Create a client + websocket = await self.create_websocket_client() + assert websocket is not None, "Failed to connect client" + + # Wait for connection + await asyncio.sleep(0.1) + + # Test 1: Rapid send operations from different thread contexts + def rapid_operations(): + """Perform rapid WebSocket operations that might hit edge cases.""" + errors = [] + + for i in range(10): + try: + # These operations use the event loop detection pattern that could fail + count = self.ws.get_client_count() + clients = self.ws.get_clients() + + # Try sending to client (if any exist) + if clients: + client_id = clients[0] + success = self.ws.send_to_client(client_id, { + 'edge_case_test': i, + 'timestamp': time.time() + }) + + # Try broadcasting + self.ws.broadcast_to_all({ + 'edge_case_broadcast': i, + 'timestamp': time.time() + }) + + except Exception as e: + errors.append(f"Operation {i}: {e}") + + return errors + + # Run rapid operations in a separate thread + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(rapid_operations) + + # While operations are running, also send messages from the async context + for i in range(5): + await self.send_message(websocket, 'test_message', { + 'async_test': i, + 'timestamp': time.time() + }) + await asyncio.sleep(0.02) + + # Get results from thread operations + thread_errors = future.result() + + await websocket.close() + + print(f"Edge case test completed. Thread errors: {len(thread_errors)}") + + if thread_errors: + print("Thread operation errors:") + for error in thread_errors[:5]: + print(f" - {error}") + + # Check for event loop errors specifically + event_loop_errors = [e for e in thread_errors if 'event loop' in str(e).lower()] + assert len(event_loop_errors) == 0, f"Event loop edge case errors: {event_loop_errors}" + + +@pytest.mark.skipif( + "os.environ.get('CI', 'false').lower() == 'true'", + reason="WebSocket event loop tests are unstable in CI environments" +) +class TestWebSocketEventLoopStability: + """Test class that can be run with pytest to validate WebSocket event loop stability.""" + + def test_rapid_connections(self): + """Run the rapid connection cycles test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_rapid_connection_cycles()) + finally: + test_suite.teardown_method() + + def test_concurrent_messaging(self): + """Run the concurrent message handling test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_concurrent_message_handling()) + finally: + test_suite.teardown_method() + + def test_shutdown_stress(self): + """Run the server shutdown during activity test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_server_shutdown_during_activity()) + finally: + test_suite.teardown_method() + + def test_thread_safety(self): + """Run the thread safety operations test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_thread_safety_operations()) + finally: + test_suite.teardown_method() + + def test_broadcast_from_message_handlers(self): + """Run the broadcast from message handlers test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_broadcast_from_message_handlers()) + finally: + test_suite.teardown_method() + + def test_event_loop_edge_cases(self): + """Run the event loop edge cases test.""" + test_suite = WebSocketEventLoopTests() + test_suite.setup_method() + try: + import asyncio + asyncio.run(test_suite.test_event_loop_edge_cases()) + finally: + test_suite.teardown_method() + + +if __name__ == "__main__": + """Run tests manually for debugging.""" + import os + + print("="*80) + print("WebSocket Event Loop Stability Test Suite") + print("="*80) + print("Note: Run with 'pytest test_websocket_event_loop.py' for better test management") + print("="*80) + + test_suite = WebSocketEventLoopTests() + + tests = [ + ("Rapid Connection Cycles", test_suite.test_rapid_connection_cycles), + ("Concurrent Message Handling", test_suite.test_concurrent_message_handling), + ("Server Shutdown During Activity", test_suite.test_server_shutdown_during_activity), + ("Thread Safety Operations", test_suite.test_thread_safety_operations), + ("Broadcast from Message Handlers", test_suite.test_broadcast_from_message_handlers), + ("Event Loop Edge Cases", test_suite.test_event_loop_edge_cases), + ] + + results = {} + + for test_name, test_func in tests: + print(f"\n{'-'*60}") + print(f"Running: {test_name}") + print(f"{'-'*60}") + + try: + # Set up fresh test environment + test_suite.setup_method() + + # Run the test + start_time = time.time() + asyncio.run(test_func()) + end_time = time.time() + + results[test_name] = { + 'status': 'PASSED', + 'duration': end_time - start_time, + 'error': None + } + print(f"✅ {test_name} PASSED ({end_time - start_time:.2f}s)") + + except Exception as e: + results[test_name] = { + 'status': 'FAILED', + 'duration': time.time() - start_time if 'start_time' in locals() else 0, + 'error': str(e) + } + print(f"❌ {test_name} FAILED: {e}") + + finally: + # Clean up + try: + test_suite.teardown_method() + except Exception as e: + print(f"⚠️ Cleanup error for {test_name}: {e}") + + # Print summary + print(f"\n{'='*80}") + print("TEST SUMMARY") + print(f"{'='*80}") + + passed = sum(1 for r in results.values() if r['status'] == 'PASSED') + failed = sum(1 for r in results.values() if r['status'] == 'FAILED') + + print(f"Total tests: {len(results)}") + print(f"Passed: {passed}") + print(f"Failed: {failed}") + + if failed > 0: + print(f"\nFAILED TESTS:") + for test_name, result in results.items(): + if result['status'] == 'FAILED': + print(f" - {test_name}: {result['error']}") + + # Exit with error code if any tests failed + exit(failed) \ No newline at end of file From 42fe870a2ccbba33deedea7dfc901fbdd85d1deb Mon Sep 17 00:00:00 2001 From: tom-pilotschool Date: Sat, 7 Jun 2025 17:11:56 +1000 Subject: [PATCH 2/3] Fixes to documentation --- docs/src/content/docs/guides/blueprint.md | 14 +- docs/src/content/docs/guides/websockets.md | 255 +++++++++++++++------ docs/src/content/docs/overview.md | 6 +- 3 files changed, 189 insertions(+), 86 deletions(-) diff --git a/docs/src/content/docs/guides/blueprint.md b/docs/src/content/docs/guides/blueprint.md index f68e430..30fb6cb 100644 --- a/docs/src/content/docs/guides/blueprint.md +++ b/docs/src/content/docs/guides/blueprint.md @@ -126,16 +126,16 @@ Register WebSocket message handlers: ```python @blueprint.socket("user_update") -def handle_user_update(req, sid): +def handle_user_update(data, client_id): # Handle WebSocket message return {"status": "processed"} @blueprint.socket_connect() -def on_connect(req, sid): +def on_connect(data, client_id): return {"status": "connected"} @blueprint.socket_disconnect() -def on_disconnect(req, sid): +def on_disconnect(data, client_id): # Clean up resources pass ``` @@ -379,7 +379,7 @@ def http_create_user(req): # WebSocket handlers @users_bp.socket("get_users") -def ws_get_users(req, sid): +def ws_get_users(data, client_id): return { "users": list(users.values()) } @@ -509,9 +509,9 @@ def http_login(req): # WebSocket authentication @auth_bp.socket_connect() -def ws_authenticate(req, sid): - # Extract token from query parameters or headers - token = req.data.get("token") +def ws_authenticate(data, client_id): + # Extract token from connection data + token = data.get("token") if not token: # No token provided, allow connection but mark as unauthenticated diff --git a/docs/src/content/docs/guides/websockets.md b/docs/src/content/docs/guides/websockets.md index 7493899..2f09bbc 100644 --- a/docs/src/content/docs/guides/websockets.md +++ b/docs/src/content/docs/guides/websockets.md @@ -32,19 +32,55 @@ app.run(service="realtime_service") ## Message Handlers +### Understanding Handler Signatures + +**Important**: WebSocket message handlers in Daebus use a specific signature that differs from some other WebSocket libraries: + +```python +@app.socket("message_type") +def handler(data, client_id): + # data: Contents of the 'data' field from the client message + # client_id: Unique identifier for the WebSocket connection + pass +``` + +**What the client sends vs. what your handler receives:** + +```javascript +// Client sends this complete message: +{ + "type": "chat_message", // Used to route to the correct handler + "data": { // This object becomes the 'data' parameter + "message": "Hello!", + "room": "general" + } +} +``` + +```python +# Your handler receives: +@app.socket("chat_message") # ← Matches the 'type' field +def handle_chat(data, client_id): + # data = {"message": "Hello!", "room": "general"} + # client_id = "user_abc123..." (unique session ID) + + message = data.get("message") # ← Direct access to message data + room = data.get("room") +``` + ### Handling Message Types Use the `@app.socket()` decorator to handle specific message types: ```python @app.socket("chat_message") -def handle_chat(req, sid): +def handle_chat(data, client_id): """Handle incoming chat messages""" - message = req.data.get("message", "") - sender = req.data.get("sender", "Anonymous") + message = data.get("message", "") + sender = data.get("sender", "Anonymous") # Log the message - logger.info(f"Received chat message from {sender} (client {sid}): {message}") + logger.info(f"Received chat message from {sender} (client {client_id}): {message}") # Broadcast to all clients app.websocket.broadcast_to_all({ @@ -61,8 +97,8 @@ def handle_chat(req, sid): ``` The handler function receives two parameters: -- `req`: The WebSocketRequest object containing message data -- `sid`: The client's session ID (a unique identifier for the connection) +- `data`: The contents of the `data` field from the client's message +- `client_id`: The client's session ID (a unique identifier for the connection) ### Connection Events @@ -70,25 +106,25 @@ Handle client connections and disconnections: ```python @app.socket_connect() -def on_connect(req, sid): +def on_connect(data, client_id): """Handle new client connection""" - logger.info(f"Client {sid} connected") + logger.info(f"Client {client_id} connected") # You can return data that will be sent to the client return { "status": "connected", - "client_id": sid, + "client_id": client_id, "server_time": time.time() } @app.socket_disconnect() -def on_disconnect(req, sid): +def on_disconnect(data, client_id): """Handle client disconnection""" - logger.info(f"Client {sid} disconnected") + logger.info(f"Client {client_id} disconnected") # Clean up any client-specific resources - if sid in user_sessions: - del user_sessions[sid] + if client_id in user_sessions: + del user_sessions[client_id] ``` ### Client Registration @@ -97,19 +133,19 @@ Handle client registration with custom data: ```python @app.socket_register() -def on_register(req, sid): +def on_register(data, client_id): """Handle client registration""" - user_data = req.data.get("user", {}) - username = user_data.get("username", f"Guest-{sid[:8]}") + user_data = data.get("user", {}) + username = user_data.get("username", f"Guest-{client_id[:8]}") # Store the user information - user_sessions[sid] = { + user_sessions[client_id] = { "username": username, "registered_at": time.time(), "is_active": True } - logger.info(f"Client {sid} registered as {username}") + logger.info(f"Client {client_id} registered as {username}") # Notify others about the new user app.websocket.broadcast_to_all({ @@ -131,8 +167,8 @@ Send a response to the client who sent the message: ```python @app.socket("get_data") -def handle_data_request(req, sid): - data_id = req.data.get("id") +def handle_data_request(data, client_id): + data_id = data.get("id") try: # Fetch the requested data @@ -291,33 +327,78 @@ def disconnect_client(): }) ``` -## Working with Request Data +## Working with Message Data -The `req` object in WebSocket handlers provides: +WebSocket handlers receive the message data directly from the client's `data` field: ```python @app.socket("example_message") -def handle_example(req, sid): - # Message type (from the 'type' field in the client message) - message_type = req.message_type - - # Full message payload - full_payload = req.payload - - # Convenience access to the 'data' field of the message - message_data = req.data +def handle_example(data, client_id): + # Direct access to the message data (from the 'data' field in the client message) + username = data.get("username", "Anonymous") + action = data.get("action", "view") - # Access specific fields with defaults - username = message_data.get("username", "Anonymous") - action = message_data.get("action", "view") + # The client_id parameter provides the unique identifier for this connection + logger.info(f"Processing {action} request from {username} (client: {client_id})") - # You can also access the WebSocket connection directly - websocket = req.websocket + # If you need access to the full request context, use the request proxy + from daebus.modules.context import request + message_type = request.message_type # The 'type' field from the client message + websocket_connection = request.websocket # The underlying WebSocket connection # Process the message... return {"status": "processed"} ``` +**Client Message Structure:** +```javascript +// Client sends this structure +{ + "type": "example_message", // Determines which handler is called + "data": { // This object is passed as 'data' parameter + "username": "JohnDoe", + "action": "view" + } +} +``` + +**When you need the full request context:** + +If you need access to the complete message structure, WebSocket connection, or other request details, use the request proxy: + +```python +@app.socket("advanced_handler") +def handle_advanced(data, client_id): + # Access message data directly (recommended for most cases) + username = data.get("username") + + # Access full request context when needed + from daebus.modules.context import request + + message_type = request.message_type # The 'type' field from client + full_payload = request.payload # Complete client message + websocket_conn = request.websocket # Raw WebSocket connection + + # Access client metadata + metadata = app.websocket.get_client_metadata(client_id) + connected_at = metadata.get("connected_at") + + return {"processed": True} +``` + +**Important Notes:** + +1. **Return values**: Anything you return from a handler is automatically sent to the client as a response message +2. **Async handlers**: You can make handlers async if you need to perform async operations: + ```python + @app.socket("async_operation") + async def handle_async(data, client_id): + result = await some_async_operation() + return {"result": result} + ``` +3. **Error handling**: Exceptions in handlers are caught and sent as error messages to the client +4. **No return value**: If your handler doesn't return anything, no response is sent (useful for fire-and-forget messages) + ## Client-Side Implementation Here's a basic JavaScript client implementation: @@ -392,7 +473,7 @@ Enable rate limiting to prevent abuse: # Set up rate limiting when creating the WebSocket component websocket = DaebusWebSocket() websocket.enable_rate_limiting( - messages_per_minute=60, # Maximum messages per minute + max_messages=60, # Maximum messages per minute window_seconds=60 # Time window for counting messages ) app.attach(websocket) @@ -444,16 +525,20 @@ Use blueprints to organize WebSocket handlers: ```python from daebus import Daebus, DaebusHttp, DaebusWebSocket, Blueprint +# Global storage for the blueprint example +chat_rooms = {} +authenticated_users = {} + # Create a blueprint for chat functionality chat_bp = Blueprint("chat") @chat_bp.socket("send_message") -def handle_chat_message(req, sid): +def handle_chat_message(data, client_id): # Chat message handling logic return {"received": True} @chat_bp.socket("join_room") -def handle_join_room(req, sid): +def handle_join_room(data, client_id): # Room joining logic return {"joined": True} @@ -461,7 +546,7 @@ def handle_join_room(req, sid): user_bp = Blueprint("users") @user_bp.socket_connect() -def handle_connect(req, sid): +def handle_connect(data, client_id): # Connection handling return {"welcome": True} @@ -489,9 +574,11 @@ Implement authentication for WebSocket connections: ```python @app.socket_connect() -def on_connect(req, sid): - # Extract authentication token from request path or headers - token = extract_token_from_request(req) +def on_connect(data, client_id): + # Extract authentication token from the connection request + # You can access the full request context using the request proxy + from daebus.modules.context import request + token = extract_token_from_request(request) if not token or not validate_token(token): # Return False to reject the connection @@ -499,7 +586,12 @@ def on_connect(req, sid): # Store authenticated user information user_id = get_user_id_from_token(token) - app.websocket.set_client_data(sid, "user_id", user_id) + + # Store user info in your own session storage + authenticated_users[client_id] = { + "user_id": user_id, + "authenticated_at": time.time() + } logger.info(f"Authenticated connection from user {user_id}") return {"authenticated": True, "user_id": user_id} @@ -511,9 +603,9 @@ Always validate incoming messages: ```python @app.socket("update_profile") -def handle_profile_update(req, sid): +def handle_profile_update(data, client_id): # Get user data - profile_data = req.data.get("profile", {}) + profile_data = data.get("profile", {}) # Validate required fields if not profile_data.get("name"): @@ -559,25 +651,25 @@ users_lock = threading.Lock() # Connection handler @app.socket_connect() -def on_connect(req, sid): - direct_logger.info(f"Client connected: {sid}") - return {"status": "connected", "client_id": sid} +def on_connect(data, client_id): + direct_logger.info(f"Client connected: {client_id}") + return {"status": "connected", "client_id": client_id} # Disconnection handler @app.socket_disconnect() -def on_disconnect(req, sid): +def on_disconnect(data, client_id): # Remove user from rooms with rooms_lock: for room_name, room in list(rooms.items()): - if sid in room["members"]: - room["members"].remove(sid) + if client_id in room["members"]: + room["members"].remove(client_id) # Notify others in the room if room["members"]: app.websocket.broadcast_to_clients( room["members"], { - "user": users.get(sid, {}).get("username", "Anonymous"), + "user": users.get(client_id, {}).get("username", "Anonymous"), "action": "left", "room": room_name }, @@ -586,28 +678,28 @@ def on_disconnect(req, sid): # Remove user with users_lock: - if sid in users: - del users[sid] + if client_id in users: + del users[client_id] - direct_logger.info(f"Client disconnected: {sid}") + direct_logger.info(f"Client disconnected: {client_id}") # User registration @app.socket("register") -def register_user(req, sid): - username = req.data.get("username") +def register_user(data, client_id): + username = data.get("username") if not username: return {"error": "Username is required", "status": "error"} # Store user information with users_lock: - users[sid] = { + users[client_id] = { "username": username, "registered_at": time.time(), "rooms": [] } - direct_logger.info(f"User registered: {username} ({sid})") + direct_logger.info(f"User registered: {username} ({client_id})") return { "status": "registered", @@ -617,39 +709,39 @@ def register_user(req, sid): # Create or join room @app.socket("join_room") -def join_room(req, sid): - room_name = req.data.get("room") +def join_room(data, client_id): + room_name = data.get("room") if not room_name: return {"error": "Room name is required", "status": "error"} # Get username - username = users.get(sid, {}).get("username", "Anonymous") + username = users.get(client_id, {}).get("username", "Anonymous") with rooms_lock: # Create room if it doesn't exist if room_name not in rooms: rooms[room_name] = { "created_at": time.time(), - "created_by": sid, + "created_by": client_id, "members": set(), "messages": [] } direct_logger.info(f"Room created: {room_name} by {username}") # Add user to room - rooms[room_name]["members"].add(sid) + rooms[room_name]["members"].add(client_id) # Add room to user's list with users_lock: - if sid in users and "rooms" in users[sid]: - if room_name not in users[sid]["rooms"]: - users[sid]["rooms"].append(room_name) + if client_id in users and "rooms" in users[client_id]: + if room_name not in users[client_id]["rooms"]: + users[client_id]["rooms"].append(room_name) # Notify others in the room with rooms_lock: room = rooms[room_name] - others = room["members"] - {sid} + others = room["members"] - {client_id} if others: app.websocket.broadcast_to_clients( @@ -674,9 +766,9 @@ def join_room(req, sid): # Send message to room @app.socket("chat_message") -def send_message(req, sid): - room_name = req.data.get("room") - message = req.data.get("message", "").strip() +def send_message(data, client_id): + room_name = data.get("room") + message = data.get("message", "").strip() if not room_name or not message: return {"error": "Room and message are required", "status": "error"} @@ -686,15 +778,15 @@ def send_message(req, sid): if room_name not in rooms: return {"error": "Room does not exist", "status": "error"} - if sid not in rooms[room_name]["members"]: + if client_id not in rooms[room_name]["members"]: return {"error": "Not a member of this room", "status": "error"} # Get username - username = users.get(sid, {}).get("username", "Anonymous") + username = users.get(client_id, {}).get("username", "Anonymous") # Create message object msg = { - "id": f"msg_{time.time()}_{sid[:8]}", + "id": f"msg_{time.time()}_{client_id[:8]}", "room": room_name, "sender": username, "text": message, @@ -709,7 +801,7 @@ def send_message(req, sid): rooms[room_name]["messages"] = rooms[room_name]["messages"][-100:] # Get all members except sender - recipients = list(rooms[room_name]["members"] - {sid}) + recipients = list(rooms[room_name]["members"] - {client_id}) # Broadcast to other room members if recipients: @@ -759,6 +851,17 @@ if __name__ == "__main__": 6. **Reconnection**: Implement reconnection logic on the client side 7. **Testing**: Test with multiple simultaneous connections to ensure scalability +## Deprecation Warnings + +You may see deprecation warnings related to the `websockets` library: + +``` +DeprecationWarning: websockets.server.WebSocketServerProtocol is deprecated +DeprecationWarning: websockets.legacy is deprecated +``` + +These warnings are related to the underlying `websockets` library and do not affect functionality. They will be addressed in future versions of Daebus. You can safely ignore them for now. + ## Troubleshooting ### Connection Issues diff --git a/docs/src/content/docs/overview.md b/docs/src/content/docs/overview.md index 629150e..0c3bc15 100644 --- a/docs/src/content/docs/overview.md +++ b/docs/src/content/docs/overview.md @@ -162,8 +162,8 @@ WebSocket handlers process real-time messages: ```python @app.socket("chat_message") -def handle_chat(req, sid): - message = req.data.get("message", "") +def handle_chat(data, client_id): + message = data.get("message", "") # Process the message return { @@ -265,7 +265,7 @@ def http_get_metrics(req): # WebSocket message handler @app.socket("subscribe_metrics") -def socket_subscribe_metrics(req, sid): +def socket_subscribe_metrics(data, client_id): # Register this client for metric updates # (implementation details omitted) From eb5c10d887f317faf706ba3fefbef819045ca2c6 Mon Sep 17 00:00:00 2001 From: tom-pilotschool Date: Sat, 7 Jun 2025 17:14:23 +1000 Subject: [PATCH 3/3] fixes to docs site --- docs/src/content/docs/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/content/docs/index.mdx b/docs/src/content/docs/index.mdx index cda7082..54b8583 100644 --- a/docs/src/content/docs/index.mdx +++ b/docs/src/content/docs/index.mdx @@ -11,7 +11,7 @@ hero: link: /installation icon: right-arrow - text: Read the Daebus docs - link: /guides/getting-started + link: /getting-started icon: external variant: minimal ---