diff --git a/tests/api.py b/tests/api.py index 84038cd..f8912c7 100644 --- a/tests/api.py +++ b/tests/api.py @@ -128,7 +128,6 @@ def test_new_dfiq_from_yaml(self, mock_post): json={ "dfiq_type": "type", "dfiq_yaml": "yaml_content", - "update_indicators": True, }, ) @@ -145,7 +144,24 @@ def test_patch_dfiq_from_yaml(self, mock_patch): json={ "dfiq_type": "type", "dfiq_yaml": "yaml_content", - "update_indicators": True, + }, + ) + + @patch("yeti.api.requests.Session.patch") + def test_patch_dfiq(self, mock_patch): + mock_response = MagicMock() + mock_response.content = b'{"id": "patched_dfiq"}' + mock_patch.return_value = mock_response + + result = self.api.patch_dfiq( + {"name": "patched_dfiq", "id": 1, "type": "question"} + ) + self.assertEqual(result, {"id": "patched_dfiq"}) + mock_patch.assert_called_with( + "http://fake-url/api/v2/dfiq/1", + json={ + "dfiq_object": {"name": "patched_dfiq", "type": "question", "id": 1}, + "dfiq_type": "question", }, ) diff --git a/yeti/api.py b/yeti/api.py index 1533a91..aee183f 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -1,6 +1,7 @@ """Python client for the Yeti API.""" import json +import logging from typing import Any, Sequence import requests @@ -24,6 +25,13 @@ YetiLinkObject = dict[str, Any] +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + class YetiApi: """API object to interact with the Yeti API. @@ -40,6 +48,13 @@ def __init__(self, url_root: str): } self._url_root = url_root + self._auth_function = "" + self._auth_function_map = { + "auth_api_key": self.auth_api_key, + } + + self._apikey = None + def do_request( self, method: str, @@ -47,6 +62,7 @@ def do_request( json_data: dict[str, Any] | None = None, body: bytes | None = None, headers: dict[str, Any] | None = None, + retries: int = 3, ) -> bytes: """Issues a request to the given URL. @@ -56,6 +72,7 @@ def do_request( json: The JSON payload to include in the request. body: The body to include in the request. headers: Extra headers to include in the request. + retries: The number of times to retry the request. Returns: The response from the API; a bytes object. @@ -85,17 +102,30 @@ def do_request( raise ValueError(f"Unsupported method: {method}") response.raise_for_status() except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + if retries == 0: + raise errors.YetiAuthError(str(e)) from e + self.refresh_auth() + return self.do_request( + method, url, json_data, body, headers, retries - 1 + ) + raise errors.YetiApiError(e.response.status_code, e.response.text) return response.content - def auth_api_key(self, apikey: str) -> None: + def auth_api_key(self, apikey: str | None = None) -> None: """Authenticates a session using an API key.""" # Use long-term refresh API token to get an access token + if apikey is not None: + self._apikey = apikey + if not self._apikey: + raise ValueError("No API key provided.") + response = self.do_request( "POST", f"{self._url_root}{API_TOKEN_ENDPOINT}", - headers={"x-yeti-apikey": apikey}, + headers={"x-yeti-apikey": self._apikey}, ) access_token = json.loads(response).get("access_token") @@ -107,6 +137,14 @@ def auth_api_key(self, apikey: str) -> None: authd_session.headers.update({"authorization": f"Bearer {access_token}"}) self.client = authd_session + self._auth_function = "auth_api_key" + + def refresh_auth(self): + if self._auth_function: + self._auth_function_map[self._auth_function]() + else: + logger.warning("No auth function set, cannot refresh auth.") + def search_indicators( self, name: str | None = None, @@ -261,7 +299,6 @@ def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject: params = { "dfiq_type": dfiq_type, "dfiq_yaml": dfiq_yaml, - "update_indicators": True, } response = self.do_request( "POST", f"{self._url_root}/api/v2/dfiq/from_yaml", json_data=params @@ -278,13 +315,25 @@ def patch_dfiq_from_yaml( params = { "dfiq_type": dfiq_type, "dfiq_yaml": dfiq_yaml, - "update_indicators": True, } response = self.do_request( "PATCH", f"{self._url_root}/api/v2/dfiq/{yeti_id}", json_data=params ) return json.loads(response) + def patch_dfiq(self, dfiq_object: dict[str, Any]) -> YetiObject: + """Patches a DFIQ object in Yeti.""" + params = { + "dfiq_type": dfiq_object["type"], + "dfiq_object": dfiq_object, + } + response = self.do_request( + "PATCH", + f"{self._url_root}/api/v2/dfiq/{dfiq_object['id']}", + json_data=params, + ) + return json.loads(response) + def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes: """Downloads an archive containing all DFIQ data from Yeti. diff --git a/yeti/client.py b/yeti/client.py index c06fb06..441f04d 100644 --- a/yeti/client.py +++ b/yeti/client.py @@ -13,7 +13,10 @@ def __init__(self): @click.group() @click.option("--api-key", envvar="YETI_API_KEY", required=True, help="Your API key.") @click.option( - "--endpoint", envvar="YETI_WEB_ROOT", required=True, help="The Yeti endpoint." + "--endpoint", + envvar="YETI_WEB_ROOT", + required=True, + help="The Yeti endpoint, e.g. http://localhost:3000/", ) @pass_context # Add this to pass the context to subcommands def cli(ctx, api_key, endpoint): diff --git a/yeti/errors.py b/yeti/errors.py index e09c86a..b6f6cf4 100644 --- a/yeti/errors.py +++ b/yeti/errors.py @@ -1,4 +1,8 @@ -class YetiApiError(RuntimeError): +class YetiError(RuntimeError): + """Base class for errors in the Yeti package.""" + + +class YetiApiError(YetiError): """Base class for errors in the Yeti API.""" status_code: int @@ -7,3 +11,10 @@ class YetiApiError(RuntimeError): def __init__(self, status_code: int, message: str): super().__init__(message) self.status_code = status_code + + +class YetiAuthError(YetiError): + """Error authenticating with the Yeti API.""" + + def __init__(self, message: str): + super().__init__(message)