From 37011ed3d1283ca175d7ff2aad624fff554777c9 Mon Sep 17 00:00:00 2001 From: tomchop Date: Tue, 11 Mar 2025 23:33:20 +0000 Subject: [PATCH 1/3] Add DFIQ patch object endpoint --- tests/api.py | 18 ++++++++++++++++++ yeti/api.py | 14 +++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/api.py b/tests/api.py index 84038cd..9f3e91d 100644 --- a/tests/api.py +++ b/tests/api.py @@ -149,6 +149,24 @@ def test_patch_dfiq_from_yaml(self, mock_patch): }, ) + @patch("yeti.api.requests.Session.patch") + def test_patch_dfiq(self, mock_patch): + mock_response = MagicMock() + mock_response.content = b'{"id": "patched_dfiq"}' + mock_patch.return_value = mock_response + + result = self.api.patch_dfiq( + {"name": "patched_dfiq", "id": 1, "type": "question"} + ) + self.assertEqual(result, {"id": "patched_dfiq"}) + mock_patch.assert_called_with( + "http://fake-url/api/v2/dfiq/1", + json={ + "dfiq_object": {"name": "patched_dfiq", "type": "question", "id": 1}, + "dfiq_type": "question", + }, + ) + @patch("yeti.api.requests.Session.post") def test_download_dfiq_archive(self, mock_post): mock_response = MagicMock() diff --git a/yeti/api.py b/yeti/api.py index 1533a91..0aa7a9b 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -278,13 +278,25 @@ def patch_dfiq_from_yaml( params = { "dfiq_type": dfiq_type, "dfiq_yaml": dfiq_yaml, - "update_indicators": True, } response = self.do_request( "PATCH", f"{self._url_root}/api/v2/dfiq/{yeti_id}", json_data=params ) return json.loads(response) + def patch_dfiq(self, dfiq_object: dict[str, Any]) -> YetiObject: + """Patches a DFIQ object in Yeti.""" + params = { + "dfiq_type": dfiq_object["type"], + "dfiq_object": dfiq_object, + } + response = self.do_request( + "PATCH", + f"{self._url_root}/api/v2/dfiq/{dfiq_object['id']}", + json_data=params, + ) + return json.loads(response) + def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes: """Downloads an archive containing all DFIQ data from Yeti. From f30048f52b288dc60e0868fa0e595ae5fb9bb247 Mon Sep 17 00:00:00 2001 From: tomchop Date: Tue, 11 Mar 2025 23:40:50 +0000 Subject: [PATCH 2/3] Add reauth function --- tests/api.py | 2 -- yeti/api.py | 24 +++++++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/api.py b/tests/api.py index 9f3e91d..f8912c7 100644 --- a/tests/api.py +++ b/tests/api.py @@ -128,7 +128,6 @@ def test_new_dfiq_from_yaml(self, mock_post): json={ "dfiq_type": "type", "dfiq_yaml": "yaml_content", - "update_indicators": True, }, ) @@ -145,7 +144,6 @@ def test_patch_dfiq_from_yaml(self, mock_patch): json={ "dfiq_type": "type", "dfiq_yaml": "yaml_content", - "update_indicators": True, }, ) diff --git a/yeti/api.py b/yeti/api.py index 0aa7a9b..2b9f549 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -39,6 +39,10 @@ def __init__(self, url_root: str): "Content-Type": "application/json", } self._url_root = url_root + self._auth_method = "" + self._auth_functions = { + "auth_api_key": self.auth_api_key, + } def do_request( self, @@ -89,13 +93,18 @@ def do_request( return response.content - def auth_api_key(self, apikey: str) -> None: + def auth_api_key(self, apikey: str | None = None) -> None: """Authenticates a session using an API key.""" # Use long-term refresh API token to get an access token + if apikey is not None: + self._apikey = apikey + if not self._apikey: + raise ValueError("No API key provided.") + response = self.do_request( "POST", f"{self._url_root}{API_TOKEN_ENDPOINT}", - headers={"x-yeti-apikey": apikey}, + headers={"x-yeti-apikey": self._apikey}, ) access_token = json.loads(response).get("access_token") @@ -107,6 +116,16 @@ def auth_api_key(self, apikey: str) -> None: authd_session.headers.update({"authorization": f"Bearer {access_token}"}) self.client = authd_session + self._auth_method = "auth_api_key" + + def refresh_auth(self): + if self._auth_method: + self._auth_functions[self._auth_method]() + else: + raise RuntimeError( + "No authentication method set. You might have to authenticate first." + ) + def search_indicators( self, name: str | None = None, @@ -261,7 +280,6 @@ def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject: params = { "dfiq_type": dfiq_type, "dfiq_yaml": dfiq_yaml, - "update_indicators": True, } response = self.do_request( "POST", f"{self._url_root}/api/v2/dfiq/from_yaml", json_data=params From 511c4c95d13d7783204fcbaab98ca8221f31bce8 Mon Sep 17 00:00:00 2001 From: tomchop Date: Wed, 12 Mar 2025 01:48:16 +0000 Subject: [PATCH 3/3] Add a auth refresh handler --- yeti/api.py | 35 +++++++++++++++++++++++++++-------- yeti/client.py | 5 ++++- yeti/errors.py | 13 ++++++++++++- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/yeti/api.py b/yeti/api.py index 2b9f549..aee183f 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -1,6 +1,7 @@ """Python client for the Yeti API.""" import json +import logging from typing import Any, Sequence import requests @@ -24,6 +25,13 @@ YetiLinkObject = dict[str, Any] +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + class YetiApi: """API object to interact with the Yeti API. @@ -39,11 +47,14 @@ def __init__(self, url_root: str): "Content-Type": "application/json", } self._url_root = url_root - self._auth_method = "" - self._auth_functions = { + + self._auth_function = "" + self._auth_function_map = { "auth_api_key": self.auth_api_key, } + self._apikey = None + def do_request( self, method: str, @@ -51,6 +62,7 @@ def do_request( json_data: dict[str, Any] | None = None, body: bytes | None = None, headers: dict[str, Any] | None = None, + retries: int = 3, ) -> bytes: """Issues a request to the given URL. @@ -60,6 +72,7 @@ def do_request( json: The JSON payload to include in the request. body: The body to include in the request. headers: Extra headers to include in the request. + retries: The number of times to retry the request. Returns: The response from the API; a bytes object. @@ -89,6 +102,14 @@ def do_request( raise ValueError(f"Unsupported method: {method}") response.raise_for_status() except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + if retries == 0: + raise errors.YetiAuthError(str(e)) from e + self.refresh_auth() + return self.do_request( + method, url, json_data, body, headers, retries - 1 + ) + raise errors.YetiApiError(e.response.status_code, e.response.text) return response.content @@ -116,15 +137,13 @@ def auth_api_key(self, apikey: str | None = None) -> None: authd_session.headers.update({"authorization": f"Bearer {access_token}"}) self.client = authd_session - self._auth_method = "auth_api_key" + self._auth_function = "auth_api_key" def refresh_auth(self): - if self._auth_method: - self._auth_functions[self._auth_method]() + if self._auth_function: + self._auth_function_map[self._auth_function]() else: - raise RuntimeError( - "No authentication method set. You might have to authenticate first." - ) + logger.warning("No auth function set, cannot refresh auth.") def search_indicators( self, diff --git a/yeti/client.py b/yeti/client.py index c06fb06..441f04d 100644 --- a/yeti/client.py +++ b/yeti/client.py @@ -13,7 +13,10 @@ def __init__(self): @click.group() @click.option("--api-key", envvar="YETI_API_KEY", required=True, help="Your API key.") @click.option( - "--endpoint", envvar="YETI_WEB_ROOT", required=True, help="The Yeti endpoint." + "--endpoint", + envvar="YETI_WEB_ROOT", + required=True, + help="The Yeti endpoint, e.g. http://localhost:3000/", ) @pass_context # Add this to pass the context to subcommands def cli(ctx, api_key, endpoint): diff --git a/yeti/errors.py b/yeti/errors.py index e09c86a..b6f6cf4 100644 --- a/yeti/errors.py +++ b/yeti/errors.py @@ -1,4 +1,8 @@ -class YetiApiError(RuntimeError): +class YetiError(RuntimeError): + """Base class for errors in the Yeti package.""" + + +class YetiApiError(YetiError): """Base class for errors in the Yeti API.""" status_code: int @@ -7,3 +11,10 @@ class YetiApiError(RuntimeError): def __init__(self, status_code: int, message: str): super().__init__(message) self.status_code = status_code + + +class YetiAuthError(YetiError): + """Error authenticating with the Yeti API.""" + + def __init__(self, message: str): + super().__init__(message)