diff --git a/poetry.lock b/poetry.lock index e925466..253980f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "anyio" version = "4.8.0" @@ -431,4 +442,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0cad654a3a95aeb2eabaaf9a2e1edfb9a87ef300b96d79f62f13c6c85d0e89ac" +content-hash = "c1c0f98276102a1ae18e71363caef0d08485c03f3427f7cdcc0af8f7ede3017b" diff --git a/pyproject.toml b/pyproject.toml index 382eb4a..77ae75d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" httpx = {extras = ["http2"], version = ">=0.26,<0.29"} +aiofiles = "^24.1.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.4" diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 3dac696..9dd874a 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -88,6 +88,42 @@ async def test_run_workflow_tag_validation(good_async_client, tags, exception_me assert exception_message in str(exc_info.value) +@pytest.mark.parametrize( + "files,exception_message", + [ + [{"key": 123}, "File values must be file paths (strings)"], + [{"key": "value", "bad_key": 123}, "File values must be file paths (strings)"], + [{123: "value"}, "File keys must be strings"], + ["not_a_dict", "Files must be a dictionary"], + [{}, None], # Empty dict is valid + ], +) +async def test_run_workflow_files_validation( + good_async_client, files, exception_message +): + if exception_message: + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + assert exception_message in str(exc_info.value) + else: + # This should not raise an exception + async with good_async_client: + with patch("tws._async.client.AsyncClient._make_rpc_request") as mock_rpc: + with patch( + "tws._async.client.AsyncClient._make_request" + ) as mock_request: + mock_rpc.return_value = {"workflow_instance_id": "123"} + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success"}} + ] + await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + @patch("tws._async.client.AsyncClient._make_rpc_request") @patch("tws._async.client.AsyncClient._make_request") async def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_async_client): @@ -118,6 +154,51 @@ async def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_async_c assert result == {"output": "success"} +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_files( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock file upload + mock_upload.return_value = "user-123/timestamp-test_file.txt" + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with file"}} + ] + + files = {"input_file": str(test_file)} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + # Verify file was uploaded + mock_upload.assert_called_once_with(str(test_file)) + + # Verify the file path was merged into workflow args + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file": "user-123/timestamp-test_file.txt", + }, + }, + ) + assert result == {"output": "success with file"} + + @patch("tws._async.client.AsyncClient._make_rpc_request") async def test_run_workflow_not_found(mock_rpc, good_async_client): mock_request = Request("POST", "http://example.com") @@ -249,6 +330,7 @@ def mock_json(): "/rest/v1/test/endpoint", json={"param": "value"}, params={"query": "param"}, + files=None, ) assert result == {"data": "test"} @@ -264,6 +346,63 @@ async def test_make_request_error(mock_request, good_async_client): assert "Request error occurred: Network error" in str(exc_info.value) +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_success(mock_request, good_async_client): + # Mock successful user ID lookup + mock_request.return_value = [{"user_id": "test-user-123"}] + + async with good_async_client: + user_id = await good_async_client._lookup_user_id() + + # Verify the request was made correctly + mock_request.assert_called_once_with( + "GET", + "users_private", + params={ + "select": "user_id", + "api_key": f"eq.{good_async_client.session.headers['X-TWS-API-KEY']}", + }, + ) + + # Verify the user ID was returned and cached + assert user_id == "test-user-123" + assert good_async_client.user_id == "test-user-123" + + # Reset the mock and call again to verify caching + mock_request.reset_mock() + + # Second call should use cached value + user_id_again = await good_async_client._lookup_user_id() + assert user_id_again == "test-user-123" + + # Verify no additional request was made + mock_request.assert_not_called() + + +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_empty_response(mock_request, good_async_client): + # Mock empty response (no user found) + mock_request.return_value = [] + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._lookup_user_id() + + assert "User ID not found, is your API key correct?" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_request_error(mock_request, good_async_client): + # Mock request error + mock_request.side_effect = Exception("Database connection error") + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._lookup_user_id() + + assert "Failed to look up user ID: Database connection error" in str(exc_info.value) + + @patch("tws._async.client.AsyncClient._make_request") async def test_make_rpc_request_success(mock_request, good_async_client): mock_request.return_value = {"result": "success"} @@ -311,3 +450,149 @@ async def test_run_workflow_timeout( "workflow-id", {"arg": "value"}, timeout=600 ) assert "Workflow execution timed out after 600 seconds" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._lookup_user_id") +@patch("tws._async.client.AsyncClient._make_request") +async def test_upload_file_success( + mock_request, mock_lookup_user_id, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-456" + + # Mock successful file upload + mock_request.return_value = { + "Key": "documents/test-user-456/timestamp-test_file.txt" + } + + async with good_async_client: + file_path = await good_async_client._upload_file(str(test_file)) + + # Verify the correct path is returned (without the documents/ prefix) + assert file_path == "test-user-456/timestamp-test_file.txt" + + # Verify the file upload request was made with the correct parameters + assert mock_request.call_count == 1 + # We can't check the exact file content in the call args because it's dynamic, + # but we can verify the endpoint and service + call_args = mock_request.call_args + assert call_args[0][0] == "POST" # HTTP method + assert "object/documents/test-user-456/" in call_args[0][1] # URI + assert call_args[1]["service"] == "storage" # service parameter + + +@patch("tws._async.client.AsyncClient._lookup_user_id") +async def test_upload_file_nonexistent_file(mock_lookup_user_id, good_async_client): + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-456" + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._upload_file("/nonexistent/file.txt") + + assert "File not found: /nonexistent/file.txt" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_multiple_files( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create temporary test files + test_file1 = tmp_path / "test_file1.txt" + test_file1.write_text("test content 1") + + test_file2 = tmp_path / "test_file2.txt" + test_file2.write_text("test content 2") + + # Mock file uploads with different return values for each file + mock_upload.side_effect = [ + "user-123/timestamp-test_file1.txt", + "user-123/timestamp-test_file2.txt", + ] + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with multiple files"}} + ] + + files = {"input_file1": str(test_file1), "input_file2": str(test_file2)} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + # Verify both files were uploaded + assert mock_upload.call_count == 2 + mock_upload.assert_any_call(str(test_file1)) + mock_upload.assert_any_call(str(test_file2)) + + # Verify the file paths were merged into workflow args + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file1": "user-123/timestamp-test_file1.txt", + "input_file2": "user-123/timestamp-test_file2.txt", + }, + }, + ) + assert result == {"output": "success with multiple files"} + + +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_files_and_tags( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock file upload + mock_upload.return_value = "user-123/timestamp-test_file.txt" + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with file and tags"}} + ] + + files = {"input_file": str(test_file)} + tags = {"tag1": "value1", "tag2": "value2"} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files, tags=tags + ) + + # Verify file was uploaded + mock_upload.assert_called_once_with(str(test_file)) + + # Verify the file path was merged into workflow args and tags were included + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file": "user-123/timestamp-test_file.txt", + }, + "tags": tags, + }, + ) + assert result == {"output": "success with file and tags"} diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py index 4f96382..9461a28 100644 --- a/tests/test_sync_client.py +++ b/tests/test_sync_client.py @@ -243,6 +243,7 @@ def mock_json(): "/rest/v1/test/endpoint", json={"param": "value"}, params={"query": "param"}, + files=None, ) assert result == {"data": "test"} @@ -282,6 +283,57 @@ def test_make_rpc_request_without_payload(mock_request, good_client): assert result == {"result": "success"} +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_success(mock_request, good_client): + # Mock successful user ID lookup + mock_request.return_value = [{"user_id": "test-user-123"}] + + with good_client: + # First call should make the request + user_id = good_client._lookup_user_id() + # Second call should use cached value + user_id_cached = good_client._lookup_user_id() + + # Verify request was made with correct parameters + mock_request.assert_called_once_with( + "GET", + "users_private", + params={ + "select": "user_id", + "api_key": f"eq.{good_client.session.headers['X-TWS-API-KEY']}", + }, + ) + + assert user_id == "test-user-123" + assert user_id_cached == "test-user-123" + assert good_client.user_id == "test-user-123" + assert mock_request.call_count == 1 # Should only be called once due to caching + + +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_empty_response(mock_request, good_client): + # Mock empty response + mock_request.return_value = [] + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._lookup_user_id() + + assert "User ID not found, is your API key correct?" in str(exc_info.value) + + +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_request_error(mock_request, good_client): + # Mock request error + mock_request.side_effect = Exception("Network error") + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._lookup_user_id() + + assert "Failed to look up user ID: Network error" in str(exc_info.value) + + @patch("tws._sync.client.SyncClient._make_rpc_request") @patch("tws._sync.client.SyncClient._make_request") @patch("time.time") @@ -299,3 +351,108 @@ def test_run_workflow_timeout(mock_time, mock_request, mock_rpc, good_client): with good_client: good_client.run_workflow("workflow-id", {"arg": "value"}, timeout=600) assert "Workflow execution timed out after 600 seconds" in str(exc_info.value) + + +@patch("tws._sync.client.SyncClient._lookup_user_id") +@patch("tws._sync.client.SyncClient._make_request") +def test_run_workflow_with_file_upload( + mock_request, mock_lookup_user_id, good_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-123" + + # Mock file upload response + mock_request.side_effect = [ + # First call for file upload + {"Key": "documents/test-user-123/timestamp-test_file.txt"}, + # Second call for workflow status + [{"status": "COMPLETED", "result": {"output": "success"}}], + ] + + # Mock RPC request for workflow start + with patch("tws._sync.client.SyncClient._make_rpc_request") as mock_rpc: + mock_rpc.return_value = {"workflow_instance_id": "123"} + + with good_client: + result = good_client.run_workflow( + "workflow-id", {"arg": "value"}, files={"file_arg": str(test_file)} + ) + + # Verify the file was uploaded and the file path was included in workflow args + mock_rpc.assert_called_once() + workflow_args = mock_rpc.call_args[0][1]["request_body"] + assert "file_arg" in workflow_args + assert workflow_args["file_arg"] == "test-user-123/timestamp-test_file.txt" + assert result == {"output": "success"} + + +@patch("os.path.exists") +def test_upload_file_not_found(mock_exists, good_client): + # Mock file not found + mock_exists.return_value = False + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/nonexistent/file.txt") + + assert "File not found: /path/to/nonexistent/file.txt" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("builtins.open", side_effect=IOError("Permission denied")) +def test_upload_file_open_error(mock_open, mock_exists, good_client): + # Mock file exists but can't be opened + mock_exists.return_value = True + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/file.txt") + + assert "File upload failed: Permission denied" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("builtins.open") +@patch("tws._sync.client.SyncClient._lookup_user_id") +@patch("tws._sync.client.SyncClient._make_request") +def test_upload_file_api_error( + mock_request, mock_lookup_user_id, mock_open, mock_exists, good_client +): + # Mock file exists and can be opened + mock_exists.return_value = True + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-123" + + # Mock API error during upload + mock_request.side_effect = Exception("API error") + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/file.txt") + + assert "File upload failed: API error" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("time.time") +def test_run_workflow_with_file_upload_error(mock_time, mock_exists, good_client): + # Mock time for filename generation + mock_time.return_value = 1234567890 + + # Mock file not found + mock_exists.return_value = False + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client.run_workflow( + "workflow-id", + {"arg": "value"}, + files={"file_arg": "/path/to/nonexistent/file.txt"}, + ) + + assert "File not found: /path/to/nonexistent/file.txt" in str(exc_info.value) diff --git a/tws/_async/client.py b/tws/_async/client.py index edaec11..fbed8d4 100644 --- a/tws/_async/client.py +++ b/tws/_async/client.py @@ -1,11 +1,14 @@ import asyncio +import mimetypes +import os import time from typing import cast, Dict, Optional +import aiofiles import httpx from httpx import AsyncClient as AsyncHttpClient -from tws.base.client import TWSClient, ClientException +from tws.base.client import TWS_API_KEY_HEADER, TWSClient, ClientException class AsyncClient(TWSClient): @@ -53,12 +56,40 @@ async def __aexit__(self, exc_type, exc, tb) -> None: # Close the underlying HTTP session await self.session.aclose() + async def _lookup_user_id(self) -> str: + """Look up the user ID associated with the API key. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + if self.user_id is None: + params = { + "select": "user_id", + "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}", + } + try: + response = await self._make_request( + "GET", "users_private", params=params + ) + if not response or len(response) == 0: + raise ClientException("User ID not found, is your API key correct?") + self.user_id = response[0]["user_id"] + except Exception as e: + raise ClientException(f"Failed to look up user ID: {e}") + + return self.user_id + async def _make_request( self, method: str, uri: str, payload: Optional[dict] = None, params: Optional[dict] = None, + files: Optional[dict] = None, + service: str = "rest", ): """Make a HTTP request to the TWS API. @@ -76,7 +107,7 @@ async def _make_request( """ try: response = await self.session.request( - method, f"/rest/v1/{uri}", json=payload, params=params + method, f"/{service}/v1/{uri}", json=payload, params=params, files=files ) response.raise_for_status() return response.json() @@ -97,6 +128,49 @@ async def _make_rpc_request( """ return await self._make_request("POST", f"rpc/{function_name}", payload) + async def _upload_file(self, file_path: str) -> str: + """Upload a file to the TWS API asynchronously. + + Args: + file_path: Path to the file to upload + + Returns: + File path that can be used in workflow arguments + + Raises: + ClientException: If the file upload fails + """ + try: + if not os.path.exists(file_path): + raise ClientException(f"File not found: {file_path}") + + filename = os.path.basename(file_path) + unique_filename = f"{int(time.time())}-{filename}" + + # Detect MIME type based on file extension + content_type, _ = mimetypes.guess_type(file_path) + + async with aiofiles.open(file_path, "rb") as file_obj: + file_content = await file_obj.read() + user_id = await self._lookup_user_id() + + # Since httpx can't handle the aiofiles file object, we have to + # explicitly construct the tuple so it sends the MIME type + files = {"upload-file": (filename, file_content, content_type)} + + response = await self._make_request( + "POST", + f"object/documents/{user_id}/{unique_filename}", + files=files, + service="storage", + ) + + file_url = response["Key"] + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/") :] + except Exception as e: + raise ClientException(f"File upload failed: {e}") + async def run_workflow( self, workflow_definition_id: str, @@ -104,13 +178,26 @@ async def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) self._validate_tags(tags) + self._validate_files(files) + + # Create a copy of workflow_args to avoid modifying the original + merged_args = workflow_args.copy() + + # Handle file uploads if provided + if files: + for arg_name, file_path in files.items(): + # Upload the file and get a file ID + file_url = await self._upload_file(file_path) + # Merge the file ID into the workflow arguments + merged_args[arg_name] = file_url payload = { "workflow_definition_id": workflow_definition_id, - "request_body": workflow_args, + "request_body": merged_args, } if tags is not None: payload["tags"] = tags diff --git a/tws/_sync/client.py b/tws/_sync/client.py index c9b97e8..a9e75dc 100644 --- a/tws/_sync/client.py +++ b/tws/_sync/client.py @@ -1,10 +1,11 @@ +import os import time from typing import Dict, cast, Optional import httpx from httpx import Client as SyncHttpClient -from tws.base.client import TWSClient, ClientException +from tws.base.client import TWS_API_KEY_HEADER, TWSClient, ClientException class SyncClient(TWSClient): @@ -52,12 +53,38 @@ def __exit__(self, exc_type, exc, tb) -> None: # Close the underlying HTTP session self.session.close() + def _lookup_user_id(self) -> str: + """Look up the user ID associated with the API key. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + if self.user_id is None: + params = { + "select": "user_id", + "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}", + } + try: + response = self._make_request("GET", "users_private", params=params) + if not response or len(response) == 0: + raise ClientException("User ID not found, is your API key correct?") + self.user_id = response[0]["user_id"] + except Exception as e: + raise ClientException(f"Failed to look up user ID: {e}") + + return self.user_id + def _make_request( self, method: str, uri: str, payload: Optional[dict] = None, params: Optional[dict] = None, + files: Optional[dict] = None, + service: str = "rest", ): """Make a HTTP request to the TWS API. @@ -75,7 +102,7 @@ def _make_request( """ try: response = self.session.request( - method, f"/rest/v1/{uri}", json=payload, params=params + method, f"/{service}/v1/{uri}", json=payload, params=params, files=files ) response.raise_for_status() return response.json() @@ -94,6 +121,40 @@ def _make_rpc_request(self, function_name: str, payload: Optional[dict] = None): """ return self._make_request("POST", f"rpc/{function_name}", payload) + def _upload_file(self, file_path: str) -> str: + """Upload a file to the TWS API. + + Args: + file_path: Path to the file to upload + + Returns: + File path that can be used in workflow arguments + + Raises: + ClientException: If the file upload fails + """ + try: + if not os.path.exists(file_path): + raise ClientException(f"File not found: {file_path}") + filename = os.path.basename(file_path) + + with open(file_path, "rb") as file_obj: + # Upload the file to get a file URL + unique_filename = f"{int(time.time())}-{filename}" + user_id = self._lookup_user_id() + response = self._make_request( + "POST", + f"object/documents/{user_id}/{unique_filename}", + files={"upload-file": file_obj}, + service="storage", + ) + + file_url = response["Key"] + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/") :] + except Exception as e: + raise ClientException(f"File upload failed: {e}") + def run_workflow( self, workflow_definition_id: str, @@ -101,13 +162,26 @@ def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) self._validate_tags(tags) + self._validate_files(files) + + # Create a copy of workflow_args to avoid modifying the original + merged_args = workflow_args.copy() + + # Handle file uploads if provided + if files: + for arg_name, file_path in files.items(): + # Upload the file and get a file ID + file_url = self._upload_file(file_path) + # Merge the file ID into the workflow arguments + merged_args[arg_name] = file_url payload = { "workflow_definition_id": workflow_definition_id, - "request_body": workflow_args, + "request_body": merged_args, } if tags is not None: payload["tags"] = tags diff --git a/tws/base/client.py b/tws/base/client.py index 0c8d739..e982dde 100644 --- a/tws/base/client.py +++ b/tws/base/client.py @@ -8,6 +8,8 @@ from tws.utils import is_valid_jwt +TWS_API_KEY_HEADER = "X-TWS-API-KEY" + class ClientException(Exception): def __init__(self, message: str): @@ -46,11 +48,12 @@ def __init__( base_url = api_url.rstrip("/") headers = { - "Authorization": secret_key, + "Authorization": f"Bearer {public_key}", "apikey": public_key, - "Content-Type": "application/json", + TWS_API_KEY_HEADER: secret_key, } self.session = self.create_session(base_url, headers) + self.user_id = None @abstractmethod def create_session( @@ -107,6 +110,41 @@ def _validate_tags(tags: Optional[Dict[str, str]]) -> None: "Tag keys and values must be <= 255 characters" ) + @abstractmethod + def _lookup_user_id(self) -> Union[str, Coroutine[Any, Any, str]]: + """Look up the user ID associated with the API key. + + Lazily fetches and caches the user ID if it hasn't been retrieved yet. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + raise NotImplementedError() + + @staticmethod + def _validate_files(files: Optional[Dict[str, str]]) -> None: + """Validate file upload parameters. + + Args: + files: Dictionary mapping argument names to file paths + + Raises: + ClientException: If files parameter is invalid + """ + if files is not None: + if not isinstance(files, dict): + raise ClientException("Files must be a dictionary") + for key, value in files.items(): + if not isinstance(key, str): + raise ClientException("File keys must be strings") + + # Validate that the value is a string (file path) + if not isinstance(value, str): + raise ClientException("File values must be file paths (strings)") + @abstractmethod def run_workflow( self, @@ -115,6 +153,7 @@ def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ) -> Union[dict, Coroutine[Any, Any, dict]]: """Execute a workflow and wait for it to complete or fail. @@ -124,6 +163,7 @@ def run_workflow( timeout: Maximum time in seconds to wait for workflow completion (1-3600) retry_delay: Time in seconds between status checks (1-60) tags: Optional dictionary of tag key-value pairs to attach to the workflow + files: Optional dictionary mapping workflow argument names to file paths Returns: The workflow execution result as a dictionary