diff --git a/tests/api.py b/tests/api.py index 5b4ea7b..3bc8e4b 100644 --- a/tests/api.py +++ b/tests/api.py @@ -1,6 +1,9 @@ import unittest from unittest.mock import patch, MagicMock from yeti.api import YetiApi +from yeti import errors + +import requests class TestYetiApi(unittest.TestCase): @@ -259,6 +262,23 @@ def test_search_graph(self, mock_post): }, ) + @patch("yeti.api.requests.Session.post") + def test_error_message(self, mock_post): + # create mock requests response that raises an requests.exceptions.HTTPError for status + mock_response = MagicMock() + mock_exception_with_status_code = requests.exceptions.HTTPError() + mock_exception_with_status_code.response = MagicMock() + mock_exception_with_status_code.response.status_code = 400 + mock_exception_with_status_code.response.text = "error_message" + mock_response.raise_for_status.side_effect = mock_exception_with_status_code + mock_post.return_value = mock_response + + with self.assertRaises(errors.YetiApiError) as raised: + self.api.new_indicator({"name": "test_indicator"}) + + self.assertEqual(str(raised.exception), "error_message") + self.assertEqual(raised.exception.status_code, 400) + if __name__ == "__main__": unittest.main() diff --git a/yeti/api.py b/yeti/api.py index b2f1a72..0cc64b4 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -1,11 +1,11 @@ """Python client for the Yeti API.""" -import requests -import requests_toolbelt.multipart.encoder as encoder - import json from typing import Any, Sequence +import yeti.errors as errors +import requests +import requests_toolbelt.multipart.encoder as encoder TYPE_TO_ENDPOINT = { "indicator": "/api/v2/indicators", @@ -73,14 +73,19 @@ def do_request( if body: request_kwargs["body"] = body - if method == "POST": - response = self.client.post(url, **request_kwargs) - elif method == "PATCH": - response = self.client.patch(url, **request_kwargs) - elif method == "GET": - response = self.client.get(url, **request_kwargs) - else: - raise ValueError(f"Unsupported method: {method}") + try: + if method == "POST": + response = self.client.post(url, **request_kwargs) + elif method == "PATCH": + response = self.client.patch(url, **request_kwargs) + elif method == "GET": + response = self.client.get(url, **request_kwargs) + else: + raise ValueError(f"Unsupported method: {method}") + response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise errors.YetiApiError(e.response.status_code, e.response.text) + return response.bytes def auth_api_key(self, apikey: str) -> None: diff --git a/yeti/errors.py b/yeti/errors.py new file mode 100644 index 0000000..e09c86a --- /dev/null +++ b/yeti/errors.py @@ -0,0 +1,9 @@ +class YetiApiError(RuntimeError): + """Base class for errors in the Yeti API.""" + + status_code: int + message: str + + def __init__(self, status_code: int, message: str): + super().__init__(message) + self.status_code = status_code