From ebc2f0ebd449db5971f1826bfdeb6629d485b4ac Mon Sep 17 00:00:00 2001 From: Thomas Chopitea Date: Wed, 5 Feb 2025 00:55:32 +0000 Subject: [PATCH 1/2] Add some error handling --- tests/api.py | 20 ++++++++++++++++++++ yeti/api.py | 31 ++++++++++++++++++++----------- yeti/errors.py | 9 +++++++++ 3 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 yeti/errors.py diff --git a/tests/api.py b/tests/api.py index 5b4ea7b..3bc8e4b 100644 --- a/tests/api.py +++ b/tests/api.py @@ -1,6 +1,9 @@ import unittest from unittest.mock import patch, MagicMock from yeti.api import YetiApi +from yeti import errors + +import requests class TestYetiApi(unittest.TestCase): @@ -259,6 +262,23 @@ def test_search_graph(self, mock_post): }, ) + @patch("yeti.api.requests.Session.post") + def test_error_message(self, mock_post): + # create mock requests response that raises an requests.exceptions.HTTPError for status + mock_response = MagicMock() + mock_exception_with_status_code = requests.exceptions.HTTPError() + mock_exception_with_status_code.response = MagicMock() + mock_exception_with_status_code.response.status_code = 400 + mock_exception_with_status_code.response.text = "error_message" + mock_response.raise_for_status.side_effect = mock_exception_with_status_code + mock_post.return_value = mock_response + + with self.assertRaises(errors.YetiApiError) as raised: + self.api.new_indicator({"name": "test_indicator"}) + + self.assertEqual(str(raised.exception), "error_message") + self.assertEqual(raised.exception.status_code, 400) + if __name__ == "__main__": unittest.main() diff --git a/yeti/api.py b/yeti/api.py index b2f1a72..d365625 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -1,11 +1,11 @@ """Python client for the Yeti API.""" -import requests -import requests_toolbelt.multipart.encoder as encoder - import json from typing import Any, Sequence +import yeti.errors as errors +import requests +import requests_toolbelt.multipart.encoder as encoder TYPE_TO_ENDPOINT = { "indicator": "/api/v2/indicators", @@ -73,14 +73,23 @@ def do_request( if body: request_kwargs["body"] = body - if method == "POST": - response = self.client.post(url, **request_kwargs) - elif method == "PATCH": - response = self.client.patch(url, **request_kwargs) - elif method == "GET": - response = self.client.get(url, **request_kwargs) - else: - raise ValueError(f"Unsupported method: {method}") + try: + if method == "POST": + response = self.client.post(url, **request_kwargs) + elif method == "PATCH": + response = self.client.patch(url, **request_kwargs) + elif method == "GET": + response = self.client.get(url, **request_kwargs) + else: + raise ValueError(f"Unsupported method: {method}") + except requests.exceptions.HTTPError as e: + raise errors.YetiApiError(e.response.status_code, e.response.text) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise errors.YetiApiError(e.response.status_code, e.response.text) + return response.bytes def auth_api_key(self, apikey: str) -> None: diff --git a/yeti/errors.py b/yeti/errors.py new file mode 100644 index 0000000..e09c86a --- /dev/null +++ b/yeti/errors.py @@ -0,0 +1,9 @@ +class YetiApiError(RuntimeError): + """Base class for errors in the Yeti API.""" + + status_code: int + message: str + + def __init__(self, status_code: int, message: str): + super().__init__(message) + self.status_code = status_code From b62cbefb3b7cf0137abf4b368cdad176fde7ab40 Mon Sep 17 00:00:00 2001 From: Thomas Chopitea Date: Wed, 5 Feb 2025 00:56:28 +0000 Subject: [PATCH 2/2] Spurious exception --- yeti/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/yeti/api.py b/yeti/api.py index d365625..0cc64b4 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -82,10 +82,6 @@ def do_request( response = self.client.get(url, **request_kwargs) else: raise ValueError(f"Unsupported method: {method}") - except requests.exceptions.HTTPError as e: - raise errors.YetiApiError(e.response.status_code, e.response.text) - - try: response.raise_for_status() except requests.exceptions.HTTPError as e: raise errors.YetiApiError(e.response.status_code, e.response.text)