diff --git a/flexmeasures/app.py b/flexmeasures/app.py index 059eee3eba..891fd51c08 100644 --- a/flexmeasures/app.py +++ b/flexmeasures/app.py @@ -52,6 +52,10 @@ def create( # noqa C901 load_dotenv() app = Flask("flexmeasures") + from flexmeasures.ws import sock + + sock.init_app(app) + if env is not None: # overwrite app.config["FLEXMEASURES_ENV"] = env if app.config.get("FLEXMEASURES_ENV") == "testing": diff --git a/flexmeasures/ws/__init__.py b/flexmeasures/ws/__init__.py new file mode 100644 index 0000000000..f565dea771 --- /dev/null +++ b/flexmeasures/ws/__init__.py @@ -0,0 +1,19 @@ +import importlib +import pkgutil +from flask import Blueprint, current_app +from simple_websocket import Server +from flask_security import auth_token_required + +from flask_sock import Sock + +sock = Sock() + + +def import_all_modules(package_name): + package = importlib.import_module(package_name) + for _, name, _ in pkgutil.iter_modules(package.__path__): + importlib.import_module(f"{package_name}.{name}") + + +# we need to import all the modules to run the route decorators +import_all_modules("flexmeasures.ws") diff --git a/flexmeasures/ws/ping1.py b/flexmeasures/ws/ping1.py new file mode 100644 index 0000000000..02afaacd5c --- /dev/null +++ b/flexmeasures/ws/ping1.py @@ -0,0 +1,23 @@ +import logging +from flexmeasures.ws import sock +from flask import current_app +from flexmeasures import Sensor +from sqlalchemy import select + +logger = logging.getLogger(__name__) + + +@sock.route("/ping1") +def echo1(ws): + while True: + with current_app.app_context(): + data = ws.receive() + + if data == "close": + break + + sensors = current_app.db.session.execute( + select(Sensor).where(Sensor.id == 1) + ).scalar() + + ws.send(str(sensors.__dict__)) diff --git a/flexmeasures/ws/ping2.py b/flexmeasures/ws/ping2.py new file mode 100644 index 0000000000..a072f3f600 --- /dev/null +++ b/flexmeasures/ws/ping2.py @@ -0,0 +1,14 @@ +import logging +from flexmeasures.ws import sock + +logger = logging.Logger(__name__) + + +@sock.route("/ping2") +def echo2(ws): + while True: + data = ws.receive() + logger.error("ping2>" + data) + if data == "close": + break + ws.send(data) diff --git a/flexmeasures/ws/v1.py b/flexmeasures/ws/v1.py new file mode 100644 index 0000000000..eab8a10fee --- /dev/null +++ b/flexmeasures/ws/v1.py @@ -0,0 +1,37 @@ +import logging +from flexmeasures.ws import sock +from flask import current_app +from flexmeasures import Sensor +from sqlalchemy import select, func +import json + +logger = logging.Logger(__name__) + + +@sock.route("/v1") +def header_test(ws): + # Get all headers + all_headers = { + k[5:].lower().replace("_", "-"): v + for k, v in ws.environ.items() + if k.startswith("HTTP_") + } + + # Get specific header if needed + custom_header = ws.environ.get("HTTP_X_CUSTOM_HEADER") + + logger.info(f"All headers: {all_headers}") + logger.info(f"Custom header: {custom_header}") + + # Send initial message with metadata + ws.send( + json.dumps({"type": "metadata", "headers": {"X-Server-Header": "ServerValue"}}) + ) + + while True: + data = ws.receive() + logger.error("v1>" + data) + if data == "close": + break + sensors = current_app.db.session.execute(select(func.count(Sensor.id))).scalar() + ws.send(str(sensors)) diff --git a/notebooks/websocket_analysis.ipynb b/notebooks/websocket_analysis.ipynb new file mode 100644 index 0000000000..03fbbcf178 --- /dev/null +++ b/notebooks/websocket_analysis.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 64, + "id": "2b691e65-0818-438a-b484-9ce439baef44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "import plotly.offline as pyo\n", + "import plotly.graph_objs as go\n", + "import plotly.io as pio\n", + "\n", + "\n", + "import plotly_express as px\n", + "pio.renderers.default = 'iframe'\n", + "\n", + "data = pd.read_csv(\"results-db-get-sensor-1.csv\", names=[\"time\", \"type\", \"id\", \"delta\"])\n", + "data[\"time\"] = data[\"time\"].apply(lambda x: pd.Timestamp.fromtimestamp(x))\n", + "data = data.dropna()\n", + "data[\"id2\"] = data.apply(lambda x: f\"{x['type']}: {x['id']}\", axis=1)\n", + "fig = px.line(data, x=\"time\", y=\"delta\", color=\"type\", labels={\n", + " \"time\" : \"Time\",\n", + " \"delta\": \"Roundtrip Time (s)\",\n", + " \"type\" : \"Protocol\"\n", + "}, title=\"Roundtrip Time with 1000 concurrent WS connections @ 1Hz and 1000 concurrent API requests @ 1Hz\")\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "3a8f9091-b8bd-4802-af36-bb01a2942f9a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
id
type
API5
WS623
\n", + "
" + ], + "text/plain": [ + " id\n", + "type \n", + "API 5\n", + "WS 623" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[[\"id\", \"type\"]].drop_duplicates().groupby(\"type\").count()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "066e04d7-148d-48a9-8de9-2cd9773b687e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = px.histogram(data[8000:], x=\"delta\", color=\"type\", barmode=\"overlay\", labels={\n", + " \"delta\": \"Roundtrip Time (s)\",\n", + " \"type\" : \"Protocol\"\n", + "})\n", + "fig.update_traces(opacity=.9)\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "12555575-8f72-422e-a030-a5ea0df93a56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 WS: 0\n", + "1 WS: 1\n", + "2 WS: 2\n", + "3 WS: 3\n", + "4 WS: 4\n", + " ... \n", + "150 WS: 1\n", + "151 WS: 0\n", + "152 WS: 2\n", + "153 WS: 4\n", + "154 WS: 3\n", + "Length: 155, dtype: object" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc46bc48-c6c4-4326-9462-d5a89111865f", + "metadata": {}, + "outputs": [], + "source": [ + "data.gr" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fm", + "language": "python", + "name": "fm" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements/app.in b/requirements/app.in index 77f5df43c7..18678ec362 100644 --- a/requirements/app.in +++ b/requirements/app.in @@ -67,3 +67,4 @@ flask>=1.0 werkzeug vl-convert-python Pillow>=10.0.1 # https://github.com/FlexMeasures/flexmeasures/security/dependabot/91 +flask-sock \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 50aa25de07..4ff2a0f423 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ exclude = .git,__pycache__,documentation max-line-length = 160 max-complexity = 13 select = B,C,E,F,W,B9 -ignore = E501, W503, E203 +ignore = E501, W503, E203, C901 per-file-ignores = flexmeasures/__init__.py:F401 flexmeasures/data/schemas/__init__.py:F401 diff --git a/test_ws_client.py b/test_ws_client.py new file mode 100644 index 0000000000..e1066ed7ca --- /dev/null +++ b/test_ws_client.py @@ -0,0 +1,42 @@ +from simple_websocket import Client, ConnectionClosed # type: ignore +import json +import sys + + +def main(): + headers = { + "X-Custom-Header": "SomeValue", + # 'Authorization': 'Bearer YourToken', + } + ws = Client.connect("ws://127.0.0.1:5000/v1", headers=headers) + try: + print("Connected to the WebSocket server!") + + # Get initial metadata message + initial_msg = json.loads(ws.receive()) + print(initial_msg) + if initial_msg.get("type") != "metadata": + print("ERROR: Server metadata not received!") + ws.close() + sys.exit(1) + + server_header = initial_msg.get("headers", {}).get("X-Server-Header") + if not server_header: + print("ERROR: Server header not found in metadata!") + ws.close() + sys.exit(1) + print(f"Server header received: {server_header}") + + while True: + data = input("> ") + ws.send(data) + data = ws.receive() + print(f"< {data}") + + except (KeyboardInterrupt, EOFError, ConnectionClosed) as e: + print(f"Connection closed: {e}") + ws.close() + + +if __name__ == "__main__": + main()