From 1d821b20fc86195271162b6e7237563046a5789c Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 10:39:30 -0700 Subject: [PATCH 1/6] feat: switch to new auth approach Leverages a new header for sending API key auth, to support authentication for storage requests too --- tws/base/client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tws/base/client.py b/tws/base/client.py index 0c8d739..2ebbd79 100644 --- a/tws/base/client.py +++ b/tws/base/client.py @@ -8,6 +8,9 @@ from tws.utils import is_valid_jwt +# Constant for the TWS API key header name +TWS_API_KEY_HEADER = "X-TWS-API-KEY" + class ClientException(Exception): def __init__(self, message: str): @@ -46,9 +49,9 @@ def __init__( base_url = api_url.rstrip("/") headers = { - "Authorization": secret_key, + "Authorization": f"Bearer {public_key}", "apikey": public_key, - "Content-Type": "application/json", + TWS_API_KEY_HEADER: secret_key, } self.session = self.create_session(base_url, headers) From e85d6609a6cbe3898f9040f5e566f2e7039e6090 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 11:17:26 -0700 Subject: [PATCH 2/6] feat: support files as workflow arguments Uploads the files to the storage service and injects them into the workflow arguments by file path --- poetry.lock | 13 +++++- pyproject.toml | 1 + tws/_async/client.py | 95 ++++++++++++++++++++++++++++++++++++++++++-- tws/_sync/client.py | 78 ++++++++++++++++++++++++++++++++++-- tws/base/client.py | 39 +++++++++++++++++- 5 files changed, 218 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index e925466..253980f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "anyio" version = "4.8.0" @@ -431,4 +442,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0cad654a3a95aeb2eabaaf9a2e1edfb9a87ef300b96d79f62f13c6c85d0e89ac" +content-hash = "c1c0f98276102a1ae18e71363caef0d08485c03f3427f7cdcc0af8f7ede3017b" diff --git a/pyproject.toml b/pyproject.toml index 382eb4a..77ae75d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" httpx = {extras = ["http2"], version = ">=0.26,<0.29"} +aiofiles = "^24.1.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.4" diff --git a/tws/_async/client.py b/tws/_async/client.py index edaec11..ade5653 100644 --- a/tws/_async/client.py +++ b/tws/_async/client.py @@ -1,11 +1,14 @@ import asyncio +import mimetypes +import os import time from typing import cast, Dict, Optional +import aiofiles import httpx from httpx import AsyncClient as AsyncHttpClient -from tws.base.client import TWSClient, ClientException +from tws.base.client import TWS_API_KEY_HEADER, TWSClient, ClientException class AsyncClient(TWSClient): @@ -53,12 +56,35 @@ async def __aexit__(self, exc_type, exc, tb) -> None: # Close the underlying HTTP session await self.session.aclose() + async def _lookup_user_id(self) -> str: + """Look up the user ID associated with the API key. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + if self.user_id is None: + params = {"select": "user_id", "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}"} + try: + response = await self._make_request("GET", "users_private", params=params) + if not response or len(response) == 0: + raise ClientException("User ID not found, is your API key correct?") + self.user_id = response[0]["user_id"] + except Exception as e: + raise ClientException(f"Failed to look up user ID: {e}") + + return self.user_id + async def _make_request( self, method: str, uri: str, payload: Optional[dict] = None, params: Optional[dict] = None, + files: Optional[dict] = None, + service: str = "rest", ): """Make a HTTP request to the TWS API. @@ -76,7 +102,7 @@ async def _make_request( """ try: response = await self.session.request( - method, f"/rest/v1/{uri}", json=payload, params=params + method, f"/{service}/v1/{uri}", json=payload, params=params, files=files ) response.raise_for_status() return response.json() @@ -96,6 +122,56 @@ async def _make_rpc_request( Parsed JSON response from the API """ return await self._make_request("POST", f"rpc/{function_name}", payload) + + async def _upload_file(self, file_path: str) -> str: + """Upload a file to the TWS API asynchronously. + + Args: + file_path: Path to the file to upload + + Returns: + File path that can be used in workflow arguments + + Raises: + ClientException: If the file upload fails + """ + try: + if not os.path.exists(file_path): + raise ClientException(f"File not found: {file_path}") + + filename = os.path.basename(file_path) + unique_filename = f"{int(time.time())}-{filename}" + + # Detect MIME type based on file extension + content_type, _ = mimetypes.guess_type(file_path) + if content_type is None: + content_type = "application/octet-stream" + + async with aiofiles.open(file_path, 'rb') as file_obj: + file_content = await file_obj.read() + user_id = await self._lookup_user_id() + + # Since httpx can't handle the aiofiles file object, we have to + # explicitly construct the tuple so it sends the MIME type + files = { + "upload-file": (filename, file_content, content_type) + } + + response = await self._make_request( + "POST", + f"object/documents/{user_id}/{unique_filename}", + files=files, + service="storage" + ) + + file_url = response["Key"] + if file_url.startswith("documents/"): + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/"):] + + return file_url + except Exception as e: + raise ClientException(f"File upload failed: {e}") async def run_workflow( self, @@ -104,13 +180,26 @@ async def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) self._validate_tags(tags) + self._validate_files(files) + + # Create a copy of workflow_args to avoid modifying the original + merged_args = workflow_args.copy() + + # Handle file uploads if provided + if files: + for arg_name, file_path in files.items(): + # Upload the file and get a file ID + file_url = await self._upload_file(file_path) + # Merge the file ID into the workflow arguments + merged_args[arg_name] = file_url payload = { "workflow_definition_id": workflow_definition_id, - "request_body": workflow_args, + "request_body": merged_args, } if tags is not None: payload["tags"] = tags diff --git a/tws/_sync/client.py b/tws/_sync/client.py index c9b97e8..cef8e65 100644 --- a/tws/_sync/client.py +++ b/tws/_sync/client.py @@ -1,10 +1,11 @@ +import os import time from typing import Dict, cast, Optional import httpx from httpx import Client as SyncHttpClient -from tws.base.client import TWSClient, ClientException +from tws.base.client import TWS_API_KEY_HEADER, TWSClient, ClientException class SyncClient(TWSClient): @@ -52,12 +53,35 @@ def __exit__(self, exc_type, exc, tb) -> None: # Close the underlying HTTP session self.session.close() + def _lookup_user_id(self) -> str: + """Look up the user ID associated with the API key. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + if self.user_id is None: + params = {"select": "user_id", "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}"} + try: + response = self._make_request("GET", "users_private", params=params) + if not response or len(response) == 0: + raise ClientException("User ID not found, is your API key correct?") + self.user_id = response[0]["user_id"] + except Exception as e: + raise ClientException(f"Failed to look up user ID: {e}") + + return self.user_id + def _make_request( self, method: str, uri: str, payload: Optional[dict] = None, params: Optional[dict] = None, + files: Optional[dict] = None, + service: str = "rest", ): """Make a HTTP request to the TWS API. @@ -75,7 +99,7 @@ def _make_request( """ try: response = self.session.request( - method, f"/rest/v1/{uri}", json=payload, params=params + method, f"/{service}/v1/{uri}", json=payload, params=params, files=files ) response.raise_for_status() return response.json() @@ -94,6 +118,41 @@ def _make_rpc_request(self, function_name: str, payload: Optional[dict] = None): """ return self._make_request("POST", f"rpc/{function_name}", payload) + + def _upload_file(self, file_path: str) -> str: + """Upload a file to the TWS API. + + Args: + file_path: Path to the file to upload + + Returns: + File path that can be used in workflow arguments + + Raises: + ClientException: If the file upload fails + """ + try: + if not os.path.exists(file_path): + raise ClientException(f"File not found: {file_path}") + filename = os.path.basename(file_path) + + with open(file_path, 'rb') as file_obj: + # Upload the file to get a file URL + unique_filename = f"{int(time.time())}-{filename}" + user_id = self._lookup_user_id() + response = self._make_request("POST", f"object/documents/{user_id}/{unique_filename}", files={ + "upload-file": file_obj + }, service="storage") + + file_url = response["Key"] + if file_url.startswith("documents/"): + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/"):] + + return file_url + except Exception as e: + raise ClientException(f"File upload failed: {e}") + def run_workflow( self, workflow_definition_id: str, @@ -101,13 +160,26 @@ def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ): self._validate_workflow_params(timeout, retry_delay) self._validate_tags(tags) + self._validate_files(files) + + # Create a copy of workflow_args to avoid modifying the original + merged_args = workflow_args.copy() + + # Handle file uploads if provided + if files: + for arg_name, file_path in files.items(): + # Upload the file and get a file ID + file_url = self._upload_file(file_path) + # Merge the file ID into the workflow arguments + merged_args[arg_name] = file_url payload = { "workflow_definition_id": workflow_definition_id, - "request_body": workflow_args, + "request_body": merged_args, } if tags is not None: payload["tags"] = tags diff --git a/tws/base/client.py b/tws/base/client.py index 2ebbd79..b2e8f48 100644 --- a/tws/base/client.py +++ b/tws/base/client.py @@ -8,7 +8,6 @@ from tws.utils import is_valid_jwt -# Constant for the TWS API key header name TWS_API_KEY_HEADER = "X-TWS-API-KEY" @@ -54,6 +53,7 @@ def __init__( TWS_API_KEY_HEADER: secret_key, } self.session = self.create_session(base_url, headers) + self.user_id = None @abstractmethod def create_session( @@ -109,6 +109,41 @@ def _validate_tags(tags: Optional[Dict[str, str]]) -> None: raise ClientException( "Tag keys and values must be <= 255 characters" ) + + @abstractmethod + def _lookup_user_id(self) -> Union[str, Coroutine[Any, Any, str]]: + """Look up the user ID associated with the API key. + + Lazily fetches and caches the user ID if it hasn't been retrieved yet. + + Returns: + The user ID string + + Raises: + ClientException: If the user ID cannot be found + """ + raise NotImplementedError() + + @staticmethod + def _validate_files(files: Optional[Dict[str, str]]) -> None: + """Validate file upload parameters. + + Args: + files: Dictionary mapping argument names to file paths + + Raises: + ClientException: If files parameter is invalid + """ + if files is not None: + if not isinstance(files, dict): + raise ClientException("Files must be a dictionary") + for key, value in files.items(): + if not isinstance(key, str): + raise ClientException("File keys must be strings") + + # Validate that the value is a string (file path) + if not isinstance(value, str): + raise ClientException("File values must be file paths (strings)") @abstractmethod def run_workflow( @@ -118,6 +153,7 @@ def run_workflow( timeout=600, retry_delay=1, tags: Optional[Dict[str, str]] = None, + files: Optional[Dict[str, str]] = None, ) -> Union[dict, Coroutine[Any, Any, dict]]: """Execute a workflow and wait for it to complete or fail. @@ -127,6 +163,7 @@ def run_workflow( timeout: Maximum time in seconds to wait for workflow completion (1-3600) retry_delay: Time in seconds between status checks (1-60) tags: Optional dictionary of tag key-value pairs to attach to the workflow + files: Optional dictionary mapping workflow argument names to file paths Returns: The workflow execution result as a dictionary From 9149af52c0dea28e0013de249eec7e457e6b94f3 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 11:19:33 -0700 Subject: [PATCH 3/6] style: ruff format --- tws/_async/client.py | 55 +++++++++++++++++++++++--------------------- tws/_sync/client.py | 41 ++++++++++++++++++--------------- tws/base/client.py | 16 ++++++------- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/tws/_async/client.py b/tws/_async/client.py index ade5653..3a9a286 100644 --- a/tws/_async/client.py +++ b/tws/_async/client.py @@ -58,25 +58,30 @@ async def __aexit__(self, exc_type, exc, tb) -> None: async def _lookup_user_id(self) -> str: """Look up the user ID associated with the API key. - + Returns: The user ID string - + Raises: ClientException: If the user ID cannot be found """ if self.user_id is None: - params = {"select": "user_id", "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}"} + params = { + "select": "user_id", + "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}", + } try: - response = await self._make_request("GET", "users_private", params=params) + response = await self._make_request( + "GET", "users_private", params=params + ) if not response or len(response) == 0: raise ClientException("User ID not found, is your API key correct?") self.user_id = response[0]["user_id"] except Exception as e: raise ClientException(f"Failed to look up user ID: {e}") - + return self.user_id - + async def _make_request( self, method: str, @@ -122,16 +127,16 @@ async def _make_rpc_request( Parsed JSON response from the API """ return await self._make_request("POST", f"rpc/{function_name}", payload) - + async def _upload_file(self, file_path: str) -> str: """Upload a file to the TWS API asynchronously. - + Args: file_path: Path to the file to upload - + Returns: File path that can be used in workflow arguments - + Raises: ClientException: If the file upload fails """ @@ -141,34 +146,32 @@ async def _upload_file(self, file_path: str) -> str: filename = os.path.basename(file_path) unique_filename = f"{int(time.time())}-{filename}" - + # Detect MIME type based on file extension content_type, _ = mimetypes.guess_type(file_path) if content_type is None: content_type = "application/octet-stream" - - async with aiofiles.open(file_path, 'rb') as file_obj: + + async with aiofiles.open(file_path, "rb") as file_obj: file_content = await file_obj.read() user_id = await self._lookup_user_id() - + # Since httpx can't handle the aiofiles file object, we have to # explicitly construct the tuple so it sends the MIME type - files = { - "upload-file": (filename, file_content, content_type) - } - + files = {"upload-file": (filename, file_content, content_type)} + response = await self._make_request( - "POST", - f"object/documents/{user_id}/{unique_filename}", - files=files, - service="storage" + "POST", + f"object/documents/{user_id}/{unique_filename}", + files=files, + service="storage", ) - + file_url = response["Key"] if file_url.startswith("documents/"): # Strip the prefix, as the workflow automatically looks in the bucket - return file_url[len("documents/"):] - + return file_url[len("documents/") :] + return file_url except Exception as e: raise ClientException(f"File upload failed: {e}") @@ -188,7 +191,7 @@ async def run_workflow( # Create a copy of workflow_args to avoid modifying the original merged_args = workflow_args.copy() - + # Handle file uploads if provided if files: for arg_name, file_path in files.items(): diff --git a/tws/_sync/client.py b/tws/_sync/client.py index cef8e65..641d183 100644 --- a/tws/_sync/client.py +++ b/tws/_sync/client.py @@ -55,15 +55,18 @@ def __exit__(self, exc_type, exc, tb) -> None: def _lookup_user_id(self) -> str: """Look up the user ID associated with the API key. - + Returns: The user ID string - + Raises: ClientException: If the user ID cannot be found """ if self.user_id is None: - params = {"select": "user_id", "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}"} + params = { + "select": "user_id", + "api_key": f"eq.{self.session.headers[TWS_API_KEY_HEADER]}", + } try: response = self._make_request("GET", "users_private", params=params) if not response or len(response) == 0: @@ -71,9 +74,9 @@ def _lookup_user_id(self) -> str: self.user_id = response[0]["user_id"] except Exception as e: raise ClientException(f"Failed to look up user ID: {e}") - + return self.user_id - + def _make_request( self, method: str, @@ -118,16 +121,15 @@ def _make_rpc_request(self, function_name: str, payload: Optional[dict] = None): """ return self._make_request("POST", f"rpc/{function_name}", payload) - def _upload_file(self, file_path: str) -> str: """Upload a file to the TWS API. - + Args: file_path: Path to the file to upload - + Returns: File path that can be used in workflow arguments - + Raises: ClientException: If the file upload fails """ @@ -135,20 +137,23 @@ def _upload_file(self, file_path: str) -> str: if not os.path.exists(file_path): raise ClientException(f"File not found: {file_path}") filename = os.path.basename(file_path) - - with open(file_path, 'rb') as file_obj: + + with open(file_path, "rb") as file_obj: # Upload the file to get a file URL unique_filename = f"{int(time.time())}-{filename}" user_id = self._lookup_user_id() - response = self._make_request("POST", f"object/documents/{user_id}/{unique_filename}", files={ - "upload-file": file_obj - }, service="storage") - + response = self._make_request( + "POST", + f"object/documents/{user_id}/{unique_filename}", + files={"upload-file": file_obj}, + service="storage", + ) + file_url = response["Key"] if file_url.startswith("documents/"): # Strip the prefix, as the workflow automatically looks in the bucket - return file_url[len("documents/"):] - + return file_url[len("documents/") :] + return file_url except Exception as e: raise ClientException(f"File upload failed: {e}") @@ -168,7 +173,7 @@ def run_workflow( # Create a copy of workflow_args to avoid modifying the original merged_args = workflow_args.copy() - + # Handle file uploads if provided if files: for arg_name, file_path in files.items(): diff --git a/tws/base/client.py b/tws/base/client.py index b2e8f48..e982dde 100644 --- a/tws/base/client.py +++ b/tws/base/client.py @@ -109,28 +109,28 @@ def _validate_tags(tags: Optional[Dict[str, str]]) -> None: raise ClientException( "Tag keys and values must be <= 255 characters" ) - + @abstractmethod def _lookup_user_id(self) -> Union[str, Coroutine[Any, Any, str]]: """Look up the user ID associated with the API key. - + Lazily fetches and caches the user ID if it hasn't been retrieved yet. - + Returns: The user ID string - + Raises: ClientException: If the user ID cannot be found """ raise NotImplementedError() - + @staticmethod def _validate_files(files: Optional[Dict[str, str]]) -> None: """Validate file upload parameters. - + Args: files: Dictionary mapping argument names to file paths - + Raises: ClientException: If files parameter is invalid """ @@ -140,7 +140,7 @@ def _validate_files(files: Optional[Dict[str, str]]) -> None: for key, value in files.items(): if not isinstance(key, str): raise ClientException("File keys must be strings") - + # Validate that the value is a string (file path) if not isinstance(value, str): raise ClientException("File values must be file paths (strings)") From 272a5239094e8e57825932115a16eb9f700c55e5 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 11:21:15 -0700 Subject: [PATCH 4/6] tests: fix existing unit tests --- tests/test_async_client.py | 1 + tests/test_sync_client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 3dac696..a80e8a2 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -249,6 +249,7 @@ def mock_json(): "/rest/v1/test/endpoint", json={"param": "value"}, params={"query": "param"}, + files=None, ) assert result == {"data": "test"} diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py index 4f96382..fc58031 100644 --- a/tests/test_sync_client.py +++ b/tests/test_sync_client.py @@ -243,6 +243,7 @@ def mock_json(): "/rest/v1/test/endpoint", json={"param": "value"}, params={"query": "param"}, + files=None, ) assert result == {"data": "test"} From 9ff86d87be47c4d16c7a083f0e6e5c486562b6a4 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 11:42:36 -0700 Subject: [PATCH 5/6] tests: add unit tests for new file upload logic --- tests/test_async_client.py | 284 +++++++++++++++++++++++++++++++++++++ tests/test_sync_client.py | 157 ++++++++++++++++++++ tws/_async/client.py | 9 +- tws/_sync/client.py | 7 +- 4 files changed, 445 insertions(+), 12 deletions(-) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index a80e8a2..9dd874a 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -88,6 +88,42 @@ async def test_run_workflow_tag_validation(good_async_client, tags, exception_me assert exception_message in str(exc_info.value) +@pytest.mark.parametrize( + "files,exception_message", + [ + [{"key": 123}, "File values must be file paths (strings)"], + [{"key": "value", "bad_key": 123}, "File values must be file paths (strings)"], + [{123: "value"}, "File keys must be strings"], + ["not_a_dict", "Files must be a dictionary"], + [{}, None], # Empty dict is valid + ], +) +async def test_run_workflow_files_validation( + good_async_client, files, exception_message +): + if exception_message: + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + assert exception_message in str(exc_info.value) + else: + # This should not raise an exception + async with good_async_client: + with patch("tws._async.client.AsyncClient._make_rpc_request") as mock_rpc: + with patch( + "tws._async.client.AsyncClient._make_request" + ) as mock_request: + mock_rpc.return_value = {"workflow_instance_id": "123"} + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success"}} + ] + await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + @patch("tws._async.client.AsyncClient._make_rpc_request") @patch("tws._async.client.AsyncClient._make_request") async def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_async_client): @@ -118,6 +154,51 @@ async def test_run_workflow_with_valid_tags(mock_request, mock_rpc, good_async_c assert result == {"output": "success"} +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_files( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock file upload + mock_upload.return_value = "user-123/timestamp-test_file.txt" + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with file"}} + ] + + files = {"input_file": str(test_file)} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + # Verify file was uploaded + mock_upload.assert_called_once_with(str(test_file)) + + # Verify the file path was merged into workflow args + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file": "user-123/timestamp-test_file.txt", + }, + }, + ) + assert result == {"output": "success with file"} + + @patch("tws._async.client.AsyncClient._make_rpc_request") async def test_run_workflow_not_found(mock_rpc, good_async_client): mock_request = Request("POST", "http://example.com") @@ -265,6 +346,63 @@ async def test_make_request_error(mock_request, good_async_client): assert "Request error occurred: Network error" in str(exc_info.value) +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_success(mock_request, good_async_client): + # Mock successful user ID lookup + mock_request.return_value = [{"user_id": "test-user-123"}] + + async with good_async_client: + user_id = await good_async_client._lookup_user_id() + + # Verify the request was made correctly + mock_request.assert_called_once_with( + "GET", + "users_private", + params={ + "select": "user_id", + "api_key": f"eq.{good_async_client.session.headers['X-TWS-API-KEY']}", + }, + ) + + # Verify the user ID was returned and cached + assert user_id == "test-user-123" + assert good_async_client.user_id == "test-user-123" + + # Reset the mock and call again to verify caching + mock_request.reset_mock() + + # Second call should use cached value + user_id_again = await good_async_client._lookup_user_id() + assert user_id_again == "test-user-123" + + # Verify no additional request was made + mock_request.assert_not_called() + + +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_empty_response(mock_request, good_async_client): + # Mock empty response (no user found) + mock_request.return_value = [] + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._lookup_user_id() + + assert "User ID not found, is your API key correct?" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._make_request") +async def test_lookup_user_id_request_error(mock_request, good_async_client): + # Mock request error + mock_request.side_effect = Exception("Database connection error") + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._lookup_user_id() + + assert "Failed to look up user ID: Database connection error" in str(exc_info.value) + + @patch("tws._async.client.AsyncClient._make_request") async def test_make_rpc_request_success(mock_request, good_async_client): mock_request.return_value = {"result": "success"} @@ -312,3 +450,149 @@ async def test_run_workflow_timeout( "workflow-id", {"arg": "value"}, timeout=600 ) assert "Workflow execution timed out after 600 seconds" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._lookup_user_id") +@patch("tws._async.client.AsyncClient._make_request") +async def test_upload_file_success( + mock_request, mock_lookup_user_id, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-456" + + # Mock successful file upload + mock_request.return_value = { + "Key": "documents/test-user-456/timestamp-test_file.txt" + } + + async with good_async_client: + file_path = await good_async_client._upload_file(str(test_file)) + + # Verify the correct path is returned (without the documents/ prefix) + assert file_path == "test-user-456/timestamp-test_file.txt" + + # Verify the file upload request was made with the correct parameters + assert mock_request.call_count == 1 + # We can't check the exact file content in the call args because it's dynamic, + # but we can verify the endpoint and service + call_args = mock_request.call_args + assert call_args[0][0] == "POST" # HTTP method + assert "object/documents/test-user-456/" in call_args[0][1] # URI + assert call_args[1]["service"] == "storage" # service parameter + + +@patch("tws._async.client.AsyncClient._lookup_user_id") +async def test_upload_file_nonexistent_file(mock_lookup_user_id, good_async_client): + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-456" + + with pytest.raises(ClientException) as exc_info: + async with good_async_client: + await good_async_client._upload_file("/nonexistent/file.txt") + + assert "File not found: /nonexistent/file.txt" in str(exc_info.value) + + +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_multiple_files( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create temporary test files + test_file1 = tmp_path / "test_file1.txt" + test_file1.write_text("test content 1") + + test_file2 = tmp_path / "test_file2.txt" + test_file2.write_text("test content 2") + + # Mock file uploads with different return values for each file + mock_upload.side_effect = [ + "user-123/timestamp-test_file1.txt", + "user-123/timestamp-test_file2.txt", + ] + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with multiple files"}} + ] + + files = {"input_file1": str(test_file1), "input_file2": str(test_file2)} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files + ) + + # Verify both files were uploaded + assert mock_upload.call_count == 2 + mock_upload.assert_any_call(str(test_file1)) + mock_upload.assert_any_call(str(test_file2)) + + # Verify the file paths were merged into workflow args + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file1": "user-123/timestamp-test_file1.txt", + "input_file2": "user-123/timestamp-test_file2.txt", + }, + }, + ) + assert result == {"output": "success with multiple files"} + + +@patch("tws._async.client.AsyncClient._upload_file") +@patch("tws._async.client.AsyncClient._make_rpc_request") +@patch("tws._async.client.AsyncClient._make_request") +async def test_run_workflow_with_files_and_tags( + mock_request, mock_rpc, mock_upload, good_async_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock file upload + mock_upload.return_value = "user-123/timestamp-test_file.txt" + + # Mock successful workflow start + mock_rpc.return_value = {"workflow_instance_id": "123"} + + # Mock successful completion + mock_request.return_value = [ + {"status": "COMPLETED", "result": {"output": "success with file and tags"}} + ] + + files = {"input_file": str(test_file)} + tags = {"tag1": "value1", "tag2": "value2"} + + async with good_async_client: + result = await good_async_client.run_workflow( + "workflow-id", {"arg": "value"}, files=files, tags=tags + ) + + # Verify file was uploaded + mock_upload.assert_called_once_with(str(test_file)) + + # Verify the file path was merged into workflow args and tags were included + mock_rpc.assert_called_once_with( + "start_workflow", + { + "workflow_definition_id": "workflow-id", + "request_body": { + "arg": "value", + "input_file": "user-123/timestamp-test_file.txt", + }, + "tags": tags, + }, + ) + assert result == {"output": "success with file and tags"} diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py index fc58031..f7315ba 100644 --- a/tests/test_sync_client.py +++ b/tests/test_sync_client.py @@ -283,6 +283,57 @@ def test_make_rpc_request_without_payload(mock_request, good_client): assert result == {"result": "success"} +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_success(mock_request, good_client): + # Mock successful user ID lookup + mock_request.return_value = [{"user_id": "test-user-123"}] + + with good_client: + # First call should make the request + user_id = good_client._lookup_user_id() + # Second call should use cached value + user_id_cached = good_client._lookup_user_id() + + # Verify request was made with correct parameters + mock_request.assert_called_once_with( + "GET", + "users_private", + params={ + "select": "user_id", + "api_key": f"eq.{good_client.session.headers['X-TWS-API-KEY']}", + }, + ) + + assert user_id == "test-user-123" + assert user_id_cached == "test-user-123" + assert good_client.user_id == "test-user-123" + assert mock_request.call_count == 1 # Should only be called once due to caching + + +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_empty_response(mock_request, good_client): + # Mock empty response + mock_request.return_value = [] + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._lookup_user_id() + + assert "User ID not found, is your API key correct?" in str(exc_info.value) + + +@patch("tws._sync.client.SyncClient._make_request") +def test_lookup_user_id_request_error(mock_request, good_client): + # Mock request error + mock_request.side_effect = Exception("Network error") + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._lookup_user_id() + + assert "Failed to look up user ID: Network error" in str(exc_info.value) + + @patch("tws._sync.client.SyncClient._make_rpc_request") @patch("tws._sync.client.SyncClient._make_request") @patch("time.time") @@ -300,3 +351,109 @@ def test_run_workflow_timeout(mock_time, mock_request, mock_rpc, good_client): with good_client: good_client.run_workflow("workflow-id", {"arg": "value"}, timeout=600) assert "Workflow execution timed out after 600 seconds" in str(exc_info.value) + + +@patch("tws._sync.client.SyncClient._lookup_user_id") +@patch("tws._sync.client.SyncClient._make_request") +def test_run_workflow_with_file_upload( + mock_request, mock_lookup_user_id, good_client, tmp_path +): + # Create a temporary test file + test_file = tmp_path / "test_file.txt" + test_file.write_text("test content") + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-123" + + # Mock file upload response + mock_request.side_effect = [ + # First call for file upload + {"Key": "documents/test-user-123/timestamp-test_file.txt"}, + # Second call for workflow status + [{"status": "COMPLETED", "result": {"output": "success"}}], + ] + + # Mock RPC request for workflow start + with patch("tws._sync.client.SyncClient._make_rpc_request") as mock_rpc: + mock_rpc.return_value = {"workflow_instance_id": "123"} + + with good_client: + result = good_client.run_workflow( + "workflow-id", {"arg": "value"}, files={"file_arg": str(test_file)} + ) + + # Verify the file was uploaded and the file path was included in workflow args + mock_rpc.assert_called_once() + workflow_args = mock_rpc.call_args[0][1]["request_body"] + assert "file_arg" in workflow_args + assert workflow_args["file_arg"] == "test-user-123/timestamp-test_file.txt" + assert result == {"output": "success"} + + +@patch("os.path.exists") +def test_upload_file_not_found(mock_exists, good_client): + # Mock file not found + mock_exists.return_value = False + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/nonexistent/file.txt") + + assert "File not found: /path/to/nonexistent/file.txt" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("builtins.open", side_effect=IOError("Permission denied")) +def test_upload_file_open_error(mock_open, mock_exists, good_client): + # Mock file exists but can't be opened + mock_exists.return_value = True + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/file.txt") + + assert "File upload failed: Permission denied" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("builtins.open") +@patch("tws._sync.client.SyncClient._lookup_user_id") +@patch("tws._sync.client.SyncClient._make_request") +def test_upload_file_api_error( + mock_request, mock_lookup_user_id, mock_open, mock_exists, good_client +): + # Mock file exists and can be opened + mock_exists.return_value = True + mock_file = mock_open.return_value.__enter__.return_value + + # Mock user ID lookup + mock_lookup_user_id.return_value = "test-user-123" + + # Mock API error during upload + mock_request.side_effect = Exception("API error") + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client._upload_file("/path/to/file.txt") + + assert "File upload failed: API error" in str(exc_info.value) + + +@patch("os.path.exists") +@patch("time.time") +def test_run_workflow_with_file_upload_error(mock_time, mock_exists, good_client): + # Mock time for filename generation + mock_time.return_value = 1234567890 + + # Mock file not found + mock_exists.return_value = False + + with pytest.raises(ClientException) as exc_info: + with good_client: + good_client.run_workflow( + "workflow-id", + {"arg": "value"}, + files={"file_arg": "/path/to/nonexistent/file.txt"}, + ) + + assert "File not found: /path/to/nonexistent/file.txt" in str(exc_info.value) diff --git a/tws/_async/client.py b/tws/_async/client.py index 3a9a286..fbed8d4 100644 --- a/tws/_async/client.py +++ b/tws/_async/client.py @@ -149,8 +149,6 @@ async def _upload_file(self, file_path: str) -> str: # Detect MIME type based on file extension content_type, _ = mimetypes.guess_type(file_path) - if content_type is None: - content_type = "application/octet-stream" async with aiofiles.open(file_path, "rb") as file_obj: file_content = await file_obj.read() @@ -168,11 +166,8 @@ async def _upload_file(self, file_path: str) -> str: ) file_url = response["Key"] - if file_url.startswith("documents/"): - # Strip the prefix, as the workflow automatically looks in the bucket - return file_url[len("documents/") :] - - return file_url + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/") :] except Exception as e: raise ClientException(f"File upload failed: {e}") diff --git a/tws/_sync/client.py b/tws/_sync/client.py index 641d183..a9e75dc 100644 --- a/tws/_sync/client.py +++ b/tws/_sync/client.py @@ -150,11 +150,8 @@ def _upload_file(self, file_path: str) -> str: ) file_url = response["Key"] - if file_url.startswith("documents/"): - # Strip the prefix, as the workflow automatically looks in the bucket - return file_url[len("documents/") :] - - return file_url + # Strip the prefix, as the workflow automatically looks in the bucket + return file_url[len("documents/") :] except Exception as e: raise ClientException(f"File upload failed: {e}") From 61ca1f973d00c3e4ddcaf73e2478731a87e76426 Mon Sep 17 00:00:00 2001 From: Sean Schaefer Date: Thu, 6 Mar 2025 11:49:41 -0700 Subject: [PATCH 6/6] style: lint fix --- tests/test_sync_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py index f7315ba..9461a28 100644 --- a/tests/test_sync_client.py +++ b/tests/test_sync_client.py @@ -424,7 +424,6 @@ def test_upload_file_api_error( ): # Mock file exists and can be opened mock_exists.return_value = True - mock_file = mock_open.return_value.__enter__.return_value # Mock user ID lookup mock_lookup_user_id.return_value = "test-user-123"