diff --git a/tests/api.py b/tests/api.py index 4be57cc..5b4ea7b 100644 --- a/tests/api.py +++ b/tests/api.py @@ -10,7 +10,7 @@ def setUp(self): @patch("yeti.api.requests.Session.post") def test_auth_api_key(self, mock_post): mock_response = MagicMock() - mock_response.text = '{"access_token": "fake_token"}' + mock_response.bytes = b'{"access_token": "fake_token"}' mock_post.return_value = mock_response self.api.auth_api_key("fake_apikey") @@ -23,7 +23,7 @@ def test_auth_api_key(self, mock_post): @patch("yeti.api.requests.Session.post") def test_search_indicators(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"indicators": [{"name": "test"}]} + mock_response.bytes = b'{"indicators": [{"name": "test"}]}' mock_post.return_value = mock_response result = self.api.search_indicators(name="test") @@ -36,7 +36,7 @@ def test_search_indicators(self, mock_post): @patch("yeti.api.requests.Session.post") def test_search_entities(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"entities": [{"name": "test_entity"}]} + mock_response.bytes = b'{"entities": [{"name": "test_entity"}]}' mock_post.return_value = mock_response result = self.api.search_entities(name="test_entity") @@ -49,7 +49,7 @@ def test_search_entities(self, mock_post): @patch("yeti.api.requests.Session.post") def test_search_observables(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"observables": [{"value": "test_value"}]} + mock_response.bytes = b'{"observables": [{"value": "test_value"}]}' mock_post.return_value = mock_response result = self.api.search_observables(value="test_value") @@ -62,7 +62,8 @@ def test_search_observables(self, mock_post): @patch("yeti.api.requests.Session.post") def test_new_entity(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "new_entity"} + mock_response.bytes = b'{"id": "new_entity"}' + mock_response.bytes = b'{"id": "new_entity"}' mock_post.return_value = mock_response result = self.api.new_entity({"name": "test_entity"}) @@ -75,7 +76,7 @@ def test_new_entity(self, mock_post): @patch("yeti.api.requests.Session.post") def test_new_indicator(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "new_indicator"} + mock_response.bytes = b'{"id": "new_indicator"}' mock_post.return_value = mock_response result = self.api.new_indicator({"name": "test_indicator"}) @@ -88,7 +89,7 @@ def test_new_indicator(self, mock_post): @patch("yeti.api.requests.Session.patch") def test_patch_indicator(self, mock_patch): mock_response = MagicMock() - mock_response.json.return_value = {"id": "patched_indicator"} + mock_response.bytes = b'{"id": "patched_indicator"}' mock_patch.return_value = mock_response result = self.api.patch_indicator(1, {"name": "patched_indicator"}) @@ -101,7 +102,7 @@ def test_patch_indicator(self, mock_patch): @patch("yeti.api.requests.Session.post") def test_search_dfiq(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"dfiq": [{"name": "test_dfiq"}]} + mock_response.bytes = b'{"dfiq": [{"name": "test_dfiq"}]}' mock_post.return_value = mock_response result = self.api.search_dfiq(name="test_dfiq") @@ -114,7 +115,7 @@ def test_search_dfiq(self, mock_post): @patch("yeti.api.requests.Session.post") def test_new_dfiq_from_yaml(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "new_dfiq"} + mock_response.bytes = b'{"id": "new_dfiq"}' mock_post.return_value = mock_response result = self.api.new_dfiq_from_yaml("type", "yaml_content") @@ -131,7 +132,7 @@ def test_new_dfiq_from_yaml(self, mock_post): @patch("yeti.api.requests.Session.patch") def test_patch_dfiq_from_yaml(self, mock_patch): mock_response = MagicMock() - mock_response.json.return_value = {"id": "patched_dfiq"} + mock_response.bytes = b'{"id": "patched_dfiq"}' mock_patch.return_value = mock_response result = self.api.patch_dfiq_from_yaml("type", "yaml_content", 1) @@ -161,7 +162,7 @@ def test_download_dfiq_archive(self, mock_post): @patch("yeti.api.requests.Session.post") def test_upload_dfiq_archive(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"uploaded": 1} + mock_response.bytes = b'{"uploaded": 1}' mock_post.return_value = mock_response with patch("builtins.open", unittest.mock.mock_open(read_data=b"data")): @@ -171,14 +172,14 @@ def test_upload_dfiq_archive(self, mock_post): mock_post.call_args[0][0], "http://fake-url/api/v2/dfiq/from_archive" ) self.assertRegex( - mock_post.call_args[1]["extra_headers"]["Content-Type"], + mock_post.call_args[1]["headers"]["Content-Type"], "multipart/form-data; boundary=[a-f0-9]{32}", ) @patch("yeti.api.requests.Session.post") def test_add_observable(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "new_observable"} + mock_response.bytes = b'{"id": "new_observable"}' mock_post.return_value = mock_response result = self.api.add_observable("value", "type") @@ -191,7 +192,7 @@ def test_add_observable(self, mock_post): @patch("yeti.api.requests.Session.post") def test_add_observables_bulk(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"added": [], "failed": []} + mock_response.bytes = b'{"added": [], "failed": []}' mock_post.return_value = mock_response result = self.api.add_observables_bulk([{"value": "value", "type": "type"}]) @@ -204,7 +205,7 @@ def test_add_observables_bulk(self, mock_post): @patch("yeti.api.requests.Session.post") def test_tag_object(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "tagged_object"} + mock_response.bytes = b'{"id": "tagged_object"}' mock_post.return_value = mock_response result = self.api.tag_object({"id": "1", "root_type": "indicator"}, ["tag1"]) @@ -217,7 +218,7 @@ def test_tag_object(self, mock_post): @patch("yeti.api.requests.Session.post") def test_link_objects(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"id": "link"} + mock_response.bytes = b'{"id": "link"}' mock_post.return_value = mock_response result = self.api.link_objects( @@ -239,7 +240,7 @@ def test_link_objects(self, mock_post): @patch("yeti.api.requests.Session.post") def test_search_graph(self, mock_post): mock_response = MagicMock() - mock_response.json.return_value = {"graph": "data"} + mock_response.bytes = b'{"graph": "data"}' mock_post.return_value = mock_response result = self.api.search_graph("source", "graph", ["type"]) diff --git a/yeti/api.py b/yeti/api.py index 33513eb..b2f1a72 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -39,15 +39,64 @@ def __init__(self, url_root: str): } self._url_root = url_root + def do_request( + self, + method: str, + url: str, + json_data: dict[str, Any] | None = None, + body: bytes | None = None, + headers: dict[str, Any] | None = None, + ) -> bytes: + """Issues a request to the given URL. + + Args: + method: The HTTP method to use. + url: The URL to issue the request to. + json: The JSON payload to include in the request. + body: The body to include in the request. + headers: Extra headers to include in the request. + + Returns: + The response from the API; a bytes object. + + """ + + if json_data and body: + raise ValueError("You must provide either json or body, not both.") + + request_kwargs = {} + + if headers: + request_kwargs["headers"] = headers + if json_data: + request_kwargs["json"] = json_data + if body: + request_kwargs["body"] = body + + if method == "POST": + response = self.client.post(url, **request_kwargs) + elif method == "PATCH": + response = self.client.patch(url, **request_kwargs) + elif method == "GET": + response = self.client.get(url, **request_kwargs) + else: + raise ValueError(f"Unsupported method: {method}") + return response.bytes + def auth_api_key(self, apikey: str) -> None: """Authenticates a session using an API key.""" # Use long-term refresh API token to get an access token - response = self.client.post( + response = self.do_request( + "POST", f"{self._url_root}{API_TOKEN_ENDPOINT}", headers={"x-yeti-apikey": apikey}, ) - access_token = json.loads(response.text).get("access_token") + access_token = json.loads(response).get("access_token") + if not access_token: + raise RuntimeError( + f"Failed to find access token in the response: {response}" + ) authd_session = requests.Session() authd_session.headers.update({"authorization": f"Bearer {access_token}"}) self.client = authd_session @@ -88,19 +137,21 @@ def search_indicators( if tags: query["tags"] = tags params = {"query": query, "count": 0} - response = self.client.post( + response = self.do_request( + "POST", f"{self._url_root}/api/v2/indicators/search", - json=params, + json_data=params, ) - return response.json()["indicators"] + return json.loads(response)["indicators"] def search_entities(self, name: str) -> list[YetiObject]: params = {"query": {"name": name}, "count": 0} - response = self.client.post( + response = self.do_request( + "POST", f"{self._url_root}/api/v2/entities/search", - json=params, + json_data=params, ) - return response.json()["entities"] + return json.loads(response)["entities"] def search_observables(self, value: str) -> list[YetiObject]: """Searches for an observable in Yeti. @@ -112,10 +163,10 @@ def search_observables(self, value: str) -> list[YetiObject]: The response from the API; a dict representing the observable. """ params = {"query": {"value": value}, "count": 0} - response = self.client.post( - f"{self._url_root}/api/v2/observables/search", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/observables/search", json_data=params ) - return response.json()["observables"] + return json.loads(response)["observables"] def new_entity( self, entity: dict[str, Any], tags: list[str] | None = None @@ -132,8 +183,12 @@ def new_entity( params = {"entity": entity} if tags: params["tags"] = tags - response = self.client.post(f"{self._url_root}/api/v2/entities/", json=params) - return response.json() + response = self.do_request( + "POST", + f"{self._url_root}/api/v2/entities/", + json_data=params, + ) + return json.loads(response) def new_indicator( self, @@ -150,12 +205,16 @@ def new_indicator( The response from the API; a dict representing the indicator. """ params = {"indicator": indicator} - response = self.client.post(f"{self._url_root}/api/v2/indicators/", json=params) - indicator = response.json() + response = self.do_request( + "POST", f"{self._url_root}/api/v2/indicators/", json_data=params + ) + indicator = json.loads(response) if tags: params = {"tags": tags, "ids": [indicator["id"]]} - self.client.post(f"{self._url_root}/api/v2/indicators/tag", json=params) + self.do_request( + "POST", f"{self._url_root}/api/v2/indicators/tag", json_data=params + ) return indicator @@ -166,10 +225,10 @@ def patch_indicator( ) -> YetiObject: """Patches an indicator in Yeti.""" params = {"indicator": indicator_object} - response = self.client.patch( - f"{self._url_root}/api/v2/indicators/{yeti_id}", json=params + response = self.do_request( + "PATCH", f"{self._url_root}/api/v2/indicators/{yeti_id}", json_data=params ) - return response.json() + return json.loads(response) def search_dfiq(self, name: str, dfiq_type: str | None = None) -> list[YetiObject]: """Searches for a DFIQ in Yeti. @@ -186,8 +245,10 @@ def search_dfiq(self, name: str, dfiq_type: str | None = None) -> list[YetiObjec if dfiq_type: query["type"] = dfiq_type params = {"query": query, "count": 0} - response = self.client.post(f"{self._url_root}/api/v2/dfiq/search", json=params) - return response.json()["dfiq"] + response = self.do_request( + "POST", f"{self._url_root}/api/v2/dfiq/search", json_data=params + ) + return json.loads(response)["dfiq"] def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject: """Creates a new DFIQ object in Yeti from a YAML string.""" @@ -196,10 +257,10 @@ def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject: "dfiq_yaml": dfiq_yaml, "update_indicators": True, } - response = self.client.post( - f"{self._url_root}/api/v2/dfiq/from_yaml", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/dfiq/from_yaml", json_data=params ) - return response.json() + return json.loads(response) def patch_dfiq_from_yaml( self, @@ -213,10 +274,10 @@ def patch_dfiq_from_yaml( "dfiq_yaml": dfiq_yaml, "update_indicators": True, } - response = self.client.patch( - f"{self._url_root}/api/v2/dfiq/{yeti_id}", json=params + response = self.do_request( + "PATCH", f"{self._url_root}/api/v2/dfiq/{yeti_id}", json_data=params ) - return response.json() + return json.loads(response) def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes: """Downloads an archive containing all DFIQ data from Yeti. @@ -231,10 +292,10 @@ def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes: params = {"count": 0} if dfiq_type: params["query"] = {"type": dfiq_type} - response = self.client.post( - f"{self._url_root}/api/v2/dfiq/to_archive", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/dfiq/to_archive", json_data=params ) - return response.bytes + return response def upload_dfiq_archive(self, archive_path: str) -> dict[str, int]: """Uploads a DFIQ archive to Yeti. @@ -253,12 +314,13 @@ def upload_dfiq_archive(self, archive_path: str) -> dict[str, int]: fields={"archive": ("archive.zip", data, "application/zip")} ) headers = {"Content-Type": encoded_data.content_type} - response = self.client.post( + response = self.do_request( + "POST", f"{self._url_root}/api/v2/dfiq/from_archive", - extra_headers=headers, + headers=headers, body=encoded_data.to_string(), ) - return response.json() + return json.loads(response) def add_observable( self, value: str, observable_type: str, tags: list[str] | None = None @@ -274,10 +336,10 @@ def add_observable( The response from the API; a dict representing the observable. """ params = {"value": value, "type": observable_type, "tags": tags} - response = self.client.post( - f"{self._url_root}/api/v2/observables/", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/observables/", json_data=params ) - return response.json() + return json.loads(response) def add_observables_bulk( self, observables: list[dict[str, Any]], tags: list[str] | None = None @@ -306,10 +368,10 @@ def add_observables_bulk( "observables": observables, } - response = self.client.post( - f"{self._url_root}/api/v2/observables/bulk", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/observables/bulk", json_data=params ) - return response.json() + return json.loads(response) def tag_object( self, yeti_object: dict[str, Any], tags: Sequence[str] @@ -317,8 +379,10 @@ def tag_object( """Tags an object in Yeti.""" params = {"tags": list(tags), "ids": [yeti_object["id"]]} endpoint = TYPE_TO_ENDPOINT[yeti_object["root_type"]] - result = self.client.post(f"{self._url_root}{endpoint}/tag", json=params) - return result.json() + response = self.do_request( + "POST", f"{self._url_root}{endpoint}/tag", json_data=params + ) + return json.loads(response) def link_objects( self, @@ -347,8 +411,10 @@ def link_objects( "link_type": link_type, "description": description, } - response = self.client.post(f"{self._url_root}/api/v2/graph/add", json=params) - return response.json() + response = self.do_request( + "POST", f"{self._url_root}/api/v2/graph/add", json_data=params + ) + return json.loads(response) def search_graph( self, @@ -389,7 +455,7 @@ def search_graph( "include_original": include_original, "target_types": target_types, } - response = self.client.post( - f"{self._url_root}/api/v2/graph/search", json=params + response = self.do_request( + "POST", f"{self._url_root}/api/v2/graph/search", json_data=params ) - return response.json() + return json.loads(response)