From 21f04ae22a74742d2610802e860e06da7d404f11 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 16 Jan 2025 21:06:33 -0700 Subject: [PATCH 1/3] feat: support tags in run_workflow --- tests/test_async_client.py | 52 +++++++++++++++++++++++++++++++++++++- tests/test_sync_client.py | 52 +++++++++++++++++++++++++++++++++++++- tws/_async/client.py | 4 +++ tws/_sync/client.py | 4 +++ tws/base/client.py | 15 +++++++++++ 5 files changed, 125 insertions(+), 2 deletions(-) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index be27dac..3dac696 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -54,7 +54,7 @@ async def test_async_client_instantiation_exceptions( [600, "not a number", "Retry delay must be between 1 and 60 seconds"], ], ) -async def test_run_workflow_validation( +async def test_run_workflow_timing_validation( good_async_client, timeout, retry_delay, exception_message ): with pytest.raises(ClientException) as exc_info: @@ -68,6 +68,56 @@ async def test_run_workflow_validation( assert exception_message in str(exc_info.value) +@pytest.mark.parametrize( + "tags,exception_message", + [ + [{"key": 123}, "Tag keys and values must be strings"], + [{"key": "value", "bad_key": 123}, "Tag keys and values must be strings"], + [{123: "value"}, "Tag keys and values must be strings"], + ["not_a_dict", "Tags must be a dictionary"], + [{"x" * 256: "value"}, "Tag keys and values must be <= 255 characters"], + [{"key": "x" * 256}, "Tag keys and values must be <= 255 characters"], + ], +) +async def test_run_workflow_tag_validation(good_async_client, tags, exception_message): + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, tags=tags + ) + assert exception_message in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_async_client): + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success"}} + ] + + valid_tags = {"userId": "someUserId", "lessonId": "someLessonId"} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, tags=valid_tags + ) + + # Verify tags were included in the RPC payload + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": {"arg": "value"}, + "tags": valid_tags, + }, + ) + assert result == {"output": "success"} + + @patch("tws._async.client.AsyncClient._make_rpc_request") async def test_run_workflow_not_found(mock_rpc, good_async_client): mock_request = Request("POST", "http://example.com") diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py index ca7302b..4f96382 100644 --- a/tests/test_sync_client.py +++ b/tests/test_sync_client.py @@ -54,7 +54,9 @@ def test_client_instantiation_exceptions( [600, "not a number", "Retry delay must be between 1 and 60 seconds"], ], ) -def test_run_workflow_validation(good_client, timeout, retry_delay, exception_message): +def test_run_workflow_timing_validation( + good_client, timeout, retry_delay, exception_message +): with pytest.raises(ClientException) as exc_info: with good_client: good_client.run_workflow( @@ -66,6 +68,54 @@ def test_run_workflow_validation(good_client, timeout, retry_delay, exception_me assert exception_message in str(exc_info.value) +@pytest.mark.parametrize( + "tags,exception_message", + [ + [{"key": 123}, "Tag keys and values must be strings"], + [{"key": "value", "bad_key": 123}, "Tag keys and values must be strings"], + [{123: "value"}, "Tag keys and values must be strings"], + ["not_a_dict", "Tags must be a dictionary"], + [{"x" * 256: "value"}, "Tag keys and values must be <= 255 characters"], + [{"key": "x" * 256}, "Tag keys and values must be <= 255 characters"], + ], +) +def test_run_workflow_tag_validation(good_client, tags, exception_message): + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client.run_workflow("workflow-id", {"arg": "value"}, tags=tags) + assert exception_message in str(exc_info.value) + + +@patch("tws._sync.client.SyncClient._make_rpc_request") +@patch("tws._sync.client.SyncClient._make_request") +def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_client): + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success"}} + ] + + valid_tags = {"userId": "someUserId", "lessonId": "someLessonId"} + + with good_client: + result = good_client.run_workflow( + "workflow-id", {"arg": "value"}, tags=valid_tags + ) + + # Verify tags were included in the RPC payload + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": {"arg": "value"}, + "tags": valid_tags, + }, + ) + assert result == {"output": "success"} + + @patch("tws._sync.client.SyncClient._make_rpc_request") def test_run_workflow_not_found(mock_rpc, good_client): mock_request = Request("POST", "http://example.com") diff --git a/tws/_async/client.py b/tws/_async/client.py index 88c0828..edaec11 100644 --- a/tws/_async/client.py +++ b/tws/_async/client.py @@ -103,13 +103,17 @@ async def run_workflow( workflow_args: dict, timeout=600, retry_delay=1, + tags: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) + self._validate_tags(tags) payload = { "workflow_definition_id": workflow_definition_id, "request_body": workflow_args, } + if tags is not None: + payload["tags"] = tags try: result = await self._make_rpc_request("start_workflow", payload) diff --git a/tws/_sync/client.py b/tws/_sync/client.py index 341e7b6..c9b97e8 100644 --- a/tws/_sync/client.py +++ b/tws/_sync/client.py @@ -100,13 +100,17 @@ def run_workflow( workflow_args: dict, timeout=600, retry_delay=1, + tags: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) + self._validate_tags(tags) payload = { "workflow_definition_id": workflow_definition_id, "request_body": workflow_args, } + if tags is not None: + payload["tags"] = tags try: result = self._make_rpc_request("start_workflow", payload) diff --git a/tws/base/client.py b/tws/base/client.py index 0f3ba0e..0c8d739 100644 --- a/tws/base/client.py +++ b/tws/base/client.py @@ -94,6 +94,19 @@ def _check_timeout(start_time: float, timeout: Union[int, float]) -> None: f"Workflow execution timed out after {timeout} seconds" ) + @staticmethod + def _validate_tags(tags: Optional[Dict[str, str]]) -> None: + if tags is not None: + if not isinstance(tags, dict): + raise ClientException("Tags must be a dictionary") + for key, value in tags.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ClientException("Tag keys and values must be strings") + if len(key) > 255 or len(value) > 255: + raise ClientException( + "Tag keys and values must be <= 255 characters" + ) + @abstractmethod def run_workflow( self, @@ -101,6 +114,7 @@ def run_workflow( workflow_args: dict, timeout=600, retry_delay=1, + tags: Optional[Dict[str, str]] = None, ) -> Union[dict, Coroutine[Any, Any, dict]]: """Execute a workflow and wait for it to complete or fail. @@ -109,6 +123,7 @@ def run_workflow( workflow_args: Dictionary of arguments to pass to the workflow timeout: Maximum time in seconds to wait for workflow completion (1-3600) retry_delay: Time in seconds between status checks (1-60) + tags: Optional dictionary of tag key-value pairs to attach to the workflow Returns: The workflow execution result as a dictionary From 6cc0faad6b9b1cf8589159205c399d958f0239e9 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 16 Jan 2025 21:12:21 -0700 Subject: [PATCH 2/3] docs: add tag example usage and explanation --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index 74351ef..6e28856 100644 --- a/README.md +++ b/README.md @@ -61,3 +61,27 @@ async def main(): }, ) ``` + +### Tags + +You can specify tags, which are string key-value pairs, when calling the `run_workflow` method. These tags can then +be used when designing workflows in TWS to lookup and filter the results of workflow runs. This allows you to associate +the results of a workflow run with a specific entity or grouping mechanism within your system, such as a user ID or +a lesson ID. + +Provide tags to the `run_workflow` method as a dictionary. Keep in mind that both tag keys and values must be strings +that are at most 255 characters long. + +```python +tws_client.run_workflow( + workflow_definition_id="your_workflow_id", + workflow_args={ + "param1": "value1", + "param2": "value2" + }, + tags={ + "user_id": "12345", + "lesson_id": "67890" + } +) +``` From f355ce9861e28829e189d0d828d0183b281124db Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 16 Jan 2025 21:14:19 -0700 Subject: [PATCH 3/3] chore: bump version to 0.2.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6458ef3..382eb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tws-sdk" -version = "0.2.0" +version = "0.2.1" description = "TWS client for Python." authors = ["Fireline Science "] homepage = "https://github.com/Fireline-Science/tws-py"