diff --git a/.editorconfig b/.editorconfig index 71ec778..4e50a0b 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,16 +10,16 @@ indent_style = space insert_final_newline = true trim_trailing_whitespace = true -[*.md] -indent_size = 2 -indent_style = space - [LICENSE.txt] insert_final_newline = false [*.{diff,patch}] trim_trailing_whitespace = false -[*.{json,yaml,yml}] +[*.{json,md,yaml,yml}] +indent_size = 2 +indent_style = space + +[.{prettierrc,yamllint}] indent_size = 2 indent_style = space diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ee4740..524de52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - --allow-missing-credentials - id: detect-private-key - id: end-of-file-fixer - exclude: '.bumpversion.cfg' + exclude: ".bumpversion.cfg" - id: mixed-line-ending - id: name-tests-test args: diff --git a/cert_manager/__init__.py b/cert_manager/__init__.py index 38d3037..e2b7820 100644 --- a/cert_manager/__init__.py +++ b/cert_manager/__init__.py @@ -3,7 +3,8 @@ from ._helpers import PendingError from .acme import ACMEAccount from .admin import Admin -from .client import Client +from .client import Client, OAuth2Client +from .dcv import DomainControlValidation from .domain import Domain from .organization import Organization from .person import Person @@ -12,5 +13,16 @@ from .ssl import SSL __all__ = [ - "ACMEAccount", "Admin", "Client", "Domain", "Organization", "PendingError", "Person", "Report", "SMIME", "SSL" + "ACMEAccount", + "Admin", + "Client", + "Domain", + "DomainControlValidation", + "OAuth2Client", + "Organization", + "PendingError", + "Person", + "Report", + "SMIME", + "SSL", ] diff --git a/cert_manager/_certificates.py b/cert_manager/_certificates.py index e8b3d17..134e039 100644 --- a/cert_manager/_certificates.py +++ b/cert_manager/_certificates.py @@ -39,9 +39,9 @@ def __init__(self, client, endpoint, api_version="v1"): super().__init__(client=client, endpoint=endpoint, api_version=api_version) # Set to None initially. Will be filled in by methods later. - self.__cert_types = None - self.__custom_fields = None - self.__reason_maxlen = 512 + self._cert_types = None + self._custom_fields = None + self._reason_maxlen = 512 @property def types(self): @@ -51,19 +51,19 @@ def types(self): """ # Only go to the API if we haven't done the API call yet, or if someone # specifically wants to refresh the internal cache - if not self.__cert_types: + if not self._cert_types: url = self._url("/types") result = self._client.get(url) # Build a dictionary instead of a flat list of dictionaries - self.__cert_types = {} + self._cert_types = {} for res in result.json(): name = res["name"] - self.__cert_types[name] = {} - self.__cert_types[name]["id"] = res["id"] - self.__cert_types[name]["terms"] = res["terms"] + self._cert_types[name] = {} + self._cert_types[name]["id"] = res["id"] + self._cert_types[name]["terms"] = res["terms"] - return self.__cert_types + return self._cert_types @property def custom_fields(self): @@ -73,13 +73,13 @@ def custom_fields(self): """ # Only go to the API if we haven't done the API call yet, or if someone # specifically wants to refresh the internal cache - if not self.__custom_fields: + if not self._custom_fields: url = self._url("/customFields") result = self._client.get(url) - self.__custom_fields = result.json() + self._custom_fields = result.json() - return self.__custom_fields + return self._custom_fields def _validate_custom_fields(self, custom_fields): """Check the structure and contents of a list of custom fields dicts. Raise exceptions if validation fails. @@ -230,8 +230,8 @@ def revoke(self, cert_id, reason=""): url = self._url(f"/revoke/{cert_id}") # Sectigo has a 512 character limit on the "reason" message, so catch that here. - if not reason or len(reason) >= self.__reason_maxlen: - raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self.__reason_maxlen} characters") + if not reason or len(reason) >= self._reason_maxlen: + raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self._reason_maxlen} characters") data = {"reason": reason} diff --git a/cert_manager/acme.py b/cert_manager/acme.py index bfadb8c..45ad262 100644 --- a/cert_manager/acme.py +++ b/cert_manager/acme.py @@ -35,7 +35,7 @@ def __init__(self, client, api_version="v1"): """ super().__init__(client=client, endpoint="/acme", api_version=api_version) self._api_url = self._url("/account") - self.__acme_accounts = None + self._acme_accounts = None def all(self, org_id, force=False): """Return a list of acme accounts from Sectigo. @@ -45,15 +45,15 @@ def all(self, org_id, force=False): :return list: A list of dictionaries representing the acme accounts """ - if (self.__acme_accounts) and (not force): - return self.__acme_accounts + if (self._acme_accounts) and (not force): + return self._acme_accounts - self.__acme_accounts = [] + self._acme_accounts = [] result = self.find(org_id) for acct in result: - self.__acme_accounts.append(acct) + self._acme_accounts.append(acct) - return self.__acme_accounts + return self._acme_accounts @paginate def find(self, org_id, **kwargs): diff --git a/cert_manager/admin.py b/cert_manager/admin.py index 4ef2a24..0ccf136 100644 --- a/cert_manager/admin.py +++ b/cert_manager/admin.py @@ -25,7 +25,7 @@ def __init__(self, client, api_version="v1"): """ super().__init__(client=client, endpoint="/admin", api_version=api_version) - self.__admins = None + self._admins = None self.all() def all(self, force=False): @@ -35,14 +35,14 @@ def all(self, force=False): :return list: A list of dictionaries representing the admins """ - if (self.__admins) and (not force): - return self.__admins + if (self._admins) and (not force): + return self._admins result = self._client.get(self._api_url) - self.__admins = result.json() + self._admins = result.json() - return self.__admins + return self._admins def create(self, login, email, forename, surname, password, credentials, **kwargs): # noqa: PLR0913 """Create a new administrator. diff --git a/cert_manager/client.py b/cert_manager/client.py index bf4634d..8a68786 100644 --- a/cert_manager/client.py +++ b/cert_manager/client.py @@ -37,44 +37,44 @@ def __init__(self, **kwargs): :param string user_key_file: The path to the key file if using client cert auth """ # These options are required, so raise a KeyError if they are not provided. - self.__login_uri = kwargs["login_uri"] - self.__username = kwargs["username"] + self._login_uri = kwargs["login_uri"] + self._username = kwargs["username"] # Using get for consistency and to allow defaults to be easily set - self.__base_url = kwargs.get("base_url", "https://cert-manager.com/api") - self.__cert_auth = kwargs.get("cert_auth", False) - self.__session = requests.Session() + self._base_url = kwargs.get("base_url", "https://cert-manager.com/api") + self._cert_auth = kwargs.get("cert_auth", False) + self._session = requests.Session() - self.__user_crt_file = kwargs.get("user_crt_file") - self.__user_key_file = kwargs.get("user_key_file") + self._user_crt_file = kwargs.get("user_crt_file") + self._user_key_file = kwargs.get("user_key_file") # Set the default HTTP headers - self.__headers = { - "login": self.__username, - "customerUri": self.__login_uri, + self._headers = { + "login": self._username, + "customerUri": self._login_uri, "Accept": "application/json", "User-Agent": self.user_agent, } # Setup the Session for certificate auth - if self.__cert_auth: + if self._cert_auth: # Require keys if cert_auth is True or raise a KeyError - self.__user_crt_file = kwargs["user_crt_file"] - self.__user_key_file = kwargs["user_key_file"] - self.__session.cert = (self.__user_crt_file, self.__user_key_file) + self._user_crt_file = kwargs["user_crt_file"] + self._user_key_file = kwargs["user_key_file"] + self._session.cert = (self._user_crt_file, self._user_key_file) # Warn about using /api instead of /private/api if doing certificate auth - if not re.search("/private", self.__base_url): - cert_uri = re.sub("/api", "/private/api", self.__base_url) + if not re.search("/private", self._base_url): + cert_uri = re.sub("/api", "/private/api", self._base_url) LOGGER.warning("base URI should probably be %s due to certificate auth", cert_uri) else: # If we're not doing certificate auth, we need a password, so make sure an exception is raised if # a password was not passed as an argument - self.__password = kwargs["password"] - self.__headers["password"] = self.__password + self._password = kwargs["password"] + self._headers["password"] = self._password - self.__session.headers.update(self.__headers) + self._session.headers.update(self._headers) @property def user_agent(self): @@ -87,18 +87,18 @@ def user_agent(self): @property def base_url(self): - """Return the internal __base_url value.""" - return self.__base_url + """Return the internal _base_url value.""" + return self._base_url @property def headers(self): - """Return the internal __headers value.""" - return self.__headers + """Return the internal _headers value.""" + return self._headers @property def session(self): - """Return the setup internal __session requests.Session object.""" - return self.__session + """Return the setup internal _session requests.Session object.""" + return self._session def add_headers(self, headers=None): """Add the provided headers to the internally stored headers. @@ -109,10 +109,10 @@ def add_headers(self, headers=None): :param dict headers: A dictionary where key is the header with its value being the setting for that header. """ if headers: - head = self.__headers.copy() + head = self._headers.copy() head.update(headers) - self.__headers = head - self.__session.headers.update(self.__headers) + self._headers = head + self._session.headers.update(self._headers) def remove_headers(self, headers=None): """Remove the requested header keys from the internally stored headers. @@ -124,9 +124,9 @@ def remove_headers(self, headers=None): """ if headers: for head in headers: - if head in self.__headers: - del self.__headers[head] - del self.__session.headers[head] + if head in self._headers: + del self._headers[head] + del self._session.headers[head] @traffic_log(traffic_logger=LOGGER) def head(self, url, headers=None, params=None): @@ -137,7 +137,7 @@ def head(self, url, headers=None, params=None): :param dict params: A dictionary with any parameters to add to the request URL :return obj: A requests.Response object received as a response """ - result = self.__session.head(url, headers=headers, params=params) + result = self._session.head(url, headers=headers, params=params) # Raise an exception if the return code is in an error range result.raise_for_status() @@ -152,7 +152,7 @@ def get(self, url, headers=None, params=None): :param dict params: A dictionary with any parameters to add to the request URL :return obj: A requests.Response object received as a response """ - result = self.__session.get(url, headers=headers, params=params) + result = self._session.get(url, headers=headers, params=params) # Raise an exception if the return code is in an error range result.raise_for_status() @@ -167,7 +167,7 @@ def post(self, url, headers=None, data=None): :param dict data: A dictionary with the data to use for the body of the POST :return obj: A requests.Response object received as a response """ - result = self.__session.post(url, json=data, headers=headers) + result = self._session.post(url, json=data, headers=headers) # Raise an exception if the return code is in an error range result.raise_for_status() @@ -182,7 +182,7 @@ def put(self, url, headers=None, data=None): :param dict data: A dictionary with the data to use for the body of the PUT :return obj: A requests.Response object received as a response """ - result = self.__session.put(url, json=data, headers=headers) + result = self._session.put(url, json=data, headers=headers) # Raise an exception if the return code is in an error range result.raise_for_status() @@ -197,8 +197,46 @@ def delete(self, url, headers=None, data=None): :param dict data: A dictionary with the data to use for the body of the DELETE :return obj: A requests.Response object received as a response """ - result = self.__session.delete(url, json=data, headers=headers) + result = self._session.delete(url, json=data, headers=headers) # Raise an exception if the return code is in an error range result.raise_for_status() return result + + +class OAuth2Client(Client): + """Serve as a Base class for calls to the Sectigo Cert Manager APIs using OAuth2.""" + + def __init__(self, client_id, client_secret, auth_url="https://auth.sso.sectigo.com/auth/realms/apiclients/protocol/openid-connect/token", + base_url="https://admin.enterprise.sectigo.com/api"): + """Initialize the class. + + :param string auth_url: The full URL to the Sectigo OAuth2 token endpoint; the default is "https://auth.sso.sectigo.com/auth/realms/apiclients/protocol/openid-connect/token" + :param string client_id: The Client ID to use for OAuth2 authentication + :param string client_secret: The Client Secret to use for OAuth2 authentication + :param string base_url: The base URL for the Sectigo API; the default is "https://admin.enterprise.sectigo.com/api" + """ + self._base_url = base_url + # Using get for consistency and to allow defaults to be easily set + self._session = requests.Session() + + payload = { + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "client_credentials" + } + headers = { + "accept": "application/json", + "content-type": "application/x-www-form-urlencoded" + } + + response = requests.post(auth_url, data=payload, headers=headers) + + # Set the default HTTP headers + self._headers = { + "Authorization": f"Bearer {response.json()['access_token']}", + "Accept": "application/json", + "User-Agent": self.user_agent, + } + + self._session.headers.update(self._headers) diff --git a/cert_manager/domain.py b/cert_manager/domain.py index 512d3fa..e393530 100644 --- a/cert_manager/domain.py +++ b/cert_manager/domain.py @@ -25,7 +25,7 @@ def __init__(self, client, api_version="v1"): """ super().__init__(client=client, endpoint="/domain", api_version=api_version) - self.__domains = None + self._domains = None def all(self, force=False): """Return a list of domains from Sectigo. @@ -34,14 +34,14 @@ def all(self, force=False): :return list: A list of dictionaries representing the domains """ - if (self.__domains) and (not force): - return self.__domains + if (self._domains) and (not force): + return self._domains result = self._client.get(self._api_url) - self.__domains = result.json() + self._domains = result.json() - return self.__domains + return self._domains def find(self, **kwargs): """Return a list of domains matching the given parameters from Sectigo. diff --git a/cert_manager/organization.py b/cert_manager/organization.py index 5ccc7b0..20356e0 100644 --- a/cert_manager/organization.py +++ b/cert_manager/organization.py @@ -20,7 +20,7 @@ def __init__(self, client, api_version="v1"): """ super().__init__(client=client, endpoint="/organization", api_version=api_version) - self.__orgs = None + self._orgs = None self.all() def all(self, force=False): @@ -30,14 +30,14 @@ def all(self, force=False): :return list: A list of dictionaries representing the organizations """ - if (self.__orgs) and (not force): - return self.__orgs + if (self._orgs) and (not force): + return self._orgs result = self._client.get(self._api_url) - self.__orgs = result.json() + self._orgs = result.json() - return self.__orgs + return self._orgs def find(self, org_name=None, dept_name=None): """Return a dictionary of organization information. diff --git a/cert_manager/smime.py b/cert_manager/smime.py index 0aab7dd..a35e469 100644 --- a/cert_manager/smime.py +++ b/cert_manager/smime.py @@ -19,7 +19,7 @@ def __init__(self, client, api_version="v1"): :param string api_version: The API version to use; the default is "v1" """ super().__init__(client=client, endpoint="/smime", api_version=api_version) - self.__reason_maxlen = 512 + self._reason_maxlen = 512 @paginate def list(self, **kwargs): @@ -195,8 +195,8 @@ def revoke(self, cert_id, reason=""): raise ValueError("Argument 'cert_id' can't be None") # Sectigo has a 512 character limit on the "reason" message, so catch that here. - if not reason or len(reason) >= self.__reason_maxlen: - raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self.__reason_maxlen} characters") + if not reason or len(reason) >= self._reason_maxlen: + raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self._reason_maxlen} characters") data = {"reason": reason} self._client.post(url, data=data) @@ -214,8 +214,8 @@ def revoke_by_email(self, email, reason=""): raise ValueError("Argument 'email' can't be empty or None") # Sectigo has a 512 character limit on the "reason" message, so catch that here. - if not reason or len(reason) >= self.__reason_maxlen: - raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self.__reason_maxlen} characters") + if not reason or len(reason) >= self._reason_maxlen: + raise ValueError(f"Sectigo limit: reason must be > 0 character and < {self._reason_maxlen} characters") data = {"email": email, "reason": reason} self._client.post(url, data=data) diff --git a/docs/README.md b/docs/README.md index 1f386d8..1201221 100644 --- a/docs/README.md +++ b/docs/README.md @@ -36,16 +36,63 @@ Other endpoints we hope to add in the near future: You can use pip to install cert_manager: -```Shell +```shell pip install cert_manager ``` +## Authentication + +Originally, Certificate Manager only allowed username and password, or client +certificate and key as methods of authenticating to the REST API. However, +OAuth2 is now supported via a completely different URL structure. This new model +can be used by swapping out the `Client` class with `OAuth2Client`. You need to +provide a Client ID and Client Secret to `OAuth2Client`, which you create via +the UI in [Sectigo][2]. Information on how to create the Client ID and Client +Secret can be found on +[How Do You Implement OAuth 2.0 for SCM](https://www.sectigo.com/knowledge-base/detail/implement-oauth-2-0-for-scm) +page. + +The following is an example of creating the `OAuth2Client` object: + +```python +from cert_manager import OAuth2Client +client = OAuth2Client( + client_id="client-id-from-sectigo", client_secret="client-secret-from-sectigo", +) +``` + +This client should then be usable by all of the existing classes in the library +in place of the original. However, pay particular attention to the version of +the API you are using for each class. Typically the version is `v1`, but there +isn't a consistent version across all endpoints. You may need to add +`api_version` to many of the object instantiations for the API to work +correctly. A complete API reference can be found on the +[SCM DevX](https://scm.devx.sectigo.com/reference/) site. This, for example, is +the first code snippet in [Examples](#examples) rewritten for the new API +infrastructure: + +```python +from cert_manager import Organization +from cert_manager import OAuth2Client +from cert_manager import SSL + +client = OAuth2Client( + client_id="client-id-from-sectigo", client_secret="client-secret-from-sectigo", +) + +org = Organization(client=client) +ssl = SSL(client=client, api_version="v2") + +print(ssl.types) +print(org.all()) +``` + ## Examples This is a simple example that just shows initializing the `Client` object and using it to query the `Organization` and `SSL` endpoints: -```Python +```python from cert_manager import Organization from cert_manager import Client from cert_manager import SSL @@ -67,7 +114,7 @@ print(org.all()) The most common process you would do, however, is enroll and then collect a certificate you want to order from the Certificate Manager: -```Python +```python from time import sleep from cert_manager import Organization @@ -128,7 +175,7 @@ To start a development environment, you should be able to just run the to build a [Docker][4] container with all the dependencies for development installed using [Poetry][3]. -```Shell +```shell ./dev.bash ``` @@ -147,7 +194,7 @@ access token. We currently use for this purpose. The following should generate the file using information from GitHub: -```Shell +```shell docker run -it --rm \ -e CHANGELOG_GITHUB_TOKEN='yourtokenhere' \ -v "$(pwd)":/working \ @@ -158,7 +205,7 @@ docker run -it --rm \ To generate the log for an upcoming release that has not yet been tagged, you can run a command to include the upcoming release version. For example, `2.0.0`: -```Shell +```shell docker run -it --rm \ -e CHANGELOG_GITHUB_TOKEN='yourtokenhere' \ -v "$(pwd)":/working \ @@ -179,14 +226,14 @@ version bumps as part of a PR, so you don't want to have [bump2version][6] tag the version at the same time it does the commit as commit hashes may change. Therefore, to bump the version a patch level, one would run the command: -```Shell +```shell bump2version --verbose --no-tag patch ``` Once the PR is merged, you can then checkout the new `main` branch and tag it using the new version number that is now in `.bumpversion.cfg`: -```Shell +```shell git checkout main git pull --rebase git tag 1.0.0 -m 'Bump version: 0.1.0 → 1.0.0' diff --git a/tests/test_admin.py b/tests/test_admin.py index 94200a9..8856029 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -73,7 +73,7 @@ def test_defaults(self): self.assertEqual(len(responses.calls), 1) self.assertEqual(responses.calls[0].request.url, self.api_url) - self.assertEqual(admin._Admin__admins, self.valid_response) + self.assertEqual(admin._admins, self.valid_response) @responses.activate def test_param(self): @@ -91,7 +91,7 @@ def test_param(self): self.assertEqual(len(responses.calls), 1) self.assertEqual(responses.calls[0].request.url, api_url) - self.assertEqual(admin._Admin__admins, self.valid_response) + self.assertEqual(admin._admins, self.valid_response) def test_need_client(self): """Raise an exception if called without a client parameter.""" diff --git a/tests/test_client.py b/tests/test_client.py index 2833668..1910a7f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -57,20 +57,20 @@ def test_defaults(self): # Use the hackity object mangling when dealing with double-underscore values in an object # This hard-coded test is to test that the default base_url is used when none is provided - self.assertEqual(client._Client__base_url, "https://cert-manager.com/api") - self.assertEqual(client._Client__login_uri, self.cfixt.login_uri) - self.assertEqual(client._Client__username, self.cfixt.username) - self.assertEqual(client._Client__password, self.cfixt.password) - self.assertEqual(client._Client__cert_auth, False) + self.assertEqual(client._base_url, "https://cert-manager.com/api") + self.assertEqual(client._login_uri, self.cfixt.login_uri) + self.assertEqual(client._username, self.cfixt.username) + self.assertEqual(client._password, self.cfixt.password) + self.assertEqual(client._cert_auth, False) # Make sure all the headers make their way into the internal requests.Session object for head, headdata in self.cfixt.headers.items(): - self.assertTrue(head in client._Client__session.headers) - self.assertEqual(client._Client__session.headers[head], headdata) + self.assertTrue(head in client._session.headers) + self.assertEqual(client._session.headers[head], headdata) # Because password was used and cert_auth was False, a password header should exist - self.assertTrue("password" in client._Client__session.headers) - self.assertEqual(self.cfixt.password, client._Client__session.headers["password"]) + self.assertTrue("password" in client._session.headers) + self.assertEqual(self.cfixt.password, client._session.headers["password"]) def test_params(self): """Set parameters correctly inside the class using all parameters.""" @@ -81,21 +81,21 @@ def test_params(self): ) # Use the hackity object mangling when dealing with double-underscore values in an object - self.assertEqual(client._Client__base_url, self.cfixt.base_url) - self.assertEqual(client._Client__login_uri, self.cfixt.login_uri) - self.assertEqual(client._Client__username, self.cfixt.username) - self.assertEqual(client._Client__cert_auth, True) - self.assertEqual(client._Client__user_crt_file, self.cfixt.user_crt_file) - self.assertEqual(client._Client__user_key_file, self.cfixt.user_key_file) - self.assertEqual(client._Client__session.cert, (self.cfixt.user_crt_file, self.cfixt.user_key_file)) + self.assertEqual(client._base_url, self.cfixt.base_url) + self.assertEqual(client._login_uri, self.cfixt.login_uri) + self.assertEqual(client._username, self.cfixt.username) + self.assertEqual(client._cert_auth, True) + self.assertEqual(client._user_crt_file, self.cfixt.user_crt_file) + self.assertEqual(client._user_key_file, self.cfixt.user_key_file) + self.assertEqual(client._session.cert, (self.cfixt.user_crt_file, self.cfixt.user_key_file)) # Make sure all the headers make their way into the internal requests.Session object for head, headdata in self.cfixt.headers.items(): - self.assertTrue(head in client._Client__session.headers) - self.assertEqual(client._Client__session.headers[head], headdata) + self.assertTrue(head in client._session.headers) + self.assertEqual(client._session.headers[head], headdata) # If cert_auth is True, make sure a password header does not exist - self.assertFalse("password" in client._Client__session.headers) + self.assertFalse("password" in client._session.headers) def test_no_pass_with_certs(self): """Set parameters correctly inside the class certificate auth without a password.""" @@ -105,21 +105,21 @@ def test_no_pass_with_certs(self): ) # Use the hackity object mangling when dealing with double-underscore values in an object - self.assertEqual(client._Client__base_url, self.cfixt.base_url) - self.assertEqual(client._Client__login_uri, self.cfixt.login_uri) - self.assertEqual(client._Client__username, self.cfixt.username) - self.assertEqual(client._Client__cert_auth, True) - self.assertEqual(client._Client__user_crt_file, self.cfixt.user_crt_file) - self.assertEqual(client._Client__user_key_file, self.cfixt.user_key_file) - self.assertEqual(client._Client__session.cert, (self.cfixt.user_crt_file, self.cfixt.user_key_file)) + self.assertEqual(client._base_url, self.cfixt.base_url) + self.assertEqual(client._login_uri, self.cfixt.login_uri) + self.assertEqual(client._username, self.cfixt.username) + self.assertEqual(client._cert_auth, True) + self.assertEqual(client._user_crt_file, self.cfixt.user_crt_file) + self.assertEqual(client._user_key_file, self.cfixt.user_key_file) + self.assertEqual(client._session.cert, (self.cfixt.user_crt_file, self.cfixt.user_key_file)) # Make sure all the headers make their way into the internal requests.Session object for head, headdata in self.cfixt.headers.items(): - self.assertTrue(head in client._Client__session.headers) - self.assertEqual(client._Client__session.headers[head], headdata) + self.assertTrue(head in client._session.headers) + self.assertEqual(client._session.headers[head], headdata) # If cert_auth is True, make sure a password header does not exist - self.assertFalse("password" in client._Client__session.headers) + self.assertFalse("password" in client._session.headers) def test_versioning(self): """Change the user-agent header if the version number changes.""" @@ -135,7 +135,7 @@ def test_versioning(self): # Make sure the user-agent header is correct in the class and the internal requests.Session object self.assertEqual(client.headers["User-Agent"], user_agent) - self.assertEqual(client._Client__session.headers["User-Agent"], user_agent) + self.assertEqual(client._session.headers["User-Agent"], user_agent) def test_need_crt(self): """Raise an exception without a cert file if cert_auth=True.""" @@ -193,7 +193,7 @@ def test_headers(self): def test_session(self): """The session property should return the correct value.""" - self.assertEqual(self.client._Client__session, self.client.session) + self.assertEqual(self.client._session, self.client.session) class TestAddHeaders(TestClient): @@ -207,13 +207,13 @@ def test_add(self): # Make sure the new headers make their way into the internal requests.Session object for header, hval in headers.items(): - self.assertTrue(header in self.client._Client__session.headers) - self.assertEqual(hval, self.client._Client__session.headers[header]) + self.assertTrue(header in self.client._session.headers) + self.assertEqual(hval, self.client._session.headers[header]) # Make sure the original headers are still in the internal requests.Session object for head, headdata in self.cfixt.headers.items(): - self.assertTrue(head in self.client._Client__session.headers) - self.assertEqual(self.client._Client__session.headers[head], headdata) + self.assertTrue(head in self.client._session.headers) + self.assertEqual(self.client._session.headers[head], headdata) def test_replace(self): """The already existing header should be modified.""" @@ -223,15 +223,15 @@ def test_replace(self): # Make sure the new headers make their way into the internal requests.Session object for header, hval in headers.items(): - self.assertTrue(header in self.client._Client__session.headers) - self.assertEqual(hval, self.client._Client__session.headers[header]) + self.assertTrue(header in self.client._session.headers) + self.assertEqual(hval, self.client._session.headers[header]) # Removed the modified header from the check as it was checked above del self.cfixt.headers["User-Agent"] # Make sure the original headers are still in the internal requests.Session object for head, headdata in self.cfixt.headers.items(): - self.assertTrue(head in self.client._Client__session.headers) - self.assertEqual(self.client._Client__session.headers[head], headdata) + self.assertTrue(head in self.client._session.headers) + self.assertEqual(self.client._session.headers[head], headdata) def test_not_dictionary(self): """Raise an exception when not passed a dictionary.""" @@ -250,13 +250,13 @@ def test_remove(self): # Make sure the headers are removed from the requests.Session object for head in headers: - self.assertFalse(head in self.client._Client__session.headers) + self.assertFalse(head in self.client._session.headers) # Make sure the rest of the headers we added before are still there for head, headdata in self.cfixt.headers.items(): if head not in headers: - self.assertTrue(head in self.client._Client__session.headers) - self.assertEqual(self.client._Client__session.headers[head], headdata) + self.assertTrue(head in self.client._session.headers) + self.assertEqual(self.client._session.headers[head], headdata) def test_dictionary(self): """Remove headers correctly if passed a dictionary.""" @@ -266,13 +266,13 @@ def test_dictionary(self): # Make sure the headers are removed from the requests.Session object for head in headers: - self.assertFalse(head in self.client._Client__session.headers) + self.assertFalse(head in self.client._session.headers) # Make sure the rest of the headers we added before are still there for head, headdata in self.cfixt.headers.items(): if head not in headers: - self.assertTrue(head in self.client._Client__session.headers) - self.assertEqual(self.client._Client__session.headers[head], headdata) + self.assertTrue(head in self.client._session.headers) + self.assertEqual(self.client._session.headers[head], headdata) class TestGet(TestClient): diff --git a/tests/test_oauth2_client.py b/tests/test_oauth2_client.py new file mode 100644 index 0000000..33c4723 --- /dev/null +++ b/tests/test_oauth2_client.py @@ -0,0 +1,119 @@ +"""Define the cert_manager.client.OAuth2Client unit tests.""" +# Don't warn about things that happen as that is part of unit testing +# pylint: disable=protected-access +# pylint: disable=no-member +import sys +from unittest import mock + +from testtools import TestCase + +from cert_manager import __version__ +from cert_manager.client import OAuth2Client + + +class TestOAuth2Client(TestCase): + """Serve as a Base class for all tests of the OAuth2Client class.""" + + def setUp(self): + """Initialize the class.""" + super().setUp() + + self.client_id = "test_client_id" + self.client_secret = "test_client_secret" + self.auth_url = "https://auth.example.com/token" + self.base_url = "https://api.example.com" + self.access_token = "test_access_token" + + # Calculate expected user agent + ver_info = list(map(str, sys.version_info)) + pyver = ".".join(ver_info[:3]) + self.user_agent = f"cert_manager/{__version__.__version__} (Python {pyver})" + + def tearDown(self): + """Test tear down method.""" + super().tearDown() + mock.patch.stopall() + + @mock.patch("requests.post") + def test_init(self, mock_post): + """Test initialization of OAuth2Client.""" + # Setup mock response + mock_response = mock.Mock() + mock_response.json.return_value = {"access_token": self.access_token} + mock_post.return_value = mock_response + + # Initialize client + client = OAuth2Client( + client_id=self.client_id, + client_secret=self.client_secret, + auth_url=self.auth_url, + base_url=self.base_url + ) + + # Verify post called with correct args + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + self.assertEqual(args[0], self.auth_url) + self.assertEqual(kwargs["data"], { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "client_credentials" + }) + self.assertEqual(kwargs["headers"], { + "accept": "application/json", + "content-type": "application/x-www-form-urlencoded" + }) + + # Verify client properties + self.assertEqual(client._base_url, self.base_url) + + # Verify headers + expected_headers = { + "Authorization": f"Bearer {self.access_token}", + "Accept": "application/json", + "User-Agent": self.user_agent, + } + + for head, value in expected_headers.items(): + self.assertIn(head, client._headers) + self.assertEqual(client._headers[head], value) + self.assertIn(head, client.session.headers) + self.assertEqual(client.session.headers[head], value) + + @mock.patch("requests.post") + def test_defaults(self, mock_post): + """Test initialization with default values.""" + # Setup mock response + mock_response = mock.Mock() + mock_response.json.return_value = {"access_token": self.access_token} + mock_post.return_value = mock_response + + # Initialize client with defaults + client = OAuth2Client( + client_id=self.client_id, + client_secret=self.client_secret + ) + + # Verify defaults + expected_auth_url = "https://auth.sso.sectigo.com/auth/realms/apiclients/protocol/openid-connect/token" + expected_base_url = "https://admin.enterprise.sectigo.com/api" + + self.assertEqual(client._base_url, expected_base_url) + + mock_post.assert_called_once() + args, _ = mock_post.call_args + self.assertEqual(args[0], expected_auth_url) + + @mock.patch("requests.post") + def test_auth_failure(self, mock_post): + """Test initialization when authentication fails.""" + # Setup mock response for failure + mock_response = mock.Mock() + mock_response.json.return_value = {"error": "invalid_client"} + mock_post.return_value = mock_response + + # Expect KeyError because access_token is missing + self.assertRaises(KeyError, OAuth2Client, + client_id=self.client_id, + client_secret=self.client_secret + ) diff --git a/tests/test_organization.py b/tests/test_organization.py index d6dbc71..76f6135 100644 --- a/tests/test_organization.py +++ b/tests/test_organization.py @@ -57,7 +57,7 @@ def test_defaults(self): self.assertEqual(len(responses.calls), 1) self.assertEqual(responses.calls[0].request.url, self.api_url) - self.assertEqual(org._Organization__orgs, self.valid_response) + self.assertEqual(org._orgs, self.valid_response) @responses.activate def test_param(self): @@ -75,7 +75,7 @@ def test_param(self): self.assertEqual(len(responses.calls), 1) self.assertEqual(responses.calls[0].request.url, api_url) - self.assertEqual(org._Organization__orgs, self.valid_response) + self.assertEqual(org._orgs, self.valid_response) def test_need_client(self): """Raise an exception if called without a client parameter."""