From 72c1cf61ee7b52d605b011a6aa89f7425358ce1e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 15:52:21 +0000 Subject: [PATCH 1/5] Initial plan From b7fd269c68e3b3d3f4d98cd6829ef9d6b6076778 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:09:50 +0000 Subject: [PATCH 2/5] Add MSI v2 (mTLS PoP) support with Windows KeyGuard attestation Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com> --- msal/__init__.py | 1 + msal/managed_identity.py | 45 + msal/msi_v2.py | 1594 ++++++++++++++++++++++++++++++++++++ msal/msi_v2_attestation.py | 182 ++++ msi-v2-sample.spec | 45 + run_msi_v2_once.py | 56 ++ sample/msi_v2_sample.py | 175 ++++ tests/test_msi_v2.py | 321 ++++++++ 8 files changed, 2419 insertions(+) create mode 100644 msal/msi_v2.py create mode 100644 msal/msi_v2_attestation.py create mode 100644 msi-v2-sample.spec create mode 100644 run_msi_v2_once.py create mode 100644 sample/msi_v2_sample.py create mode 100644 tests/test_msi_v2.py diff --git a/msal/__init__.py b/msal/__init__.py index 295e9756..81763bb1 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -38,6 +38,7 @@ SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient, ManagedIdentityError, + MsiV2Error, ArcPlatformNotSupportedError, ) diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 422b76e3..85d77a8b 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -24,6 +24,11 @@ class ManagedIdentityError(ValueError): pass +class MsiV2Error(ManagedIdentityError): + """Raised when the MSI v2 (mTLS PoP) flow fails.""" + pass + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -259,6 +264,8 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + mtls_proof_of_possession: bool = False, + with_attestation_support: bool = False, ): """Acquire token for the managed identity. @@ -278,6 +285,23 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param bool mtls_proof_of_possession: (optional) + When True, use the MSI v2 (mTLS Proof-of-Possession) flow to acquire an + ``mtls_pop`` token bound to a short-lived mTLS certificate issued by the + IMDS ``/issuecredential`` endpoint. + Without this flag the legacy IMDS v1 flow is used. + Defaults to False. + + MSI v2 is used only when both ``mtls_proof_of_possession`` and + ``with_attestation_support`` are True. + + :param bool with_attestation_support: (optional) + When True (and ``mtls_proof_of_possession`` is also True), attempt + KeyGuard / platform attestation before credential issuance. + On Windows this leverages ``AttestationClientLib.dll`` when available; + on other platforms the parameter is silently ignored. + Defaults to False. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -292,6 +316,27 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + # MSI v2 is opt-in: use it only when BOTH mtls_proof_of_possession and + # with_attestation_support are explicitly requested by the caller. + # No auto-fallback: if MSI v2 is requested and fails, the error is raised. + use_msi_v2 = bool(mtls_proof_of_possession and with_attestation_support) + + if with_attestation_support and not mtls_proof_of_possession: + raise ManagedIdentityError( + "attestation_requires_pop", + "with_attestation_support=True requires mtls_proof_of_possession=True (mTLS PoP)." + ) + + if use_msi_v2: + from .msi_v2 import obtain_token as _obtain_token_v2 + result = _obtain_token_v2( + self._http_client, self._managed_identity, resource, + attestation_enabled=True, + ) + if "access_token" in result and "error" not in result: + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP + return result + if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( diff --git a/msal/msi_v2.py b/msal/msi_v2.py new file mode 100644 index 00000000..92016350 --- /dev/null +++ b/msal/msi_v2.py @@ -0,0 +1,1594 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +MSI v2 (IMDSv2) Managed Identity flow — Windows KeyGuard + Attestation + SChannel mTLS PoP. + +This module implements the end-to-end "MSI v2" flow used by Azure Managed Identity on Windows +when *certificate-bound* access tokens are requested (token_type=mtls_pop). + +It is intentionally "Python-only": no pythonnet/.NET interop is required. Instead, it uses +ctypes to call a small set of Windows APIs: + + * CNG/NCrypt (ncrypt.dll) - Create a KeyGuard/VBS isolated RSA key + sign CSR (RSA-PSS/SHA256) + * Crypt32 (crypt32.dll) - Bind the issued certificate to the CNG private key + * WinHTTP (winhttp.dll) - Perform the token request over mTLS using SChannel + +Flow summary (mirrors your working PowerShell implementation): + + 1) GET /metadata/identity/getplatformmetadata?cred-api-version=2.0 + 2) Create KeyGuard RSA key (non-exportable, VBS-isolated) + 3) Build CSR signed with RSA-PSS/SHA256 and include a special CSR attribute: + OID 1.3.6.1.4.1.311.90.2.10 -> DER UTF8String(JSON(cuId)) + 4) Get an attestation JWT for the key (via .msi_v2_attestation.get_attestation_jwt) + 5) POST /metadata/identity/issuecredential?cred-api-version=2.0 with csr + attestation_token + 6) POST {tenant}/oauth2/v2.0/token (mTLS) with the issued certificate and token_type=mtls_pop + +Important design choices: + + * Windows-only: importing on non-Windows platforms is supported, but calling obtain_token() + will raise MsiV2Error. + * No MSI v1 fallback: any failure raises MsiV2Error. + * Defensive certificate-to-key binding: we set multiple certificate context properties so + WinHTTP/SChannel can consistently locate the private key. + +Security notes: + + * Access tokens are secrets. Avoid logging or printing them in production. + * The KeyGuard RSA key is created as persisted, but is deleted in cleanup. + +Public entrypoint: obtain_token(http_client, managed_identity, resource, attestation_enabled=True) +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import sys +import uuid +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +__all__ = [ + "get_cert_thumbprint_sha256", + "verify_cnf_binding", + "obtain_token", +] + +# -------------------------------------------------------------------------------------- +# IMDS / MSI v2 constants +# -------------------------------------------------------------------------------------- + +_IMDS_DEFAULT_BASE = "http://169.254.169.254" +_IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + +_API_VERSION_QUERY_PARAM = "cred-api-version" +_IMDS_V2_API_VERSION = "2.0" + +_CSR_METADATA_PATH = "/metadata/identity/getplatformmetadata" +_ISSUE_CREDENTIAL_PATH = "/metadata/identity/issuecredential" +_ACQUIRE_ENTRA_TOKEN_PATH = "/oauth2/v2.0/token" + +# OID for the special CSR request attribute carrying cuId JSON. +_CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" + +# Flags from ncrypt.h used by the PowerShell reference implementation. +_NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 + +_RSA_KEY_SIZE = 2048 + +# Legacy KeySpec values (CAPI compatibility / CNG interop). +# Used by NCryptCreatePersistedKey.dwLegacyKeySpec and by CRYPT_KEY_PROV_INFO.dwKeySpec. +_AT_KEYEXCHANGE = 1 +_AT_SIGNATURE = 2 + +# CRYPT_KEY_PROV_INFO.dwFlags for CNG keys (best-effort suppression of UI prompts). +_NCRYPT_SILENT_FLAG = 0x40 + +_DEFAULT_KSP_NAME = "Microsoft Software Key Storage Provider" + +# -------------------------------------------------------------------------------------- +# Compatibility helpers (optional; useful for tests or debugging) +# -------------------------------------------------------------------------------------- + + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """ + Compute base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 comparisons. + + Accepts a PEM-encoded certificate string. + + Returns: + Base64url-encoded SHA-256 thumbprint without '=' padding, or "" if cryptography is + unavailable or parsing fails. + """ + try: + # cryptography is optional; keep this helper lightweight. + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"), default_backend()) + der = cert.public_bytes(serialization.Encoding.DER) + digest = hashlib.sha256(der).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + except Exception: + # Fail closed: if we cannot compute the thumbprint, binding verification cannot succeed. + return "" + + +def verify_cnf_binding(token: str, cert_pem: str) -> bool: + """ + Verify that a JWT payload contains cnf.x5t#S256 matching the cert thumbprint. + + This is a *best-effort* helper for validating certificate binding in tests. It does not + validate JWT signature or claims (aud/iss/exp/etc). + + Args: + token: A JWT access token (3-part base64url string). + cert_pem: PEM certificate string. + + Returns: + True if cnf.x5t#S256 exists and equals the SHA-256 certificate thumbprint. + """ + try: + parts = token.split(".") + if len(parts) != 3: + return False + + payload_b64 = parts[1] + payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload_b64.encode("ascii"))) + + cnf = claims.get("cnf", {}) if isinstance(claims, dict) else {} + token_x5t = cnf.get("x5t#S256") + if not token_x5t: + return False + + cert_x5t = get_cert_thumbprint_sha256(cert_pem) + if not cert_x5t: + return False + + return token_x5t == cert_x5t + except Exception: + return False + + +# -------------------------------------------------------------------------------------- +# IMDS helpers +# -------------------------------------------------------------------------------------- + + +def _imds_base() -> str: + """Resolve IMDS base URI (supports pod identity override via env var).""" + return os.getenv(_IMDS_BASE_ENVVAR, _IMDS_DEFAULT_BASE).strip().rstrip("/") + + +def _new_correlation_id() -> str: + """Generate an RFC 4122 correlation id used in x-ms-client-request-id.""" + return str(uuid.uuid4()) + + +def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: + """ + Headers required by IMDS. The Metadata=true header is mandatory. + + We also include x-ms-client-request-id to correlate IMDS and ESTS requests. + """ + return { + "Metadata": "true", + "x-ms-client-request-id": correlation_id or _new_correlation_id(), + } + + +def _resource_to_scope(resource_or_scope: str) -> str: + """ + Convert an ADAL-style 'resource' string into an MSAL v2 scope string. + + IMDS v2 uses MSAL v2 token endpoint semantics (scope=.../.default). + """ + s = (resource_or_scope or "").strip() + if not s: + raise ValueError("resource must be non-empty") + if s.endswith("/.default"): + return s + return s.rstrip("/") + "/.default" + + +def _der_utf8string(value: str) -> bytes: + """ + Minimal DER UTF8String encoder (tag 0x0C). + + Used for the CSR request attribute value (cuId JSON) and for X.500 CN when applicable. + """ + raw = value.encode("utf-8") + n = len(raw) + if n < 0x80: + len_bytes = bytes([n]) + else: + tmp = bytearray() + m = n + while m > 0: + tmp.insert(0, m & 0xFF) + m >>= 8 + len_bytes = bytes([0x80 | len(tmp)]) + bytes(tmp) + return bytes([0x0C]) + len_bytes + raw + + +def _json_loads(text: str, what: str) -> Dict[str, Any]: + """Parse JSON or raise MsiV2Error with context.""" + from .managed_identity import MsiV2Error + + try: + data = json.loads(text) + if isinstance(data, dict): + return data + raise MsiV2Error(f"[msi_v2] Expected JSON object from {what}, got {type(data).__name__}") + except Exception as exc: # pylint: disable=broad-except + raise MsiV2Error(f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc + + +def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: + """ + Fetch the first non-empty string value among several possible keys. + + IMDS field casing can vary (camelCase vs snake_case), so this helper checks: + 1) Exact keys + 2) Case-insensitive matches + """ + # direct keys + for n in names: + if n in obj and obj[n] is not None and str(obj[n]).strip() != "": + return str(obj[n]) + + # case-insensitive + lower = {str(k).lower(): k for k in obj.keys()} + for n in names: + k = lower.get(n.lower()) + if k and obj[k] is not None and str(obj[k]).strip() != "": + return str(obj[k]) + + return None + + +def _mi_query_params(managed_identity: Optional[Dict[str, Any]]) -> Dict[str, str]: + """ + Build IMDS query parameters: + * cred-api-version=2.0 (required) + * optional user-assigned identity selectors + + managed_identity shape (MSAL Python): + {"ManagedIdentityIdType": "ClientId"|"ObjectId"|"ResourceId", "Id": ""} + """ + params: Dict[str, str] = {_API_VERSION_QUERY_PARAM: _IMDS_V2_API_VERSION} + + if not isinstance(managed_identity, dict): + return params + + id_type = managed_identity.get("ManagedIdentityIdType") + identifier = managed_identity.get("Id") + + mapping = {"ClientId": "client_id", "ObjectId": "object_id", "ResourceId": "msi_res_id"} + wire = mapping.get(id_type) + if wire and identifier: + params[wire] = str(identifier) + + return params + + +def _imds_get_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str]) -> Dict[str, Any]: + """ + GET JSON from IMDS with a basic 'Server' header sanity check. + + Note: The "server: IMDS/..." header check is a defense-in-depth measure to reduce the + chance of SSRF misuse. Keep it strict unless you have a concrete reason to loosen it. + """ + from .managed_identity import MsiV2Error + + resp = http_client.get(url, params=params, headers=headers) + + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") + + if resp.status_code != 200: + raise MsiV2Error(f"[msi_v2] IMDSv2 GET {url} failed: HTTP {resp.status_code}: {resp.text}") + + return _json_loads(resp.text, f"GET {url}") + + +def _imds_post_json( + http_client, url: str, params: Dict[str, str], headers: Dict[str, str], body: Dict[str, Any] +) -> Dict[str, Any]: + """POST JSON to IMDS and return JSON response (with same header sanity check).""" + from .managed_identity import MsiV2Error + + resp = http_client.post(url, params=params, headers=headers, data=json.dumps(body, separators=(",", ":"))) + + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") + + if resp.status_code != 200: + raise MsiV2Error(f"[msi_v2] IMDSv2 POST {url} failed: HTTP {resp.status_code}: {resp.text}") + + return _json_loads(resp.text, f"POST {url}") + + +def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: + """ + Determine the token endpoint returned by /issuecredential. + + IMDS can return either: + * token_endpoint + * mtls_authentication_endpoint + tenant_id (compose into {mtls_auth}/{tenant}/oauth2/v2.0/token) + """ + token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") + if token_endpoint: + return token_endpoint + + mtls_auth = _get_first( + cred, + "mtls_authentication_endpoint", + "mtlsAuthenticationEndpoint", + "mtls_authenticationEndpoint", + ) + tenant_id = _get_first(cred, "tenant_id", "tenantId") + if not mtls_auth or not tenant_id: + from .managed_identity import MsiV2Error + + raise MsiV2Error(f"[msi_v2] issuecredential missing mtls_authentication_endpoint/tenant_id: {cred}") + + base = mtls_auth.rstrip("/") + "/" + tenant_id.strip("/") + return base + _ACQUIRE_ENTRA_TOKEN_PATH + + +# -------------------------------------------------------------------------------------- +# Win32 primitives (ctypes) - lazy loaded +# -------------------------------------------------------------------------------------- + +_WIN32: Optional[Dict[str, Any]] = None + + +def _load_win32() -> Dict[str, Any]: + """ + Lazy-load Win32 APIs via ctypes. + + Keeping the import behind a function allows importing this module on non-Windows platforms + without failing at import time. The public obtain_token() function enforces Windows-only. + """ + global _WIN32 + + from .managed_identity import MsiV2Error + + if _WIN32 is not None: + return _WIN32 + + if sys.platform != "win32": + raise MsiV2Error("[msi_v2] KeyGuard + attested mTLS PoP is Windows-only.") + + import ctypes + from ctypes import wintypes + + # DLLs (use_last_error makes ctypes.get_last_error() reliable for BOOL-returning APIs) + ncrypt = ctypes.WinDLL("ncrypt.dll") + crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) + winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + + # --- Types --- + NCRYPT_PROV_HANDLE = ctypes.c_void_p + NCRYPT_KEY_HANDLE = ctypes.c_void_p + SECURITY_STATUS = ctypes.c_long # LONG / NTSTATUS style + + # Crypt32 certificate context + class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ("dwCertEncodingType", wintypes.DWORD), + ("pbCertEncoded", ctypes.POINTER(ctypes.c_ubyte)), + ("cbCertEncoded", wintypes.DWORD), + ("pCertInfo", ctypes.c_void_p), + ("hCertStore", ctypes.c_void_p), + ] + + PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) + + # Padding info for NCryptSignHash (RSA-PSS) + class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): + _fields_ = [ + ("pszAlgId", ctypes.c_wchar_p), + ("cbSalt", wintypes.ULONG), + ] + + # --- Constants (subset) --- + ERROR_SUCCESS = 0 + + # ncrypt.h flags + NCRYPT_OVERWRITE_KEY_FLAG = 0x00000080 + + # key properties (ncrypt.h) + NCRYPT_LENGTH_PROPERTY = "Length" + NCRYPT_EXPORT_POLICY_PROPERTY = "Export Policy" + NCRYPT_KEY_USAGE_PROPERTY = "Key Usage" + + # key usage flags (ncrypt.h) + NCRYPT_ALLOW_SIGNING_FLAG = 0x00000002 + NCRYPT_ALLOW_DECRYPT_FLAG = 0x00000001 + + # bcrypt.h / padding + BCRYPT_PAD_PSS = 0x00000008 + BCRYPT_SHA256_ALGORITHM = "SHA256" + BCRYPT_RSA_ALGORITHM = "RSA" + BCRYPT_RSAPUBLIC_BLOB = "RSAPUBLICBLOB" + BCRYPT_RSAPUBLIC_MAGIC = 0x31415352 # 'RSA1' + + # wincrypt.h + X509_ASN_ENCODING = 0x00000001 + PKCS_7_ASN_ENCODING = 0x00010000 + CERT_NCRYPT_KEY_HANDLE_PROP_ID = 78 + CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG = 0x40000000 + + # WinHTTP constants + WINHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 + WINHTTP_FLAG_SECURE = 0x00800000 + WINHTTP_OPTION_CLIENT_CERT_CONTEXT = 47 + WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT = 161 + WINHTTP_QUERY_STATUS_CODE = 19 + WINHTTP_QUERY_FLAG_NUMBER = 0x20000000 + + # --- Function prototypes (argtypes/restype) --- + # NCrypt + ncrypt.NCryptOpenStorageProvider.argtypes = [ctypes.POINTER(NCRYPT_PROV_HANDLE), ctypes.c_wchar_p, wintypes.DWORD] + ncrypt.NCryptOpenStorageProvider.restype = SECURITY_STATUS + + ncrypt.NCryptCreatePersistedKey.argtypes = [ + NCRYPT_PROV_HANDLE, + ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, # alg id + ctypes.c_wchar_p, # key name + wintypes.DWORD, # legacy keyspec + wintypes.DWORD, # flags + ] + ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS + + ncrypt.NCryptSetProperty.argtypes = [ + ctypes.c_void_p, + ctypes.c_wchar_p, + ctypes.c_void_p, + wintypes.DWORD, + wintypes.DWORD, + ] + ncrypt.NCryptSetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptFinalizeKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] + ncrypt.NCryptFinalizeKey.restype = SECURITY_STATUS + + ncrypt.NCryptGetProperty.argtypes = [ + ctypes.c_void_p, + ctypes.c_wchar_p, + ctypes.c_void_p, + wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), + wintypes.DWORD, + ] + ncrypt.NCryptGetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptExportKey.argtypes = [ + NCRYPT_KEY_HANDLE, + NCRYPT_KEY_HANDLE, + ctypes.c_wchar_p, + ctypes.c_void_p, + ctypes.c_void_p, + wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), + wintypes.DWORD, + ] + ncrypt.NCryptExportKey.restype = SECURITY_STATUS + + ncrypt.NCryptSignHash.argtypes = [ + NCRYPT_KEY_HANDLE, + ctypes.c_void_p, # padding info + ctypes.c_void_p, # hash bytes + wintypes.DWORD, # hash len + ctypes.c_void_p, # sig out + wintypes.DWORD, # sig out len + ctypes.POINTER(wintypes.DWORD), + wintypes.DWORD, # flags + ] + ncrypt.NCryptSignHash.restype = SECURITY_STATUS + + ncrypt.NCryptDeleteKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] + ncrypt.NCryptDeleteKey.restype = SECURITY_STATUS + + ncrypt.NCryptFreeObject.argtypes = [ctypes.c_void_p] + ncrypt.NCryptFreeObject.restype = SECURITY_STATUS + + # Crypt32 + crypt32.CertCreateCertificateContext.argtypes = [wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + crypt32.CertCreateCertificateContext.restype = PCCERT_CONTEXT + + crypt32.CertSetCertificateContextProperty.argtypes = [PCCERT_CONTEXT, wintypes.DWORD, wintypes.DWORD, ctypes.c_void_p] + crypt32.CertSetCertificateContextProperty.restype = wintypes.BOOL + + crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] + crypt32.CertFreeCertificateContext.restype = wintypes.BOOL + + # WinHTTP + winhttp.WinHttpOpen.argtypes = [ + ctypes.c_wchar_p, + wintypes.DWORD, + ctypes.c_wchar_p, + ctypes.c_wchar_p, + wintypes.DWORD, + ] + winhttp.WinHttpOpen.restype = ctypes.c_void_p + + winhttp.WinHttpConnect.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] + winhttp.WinHttpConnect.restype = ctypes.c_void_p + + winhttp.WinHttpOpenRequest.argtypes = [ + ctypes.c_void_p, + ctypes.c_wchar_p, + ctypes.c_wchar_p, + ctypes.c_wchar_p, + ctypes.c_wchar_p, + ctypes.c_void_p, + wintypes.DWORD, + ] + winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p + + winhttp.WinHttpSetOption.argtypes = [ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpSetOption.restype = wintypes.BOOL + + winhttp.WinHttpSendRequest.argtypes = [ + ctypes.c_void_p, + ctypes.c_wchar_p, + wintypes.DWORD, + ctypes.c_void_p, + wintypes.DWORD, + wintypes.DWORD, + ctypes.c_ulonglong, # context + ] + winhttp.WinHttpSendRequest.restype = wintypes.BOOL + + winhttp.WinHttpReceiveResponse.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL + + winhttp.WinHttpQueryHeaders.argtypes = [ + ctypes.c_void_p, + wintypes.DWORD, + ctypes.c_wchar_p, + ctypes.c_void_p, + ctypes.POINTER(wintypes.DWORD), + ctypes.POINTER(wintypes.DWORD), + ] + winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL + + winhttp.WinHttpQueryDataAvailable.argtypes = [ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryDataAvailable.restype = wintypes.BOOL + + winhttp.WinHttpReadData.argtypes = [ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpReadData.restype = wintypes.BOOL + + winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] + winhttp.WinHttpCloseHandle.restype = wintypes.BOOL + + # Kernel32 + kernel32.GetLastError.argtypes = [] + kernel32.GetLastError.restype = wintypes.DWORD + + _WIN32 = { + "ctypes": ctypes, + "wintypes": wintypes, + "ncrypt": ncrypt, + "crypt32": crypt32, + "winhttp": winhttp, + "kernel32": kernel32, + # types + "NCRYPT_PROV_HANDLE": NCRYPT_PROV_HANDLE, + "NCRYPT_KEY_HANDLE": NCRYPT_KEY_HANDLE, + "SECURITY_STATUS": SECURITY_STATUS, + "CERT_CONTEXT": CERT_CONTEXT, + "PCCERT_CONTEXT": PCCERT_CONTEXT, + "BCRYPT_PSS_PADDING_INFO": BCRYPT_PSS_PADDING_INFO, + # constants + "ERROR_SUCCESS": ERROR_SUCCESS, + "NCRYPT_OVERWRITE_KEY_FLAG": NCRYPT_OVERWRITE_KEY_FLAG, + "NCRYPT_LENGTH_PROPERTY": NCRYPT_LENGTH_PROPERTY, + "NCRYPT_EXPORT_POLICY_PROPERTY": NCRYPT_EXPORT_POLICY_PROPERTY, + "NCRYPT_KEY_USAGE_PROPERTY": NCRYPT_KEY_USAGE_PROPERTY, + "NCRYPT_ALLOW_SIGNING_FLAG": NCRYPT_ALLOW_SIGNING_FLAG, + "NCRYPT_ALLOW_DECRYPT_FLAG": NCRYPT_ALLOW_DECRYPT_FLAG, + "BCRYPT_PAD_PSS": BCRYPT_PAD_PSS, + "BCRYPT_SHA256_ALGORITHM": BCRYPT_SHA256_ALGORITHM, + "BCRYPT_RSA_ALGORITHM": BCRYPT_RSA_ALGORITHM, + "BCRYPT_RSAPUBLIC_BLOB": BCRYPT_RSAPUBLIC_BLOB, + "BCRYPT_RSAPUBLIC_MAGIC": BCRYPT_RSAPUBLIC_MAGIC, + "X509_ASN_ENCODING": X509_ASN_ENCODING, + "PKCS_7_ASN_ENCODING": PKCS_7_ASN_ENCODING, + "CERT_NCRYPT_KEY_HANDLE_PROP_ID": CERT_NCRYPT_KEY_HANDLE_PROP_ID, + "CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG": CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG, + "WINHTTP_ACCESS_TYPE_DEFAULT_PROXY": WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, + "WINHTTP_FLAG_SECURE": WINHTTP_FLAG_SECURE, + "WINHTTP_OPTION_CLIENT_CERT_CONTEXT": WINHTTP_OPTION_CLIENT_CERT_CONTEXT, + "WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT": WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT, + "WINHTTP_QUERY_STATUS_CODE": WINHTTP_QUERY_STATUS_CODE, + "WINHTTP_QUERY_FLAG_NUMBER": WINHTTP_QUERY_FLAG_NUMBER, + } + return _WIN32 + + +def _format_win32_error(ctypes_mod, code: int) -> str: + """Format a Win32 error code into a human-readable string (best-effort).""" + try: + return ctypes_mod.FormatError(code).strip() + except Exception: + return "" + + +def _raise_win32_last_error(msg: str) -> None: + """ + Raise MsiV2Error with the current Win32 last-error code. + + Use for WinHTTP/Crypt32 APIs where failure is indicated via BOOL/NULL and details are in GetLastError(). + """ + from .managed_identity import MsiV2Error + + win32 = _load_win32() + ctypes_mod = win32["ctypes"] + err = ctypes_mod.get_last_error() + detail = _format_win32_error(ctypes_mod, err) + if detail: + raise MsiV2Error(f"{msg} (winerror={err} {detail})") + raise MsiV2Error(f"{msg} (winerror={err})") + + +def _check_security_status(status: int, what: str) -> None: + """ + Check SECURITY_STATUS/NTSTATUS-style return codes from NCrypt. + + Most NCrypt APIs return 0 for success; otherwise they return a status code (often an NTSTATUS). + """ + from .managed_identity import MsiV2Error + + if int(status) != 0: + code_u32 = int(status) & 0xFFFFFFFF + raise MsiV2Error(f"[msi_v2] {what} failed: status=0x{code_u32:08X}") + + +# -------------------------------------------------------------------------------------- +# DER helpers (minimal PKCS#10 CSR builder) +# -------------------------------------------------------------------------------------- + +# This is a minimal DER encoder sufficient for: +# * PKCS#10 CertificationRequestInfo (subject, spki, attributes) +# * RSASSA-PSS AlgorithmIdentifier params +# It intentionally avoids general ASN.1 frameworks to keep dependencies low. + + +def _der_len(n: int) -> bytes: + if n < 0: + raise ValueError("DER length cannot be negative") + if n < 0x80: + return bytes([n]) + out = bytearray() + m = n + while m > 0: + out.insert(0, m & 0xFF) + m >>= 8 + return bytes([0x80 | len(out)]) + bytes(out) + + +def _der(tag: int, content: bytes) -> bytes: + return bytes([tag]) + _der_len(len(content)) + content + + +def _der_null() -> bytes: + return b"\x05\x00" + + +def _der_integer(value: int) -> bytes: + if value < 0: + raise ValueError("Only non-negative INTEGER supported") + if value == 0: + raw = b"\x00" + else: + raw = value.to_bytes((value.bit_length() + 7) // 8, "big") + if raw[0] & 0x80: + raw = b"\x00" + raw + return _der(0x02, raw) + + +def _der_oid(oid: str) -> bytes: + parts = [int(x) for x in oid.split(".")] + if len(parts) < 2: + raise ValueError(f"Invalid OID: {oid}") + if parts[0] > 2 or parts[1] >= 40: + raise ValueError(f"Invalid OID: {oid}") + first = 40 * parts[0] + parts[1] + out = bytearray([first]) + for p in parts[2:]: + if p < 0: + raise ValueError(f"Invalid OID component: {oid}") + # base-128 encoding + stack = bytearray() + if p == 0: + stack.append(0) + else: + m = p + while m > 0: + stack.insert(0, m & 0x7F) + m >>= 7 + for i in range(len(stack) - 1): + stack[i] |= 0x80 + out.extend(stack) + return _der(0x06, bytes(out)) + + +def _der_sequence(*items: bytes) -> bytes: + return _der(0x30, b"".join(items)) + + +def _der_set(*items: bytes) -> bytes: + # DER SET requires elements to be sorted by their full DER encoding. + enc = sorted(items) + return _der(0x31, b"".join(enc)) + + +def _der_bitstring(data: bytes) -> bytes: + # 0 unused bits + return _der(0x03, b"\x00" + data) + + +def _der_ia5string(value: str) -> bytes: + raw = value.encode("ascii") + return _der(0x16, raw) + + +def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: + if not 0 <= tagnum <= 30: + raise ValueError("Unsupported tag number") + return _der(0xA0 + tagnum, inner) + + +def _der_context_implicit_constructed(tagnum: int, inner_content: bytes) -> bytes: + """ + Context-specific IMPLICIT, constructed. + + Used for PKCS#10 attributes: + attributes [0] IMPLICIT SET OF Attribute + Where we encode the SET OF's contents without the SET tag (0x31). + """ + if not 0 <= tagnum <= 30: + raise ValueError("Unsupported tag number") + return _der(0xA0 + tagnum, inner_content) + + +def _der_name_cn_dc(cn: str, dc: str) -> bytes: + """ + Encode X.500 Name with CN and DC RDNs. + + CN (2.5.4.3) is encoded as UTF8String. + DC (0.9.2342.19200300.100.1.25) is usually IA5String (ASCII), else UTF8String. + """ + cn_atv = _der_sequence(_der_oid("2.5.4.3"), _der_utf8string(cn)) + cn_rdn = _der_set(cn_atv) + + try: + dc_value = _der_ia5string(dc) + except Exception: + dc_value = _der_utf8string(dc) + dc_atv = _der_sequence(_der_oid("0.9.2342.19200300.100.1.25"), dc_value) + dc_rdn = _der_set(dc_atv) + + # RDNSequence is a SEQUENCE of SETs + return _der_sequence(cn_rdn, dc_rdn) + + +def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: + rsa_pub = _der_sequence(_der_integer(modulus), _der_integer(exponent)) + alg = _der_sequence(_der_oid("1.2.840.113549.1.1.1"), _der_null()) # rsaEncryption + NULL + return _der_sequence(alg, _der_bitstring(rsa_pub)) + + +def _der_algid_rsapss_sha256() -> bytes: + """ + AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), saltLength=32, trailerField=1. + + This matches what .NET / PowerShell emits for the working flow. + """ + sha256 = _der_sequence(_der_oid("2.16.840.1.101.3.4.2.1"), _der_null()) + mgf1 = _der_sequence(_der_oid("1.2.840.113549.1.1.8"), sha256) + salt_len = _der_integer(32) + trailer = _der_integer(1) + + params = _der_sequence( + _der_context_explicit(0, sha256), + _der_context_explicit(1, mgf1), + _der_context_explicit(2, salt_len), + _der_context_explicit(3, trailer), + ) + return _der_sequence(_der_oid("1.2.840.113549.1.1.10"), params) + + +# -------------------------------------------------------------------------------------- +# CNG/NCrypt wrappers +# -------------------------------------------------------------------------------------- + + +def _ncrypt_get_property(win32: Dict[str, Any], handle: Any, name: str) -> bytes: + """Get an NCrypt property value as raw bytes.""" + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + + status = ncrypt.NCryptGetProperty(handle, name, None, 0, ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, f"NCryptGetProperty({name})") + + if cb.value == 0: + return b"" + + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptGetProperty(handle, name, buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, f"NCryptGetProperty({name})") + + return bytes(buf[: cb.value]) + + +def _create_keyguard_rsa_key(win32: Dict[str, Any]) -> Tuple[Any, Any, str]: + """ + Create a non-exportable RSA key protected with VBS/KeyGuard. + + Returns: + (prov_handle, key_handle, key_name) + + key_name is the persisted CNG key name (container name). WinHTTP/SChannel can require it to + re-open the key when doing client-certificate authentication. + """ + from .managed_identity import MsiV2Error + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + prov = win32["NCRYPT_PROV_HANDLE"]() + status = ncrypt.NCryptOpenStorageProvider(ctypes_mod.byref(prov), _DEFAULT_KSP_NAME, 0) + _check_security_status(status, "NCryptOpenStorageProvider") + + key = win32["NCRYPT_KEY_HANDLE"]() + key_name = "MsalMsiV2Key_" + _new_correlation_id() + + flags = win32["NCRYPT_OVERWRITE_KEY_FLAG"] | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG | _NCRYPT_USE_PER_BOOT_KEY_FLAG + + # IMPORTANT: + # When a certificate is bound to a CNG key via CERT_KEY_PROV_INFO_PROP_ID, Schannel/WinHTTP + # re-opens the key using NCryptOpenKey and passes CRYPT_KEY_PROV_INFO.dwKeySpec as the legacy + # keyspec parameter. The CRYPT_KEY_PROV_INFO docs specify that for CNG (dwProvType==0), + # dwKeySpec must be AT_SIGNATURE or AT_KEYEXCHANGE (not CERT_NCRYPT_KEY_SPEC). + # + # We therefore create the key with dwLegacyKeySpec=AT_SIGNATURE so re-open works reliably. + status = ncrypt.NCryptCreatePersistedKey( + prov, + ctypes_mod.byref(key), + win32["BCRYPT_RSA_ALGORITHM"], + key_name, + _AT_SIGNATURE, + flags, + ) + + try: + _check_security_status(status, "NCryptCreatePersistedKey") + + # Length must be DWORD. + length = wintypes.DWORD(int(_RSA_KEY_SIZE)) + status = ncrypt.NCryptSetProperty( + key, + win32["NCRYPT_LENGTH_PROPERTY"], + ctypes_mod.byref(length), + ctypes_mod.sizeof(length), + 0, + ) + _check_security_status(status, "NCryptSetProperty(Length)") + + # Key usage: signing is required; decrypt doesn't hurt for TLS use-cases. + usage = wintypes.DWORD(win32["NCRYPT_ALLOW_SIGNING_FLAG"] | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) + status = ncrypt.NCryptSetProperty( + key, + win32["NCRYPT_KEY_USAGE_PROPERTY"], + ctypes_mod.byref(usage), + ctypes_mod.sizeof(usage), + 0, + ) + _check_security_status(status, "NCryptSetProperty(Key Usage)") + + # Export policy: 0 (disallow export). + export_policy = wintypes.DWORD(0) + status = ncrypt.NCryptSetProperty( + key, + win32["NCRYPT_EXPORT_POLICY_PROPERTY"], + ctypes_mod.byref(export_policy), + ctypes_mod.sizeof(export_policy), + 0, + ) + _check_security_status(status, "NCryptSetProperty(Export Policy)") + + status = ncrypt.NCryptFinalizeKey(key, 0) + _check_security_status(status, "NCryptFinalizeKey") + + # Validate Virtual Iso property (Credential Guard / VBS). Helps fail fast if KeyGuard isn't active. + try: + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if vi is None or len(vi) < 4: + raise MsiV2Error("[msi_v2] Virtual Iso property missing/invalid; Credential Guard likely not active.") + except Exception as exc: + raise MsiV2Error("[msi_v2] Virtual Iso property not available; Credential Guard likely not active.") from exc + + return prov, key, key_name + + except Exception: + # Best-effort cleanup. The caller also cleans up in obtain_token(). + try: + if key: + ncrypt.NCryptDeleteKey(key, 0) + except Exception: + pass + try: + if key: + ncrypt.NCryptFreeObject(key) + except Exception: + pass + try: + if prov: + ncrypt.NCryptFreeObject(prov) + except Exception: + pass + raise + + +def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int]: + """ + Export RSA public key (modulus, exponent) from an NCrypt key handle. + + We export as BCRYPT_RSAPUBLIC_BLOB and parse it. + """ + from .managed_identity import MsiV2Error + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + status = ncrypt.NCryptExportKey(key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, None, 0, ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, "NCryptExportKey(size)") + if cb.value == 0: + raise MsiV2Error("[msi_v2] NCryptExportKey returned empty public blob size") + + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptExportKey(key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, "NCryptExportKey(RSAPUBLICBLOB)") + blob = bytes(buf[: cb.value]) + + # BCRYPT_RSAKEY_BLOB header is 6 DWORDs, little-endian. + if len(blob) < 24: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB too small") + + import struct + + magic, bitlen, cb_exp, cb_mod, cb_p1, cb_p2 = struct.unpack("<6I", blob[:24]) + if magic != win32["BCRYPT_RSAPUBLIC_MAGIC"]: + raise MsiV2Error(f"[msi_v2] RSAPUBLICBLOB magic mismatch: 0x{magic:08X}") + + if cb_p1 != 0 or cb_p2 != 0: + # Public blob should have primes=0; ignore if present (defensive). + logger.debug("[msi_v2] RSAPUBLICBLOB contains primes unexpectedly (ignored).") + + offset = 24 + if len(blob) < offset + cb_exp + cb_mod: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB truncated") + + exp_bytes = blob[offset : offset + cb_exp] + offset += cb_exp + mod_bytes = blob[offset : offset + cb_mod] + + exponent = int.from_bytes(exp_bytes, "big") + modulus = int.from_bytes(mod_bytes, "big") + + # sanity: header bitlen should match modulus bit length (often does). + if bitlen != modulus.bit_length(): + logger.debug("[msi_v2] RSA bit length mismatch: header=%d computed=%d", bitlen, modulus.bit_length()) + + return modulus, exponent + + +def _ncrypt_sign_pss_sha256(win32: Dict[str, Any], key: Any, digest: bytes) -> bytes: + """ + Sign a SHA-256 digest using RSA-PSS via NCryptSignHash. + + NCryptSignHash expects the *hash digest*, not the original message. + """ + from .managed_identity import MsiV2Error + + if len(digest) != 32: + raise MsiV2Error("[msi_v2] Expected SHA-256 digest (32 bytes)") + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + PaddingInfo = win32["BCRYPT_PSS_PADDING_INFO"] + pad = PaddingInfo(win32["BCRYPT_SHA256_ALGORITHM"], 32) + + hash_buf = (ctypes_mod.c_ubyte * len(digest)).from_buffer_copy(digest) + + cb_sig = wintypes.DWORD(0) + status = ncrypt.NCryptSignHash( + key, + ctypes_mod.byref(pad), + hash_buf, + len(digest), + None, + 0, + ctypes_mod.byref(cb_sig), + win32["BCRYPT_PAD_PSS"], + ) + if int(status) != 0 and cb_sig.value == 0: + _check_security_status(status, "NCryptSignHash(size)") + if cb_sig.value == 0: + raise MsiV2Error("[msi_v2] NCryptSignHash returned empty signature size") + + sig_buf = (ctypes_mod.c_ubyte * cb_sig.value)() + status = ncrypt.NCryptSignHash( + key, + ctypes_mod.byref(pad), + hash_buf, + len(digest), + sig_buf, + cb_sig.value, + ctypes_mod.byref(cb_sig), + win32["BCRYPT_PAD_PSS"], + ) + _check_security_status(status, "NCryptSignHash") + + return bytes(sig_buf[: cb_sig.value]) + + +# -------------------------------------------------------------------------------------- +# CSR builder (KeyGuard key handle) +# -------------------------------------------------------------------------------------- + + +def _build_csr_b64(win32: Dict[str, Any], key: Any, client_id: str, tenant_id: str, cu_id: Any) -> str: + """ + Build a PKCS#10 CSR signed by the KeyGuard key (RSA-PSS/SHA256) and return base64. + + The CSR includes a request attribute (OID _CU_ID_OID_STR) whose value is: + DER UTF8String(JSON(cuId)) + """ + modulus, exponent = _ncrypt_export_rsa_public(win32, key) + + subject = _der_name_cn_dc(client_id, tenant_id) + spki = _der_subject_public_key_info_rsa(modulus, exponent) + + cuid_json = json.dumps(cu_id, separators=(",", ":"), ensure_ascii=False) + cuid_val = _der_utf8string(cuid_json) + + # Attribute: SEQUENCE { OID, SET { } } + attr = _der_sequence(_der_oid(_CU_ID_OID_STR), _der_set(cuid_val)) + + # PKCS#10 attributes: [0] IMPLICIT SET OF Attribute + attrs_content = b"".join(sorted([attr])) + attrs = _der_context_implicit_constructed(0, attrs_content) + + cri = _der_sequence(_der_integer(0), subject, spki, attrs) + + digest = hashlib.sha256(cri).digest() + signature = _ncrypt_sign_pss_sha256(win32, key, digest) + + csr = _der_sequence(cri, _der_algid_rsapss_sha256(), _der_bitstring(signature)) + return base64.b64encode(csr).decode("ascii") + + +# -------------------------------------------------------------------------------------- +# Certificate binding + WinHTTP mTLS +# -------------------------------------------------------------------------------------- + + +def _create_cert_context_with_key( + win32: Dict[str, Any], + cert_der: bytes, + key: Any, + key_name: str, + *, + ksp_name: str = _DEFAULT_KSP_NAME, +) -> Tuple[Any, Tuple[Any, ...]]: + """ + Create a CERT_CONTEXT from DER bytes and associate it with the given CNG private key. + + Why set multiple properties? + + WinHTTP/SChannel sometimes fails to locate the private key unless the cert context contains + enough information. We set (best-effort): + + * CERT_NCRYPT_KEY_HANDLE_PROP_ID (78) - direct handle binding + * CERT_KEY_CONTEXT_PROP_ID (5) - CERT_KEY_CONTEXT with hNCryptKey + CERT_NCRYPT_KEY_SPEC + * CERT_KEY_PROV_INFO_PROP_ID (2) - CRYPT_KEY_PROV_INFO referencing the persisted key name + + The returned keepalive tuple MUST remain referenced for as long as the CERT_CONTEXT is used, + because it contains buffers referenced by the cert properties. + + Returns: + (PCCERT_CONTEXT, keepalive) + """ + from .managed_identity import MsiV2Error + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + crypt32 = win32["crypt32"] + + enc = win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"] + + # Keep DER bytes alive to be safe. + cert_buf = ctypes_mod.create_string_buffer(cert_der) + ctx = crypt32.CertCreateCertificateContext(enc, cert_buf, len(cert_der)) + if not ctx: + _raise_win32_last_error("[msi_v2] CertCreateCertificateContext failed") + + keepalive: List[Any] = [cert_buf] + + try: + key_value = int(getattr(key, "value", key) or 0) + if not key_value: + raise MsiV2Error("[msi_v2] Invalid CNG key handle (0)") + + # --- (A) CERT_NCRYPT_KEY_HANDLE_PROP_ID (78) + key_handle = ctypes_mod.c_void_p(key_value) + keepalive.append(key_handle) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + win32["CERT_NCRYPT_KEY_HANDLE_PROP_ID"], + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_handle), + ) + if not ok: + _raise_win32_last_error("[msi_v2] CertSetCertificateContextProperty(CERT_NCRYPT_KEY_HANDLE_PROP_ID) failed") + + # --- (B) CERT_KEY_CONTEXT_PROP_ID (5) - optional but helpful + CERT_KEY_CONTEXT_PROP_ID = 5 + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF # wincrypt.h: CERT_NCRYPT_KEY_SPEC + + class CERT_KEY_CONTEXT(ctypes_mod.Structure): + _fields_ = [ + ("cbSize", wintypes.DWORD), + ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), # union: HCRYPTPROV / NCRYPT_KEY_HANDLE + ("dwKeySpec", wintypes.DWORD), + ] + + key_ctx = CERT_KEY_CONTEXT(ctypes_mod.sizeof(CERT_KEY_CONTEXT), key_handle, wintypes.DWORD(CERT_NCRYPT_KEY_SPEC)) + keepalive.append(key_ctx) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + CERT_KEY_CONTEXT_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_ctx), + ) + if not ok: + # Not fatal in all environments; keep going. + logger.debug("[msi_v2] Failed to set CERT_KEY_CONTEXT_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) + + # --- (C) CERT_KEY_PROV_INFO_PROP_ID (2) - allows Schannel to re-open key by name + CERT_KEY_PROV_INFO_PROP_ID = 2 + + class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): + _fields_ = [ + ("pwszContainerName", wintypes.LPWSTR), + ("pwszProvName", wintypes.LPWSTR), + ("dwProvType", wintypes.DWORD), + ("dwFlags", wintypes.DWORD), + ("cProvParam", wintypes.DWORD), + ("rgProvParam", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + container_buf = ctypes_mod.create_unicode_buffer(str(key_name)) + provider_buf = ctypes_mod.create_unicode_buffer(str(ksp_name)) + keepalive.extend([container_buf, provider_buf]) + + # For CNG keys (dwProvType==0), dwKeySpec is passed as dwLegacyKeySpec to NCryptOpenKey. + # It must be AT_SIGNATURE or AT_KEYEXCHANGE per CRYPT_KEY_PROV_INFO docs. + prov_info = CRYPT_KEY_PROV_INFO( + ctypes_mod.cast(container_buf, wintypes.LPWSTR), + ctypes_mod.cast(provider_buf, wintypes.LPWSTR), + wintypes.DWORD(0), # CNG + wintypes.DWORD(_NCRYPT_SILENT_FLAG), + wintypes.DWORD(0), + None, + wintypes.DWORD(_AT_SIGNATURE), + ) + keepalive.append(prov_info) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + CERT_KEY_PROV_INFO_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(prov_info), + ) + if not ok: + # If this fails, WinHTTP may still work via CERT_NCRYPT_KEY_HANDLE_PROP_ID. + logger.debug("[msi_v2] Failed to set CERT_KEY_PROV_INFO_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) + + return ctx, tuple(keepalive) + + except Exception: + try: + crypt32.CertFreeCertificateContext(ctx) + except Exception: + pass + raise + + +def _winhttp_close(win32: Dict[str, Any], handle: Any) -> None: + """Close a WinHTTP HINTERNET handle (best-effort).""" + try: + if handle: + win32["winhttp"].WinHttpCloseHandle(handle) + except Exception: + pass + + +def _winhttp_set_option_dword(win32: Dict[str, Any], handle: Any, option: int, value: int, *, fatal: bool = False) -> None: + """Set a WinHTTP option that takes a DWORD value.""" + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + winhttp = win32["winhttp"] + + v = wintypes.DWORD(int(value)) + ok = winhttp.WinHttpSetOption(handle, option, ctypes_mod.byref(v), ctypes_mod.sizeof(v)) + if not ok and fatal: + _raise_win32_last_error(f"[msi_v2] WinHttpSetOption({option}) failed") + + +def _winhttp_post(win32: Dict[str, Any], url: str, cert_ctx: Any, body: bytes, headers: Dict[str, str]) -> Tuple[int, bytes]: + """ + POST bytes to an https:// URL using WinHTTP + SChannel, presenting the provided cert context. + + Returns: + (status_code, response_body_bytes) + """ + from .managed_identity import MsiV2Error + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + winhttp = win32["winhttp"] + + from urllib.parse import urlparse + + u = urlparse(url) + if u.scheme.lower() != "https": + raise MsiV2Error(f"[msi_v2] Token endpoint must be https, got: {url!r}") + if not u.hostname: + raise MsiV2Error(f"[msi_v2] Invalid token endpoint: {url!r}") + + host = u.hostname + port = u.port or 443 + path = u.path or "/" + if u.query: + path += "?" + u.query + + # WinHTTP uses UTF-16 wide strings. + user_agent = "msal-python-msi-v2" + + h_session = winhttp.WinHttpOpen( + user_agent, + win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], + None, + None, + 0, + ) + if not h_session: + _raise_win32_last_error("[msi_v2] WinHttpOpen failed") + + try: + # Best-effort: ensure client cert context is honored even when HTTP/2 is negotiated. + # Not all Windows builds support this option; ignore failures. + _winhttp_set_option_dword(win32, h_session, win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], 1, fatal=False) + + h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) + if not h_connect: + _raise_win32_last_error("[msi_v2] WinHttpConnect failed") + try: + h_request = winhttp.WinHttpOpenRequest( + h_connect, + "POST", + path, + None, + None, + None, + win32["WINHTTP_FLAG_SECURE"], + ) + if not h_request: + _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") + try: + # Set client certificate context on request. + CertContext = win32["CERT_CONTEXT"] + ok = winhttp.WinHttpSetOption( + h_request, + win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], + cert_ctx, + ctypes_mod.sizeof(CertContext), + ) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSetOption(WINHTTP_OPTION_CLIENT_CERT_CONTEXT) failed") + + header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) + header_str = header_lines # unicode / wide + + body_buf = ctypes_mod.create_string_buffer(body) + + ok = winhttp.WinHttpSendRequest( + h_request, + header_str, + 0xFFFFFFFF, # -1L (auto compute) + body_buf, + len(body), + len(body), + 0, + ) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") + + # Query status code as DWORD. + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryHeaders( + h_request, + win32["WINHTTP_QUERY_STATUS_CODE"] | win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, + ctypes_mod.byref(status), + ctypes_mod.byref(status_size), + ctypes_mod.byref(index), + ) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryHeaders(WINHTTP_QUERY_STATUS_CODE) failed") + + chunks: List[bytes] = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable(h_request, ctypes_mod.byref(avail)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData(h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[: read.value])) + if read.value == 0: + break + + return int(status.value), b"".join(chunks) + finally: + _winhttp_close(win32, h_request) + finally: + _winhttp_close(win32, h_connect) + finally: + _winhttp_close(win32, h_session) + + +def _acquire_token_mtls_schannel( + win32: Dict[str, Any], + token_endpoint: str, + cert_ctx: Any, + client_id: str, + scope: str, +) -> Dict[str, Any]: + """ + Acquire an mtls_pop token from ESTS using WinHTTP/SChannel with the provided cert context. + """ + from .managed_identity import MsiV2Error + from urllib.parse import urlencode + + form = urlencode( + { + "grant_type": "client_credentials", + "client_id": client_id, + "scope": scope, + "token_type": "mtls_pop", + } + ).encode("utf-8") + + status, resp_body = _winhttp_post( + win32, + token_endpoint, + cert_ctx, + form, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, + ) + + text = resp_body.decode("utf-8", errors="replace") + if status < 200 or status >= 300: + raise MsiV2Error(f"[msi_v2] ESTS token request failed: HTTP {status} Body={text!r}") + + return _json_loads(text, "ESTS token") + + +# -------------------------------------------------------------------------------------- +# Public API +# -------------------------------------------------------------------------------------- + + +def obtain_token( + http_client, + managed_identity: Dict[str, Any], + resource: str, + *, + attestation_enabled: bool = True, +) -> Dict[str, Any]: + """ + Acquire an mtls_pop access token using Windows KeyGuard + attestation. + + Args: + http_client: + Requests-like object that provides .get() and .post() returning responses with + .status_code, .text, .headers. (MSAL passes its own session by default.) + managed_identity: + MSAL-managed identity selector dict (system-assigned or user-assigned). Used only + to set optional IMDS query params for UAMI. + resource: + Resource or scope. If it doesn't end with "/.default", we append "/.default". + attestation_enabled: + Must be True for this KeyGuard flow. If False, we fail closed. + + Returns: + Dict with access_token, expires_in, token_type, and optional resource. + + Raises: + MsiV2Error on any failure (no MSI v1 fallback). + """ + from .managed_identity import MsiV2Error + + win32 = _load_win32() + ncrypt = win32["ncrypt"] + crypt32 = win32["crypt32"] + + base = _imds_base() + params = _mi_query_params(managed_identity) + corr = _new_correlation_id() + + prov = None + key = None + key_name = None + cert_ctx = None + cert_keepalive: Optional[Tuple[Any, ...]] = None + + try: + # 1) Read platform metadata (client_id, tenant_id, cuId, attestation endpoint). + meta_url = base + _CSR_METADATA_PATH + meta = _imds_get_json(http_client, meta_url, params, _imds_headers(corr)) + + client_id = _get_first(meta, "clientId", "client_id") + tenant_id = _get_first(meta, "tenantId", "tenant_id") + cu_id = meta.get("cuId") if "cuId" in meta else meta.get("cu_id") + attestation_endpoint = _get_first(meta, "attestationEndpoint", "attestation_endpoint") + + if not client_id or not tenant_id or cu_id is None: + raise MsiV2Error(f"[msi_v2] getplatformmetadata missing required fields: {meta}") + + # 2) Create KeyGuard RSA key (NCrypt). + prov, key, key_name = _create_keyguard_rsa_key(win32) + + # 3) CSR signed with RSA-PSS/SHA256, includes cuId request attribute. + csr_b64 = _build_csr_b64(win32, key, str(client_id), str(tenant_id), cu_id) + + # 4) Attestation JWT (required in this flow). + if not attestation_enabled: + raise MsiV2Error("[msi_v2] attestation_enabled must be True for this KeyGuard flow.") + if not attestation_endpoint: + raise MsiV2Error("[msi_v2] attestationEndpoint missing from metadata.") + + key_handle_int = int(getattr(key, "value", 0) or 0) + if not key_handle_int: + raise MsiV2Error("[msi_v2] Invalid key handle for attestation") + + from .msi_v2_attestation import get_attestation_jwt + + att_jwt = get_attestation_jwt( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + key_handle=key_handle_int, + ) + if not att_jwt or not str(att_jwt).strip(): + raise MsiV2Error("[msi_v2] Attestation token is missing/empty; refusing to call issuecredential.") + + # 5) Exchange CSR + attestation for an issued certificate (IMDS /issuecredential). + issue_url = base + _ISSUE_CREDENTIAL_PATH + issue_headers = _imds_headers(corr) + issue_headers["Content-Type"] = "application/json" + + cred = _imds_post_json( + http_client, + issue_url, + params, + issue_headers, + {"csr": csr_b64, "attestation_token": att_jwt}, + ) + + cert_b64 = _get_first(cred, "certificate", "Certificate") + if not cert_b64: + raise MsiV2Error(f"[msi_v2] issuecredential missing certificate: {cred}") + + try: + cert_der = base64.b64decode(cert_b64) + except Exception as exc: + raise MsiV2Error("[msi_v2] issuecredential returned invalid base64 certificate") from exc + + canonical_client_id = _get_first(cred, "client_id", "clientId") or str(client_id) + token_endpoint = _token_endpoint_from_credential(cred) + + # 6) Bind KeyGuard key to the issued cert and request token over mTLS (SChannel). + cert_ctx, cert_keepalive = _create_cert_context_with_key(win32, cert_der, key, str(key_name)) + + scope = _resource_to_scope(resource) + token_json = _acquire_token_mtls_schannel(win32, token_endpoint, cert_ctx, canonical_client_id, scope) + + if token_json.get("access_token") and token_json.get("expires_in"): + return { + "access_token": token_json["access_token"], + "expires_in": int(token_json["expires_in"]), + "token_type": token_json.get("token_type") or "mtls_pop", + "resource": token_json.get("resource"), + } + + # Some error shapes could still be JSON; return raw for caller to interpret. + return token_json + + finally: + # Cleanup: cert context (WinHTTP duplicates it internally during request). + try: + if cert_ctx: + crypt32.CertFreeCertificateContext(cert_ctx) + except Exception: + pass + + # Cleanup: key and provider handles. + # The key is persisted, so we delete it explicitly and then free handles. + try: + if key: + ncrypt.NCryptDeleteKey(key, 0) + except Exception: + pass + try: + if key: + ncrypt.NCryptFreeObject(key) + except Exception: + pass + try: + if prov: + ncrypt.NCryptFreeObject(prov) + except Exception: + pass + + # keepalive is intentionally unused; it just keeps buffers alive while cert_ctx existed. + _ = cert_keepalive diff --git a/msal/msi_v2_attestation.py b/msal/msi_v2_attestation.py new file mode 100644 index 00000000..d46f9338 --- /dev/null +++ b/msal/msi_v2_attestation.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +Windows attestation for MSI v2 KeyGuard keys using AttestationClientLib.dll. + +Equivalent to your PowerShell / C# P/Invoke signatures: + + int InitAttestationLib(ref AttestationLogInfo info); + int AttestKeyGuardImportKey(string endpoint, string authToken, string clientPayload, + IntPtr keyHandle, out IntPtr token, string clientId); + void FreeAttestationToken(IntPtr token); + void UninitAttestationLib(); +""" + +from __future__ import annotations + +import ctypes +import logging +import os +import sys +from ctypes import POINTER, Structure, c_char_p, c_int, c_void_p + +logger = logging.getLogger(__name__) + +# keep callback alive +_NATIVE_LOG_CB = None + + +# void LogFunc(void* ctx, const char* tag, int lvl, const char* func, int line, const char* msg); +_LogFunc = ctypes.CFUNCTYPE(None, c_void_p, c_char_p, c_int, c_char_p, c_int, c_char_p) + + +class AttestationLogInfo(Structure): + _fields_ = [("Log", c_void_p), ("Ctx", c_void_p)] + + +def _default_logger(ctx, tag, lvl, func, line, msg): + try: + tag_s = tag.decode("utf-8", errors="replace") if tag else "" + func_s = func.decode("utf-8", errors="replace") if func else "" + msg_s = msg.decode("utf-8", errors="replace") if msg else "" + logger.debug("[Native:%s:%s] %s:%s - %s", tag_s, lvl, func_s, line, msg_s) + except Exception: + pass + + +def _maybe_add_dll_dirs(): + """ + Make DLL resolution more reliable (especially for packaged apps). + """ + if sys.platform != "win32": + return + + add_dir = getattr(os, "add_dll_directory", None) + if not add_dir: + return + + # exe dir + try: + exe_dir = os.path.dirname(sys.executable) + if exe_dir and os.path.isdir(exe_dir): + add_dir(exe_dir) + except Exception: + pass + + # cwd + try: + cwd = os.getcwd() + if cwd and os.path.isdir(cwd): + add_dir(cwd) + except Exception: + pass + + # module dir + try: + mod_dir = os.path.dirname(__file__) + if mod_dir and os.path.isdir(mod_dir): + add_dir(mod_dir) + except Exception: + pass + + +def _load_lib(): + from .managed_identity import MsiV2Error + + if sys.platform != "win32": + raise MsiV2Error("[msi_v2_attestation] AttestationClientLib is Windows-only.") + + _maybe_add_dll_dirs() + + explicit = os.getenv("ATTESTATION_CLIENTLIB_PATH") + try: + if explicit: + return ctypes.CDLL(explicit) + return ctypes.CDLL("AttestationClientLib.dll") + except OSError as exc: + raise MsiV2Error( + "[msi_v2_attestation] Unable to load AttestationClientLib.dll. " + "Place it next to the app/exe or set ATTESTATION_CLIENTLIB_PATH." + ) from exc + + +def get_attestation_jwt( + *, + attestation_endpoint: str, + client_id: str, + key_handle: int, + auth_token: str = "", + client_payload: str = "{}", +) -> str: + """ + Returns attestation JWT string. Raises MsiV2Error on failure. + """ + from .managed_identity import MsiV2Error + + if not attestation_endpoint: + raise MsiV2Error("[msi_v2_attestation] attestation_endpoint must be non-empty") + if not client_id: + raise MsiV2Error("[msi_v2_attestation] client_id must be non-empty") + if not key_handle: + raise MsiV2Error("[msi_v2_attestation] key_handle must be non-zero") + + lib = _load_lib() + + lib.InitAttestationLib.argtypes = [POINTER(AttestationLogInfo)] + lib.InitAttestationLib.restype = c_int + + lib.AttestKeyGuardImportKey.argtypes = [ + c_char_p, # endpoint + c_char_p, # authToken + c_char_p, # clientPayload + c_void_p, # keyHandle + POINTER(c_void_p), # out token (char*) + c_char_p, # clientId + ] + lib.AttestKeyGuardImportKey.restype = c_int + + lib.FreeAttestationToken.argtypes = [c_void_p] + lib.FreeAttestationToken.restype = None + + lib.UninitAttestationLib.argtypes = [] + lib.UninitAttestationLib.restype = None + + global _NATIVE_LOG_CB # pylint: disable=global-statement + _NATIVE_LOG_CB = _LogFunc(_default_logger) + + info = AttestationLogInfo() + info.Log = ctypes.cast(_NATIVE_LOG_CB, c_void_p).value + info.Ctx = c_void_p(0) + + rc = lib.InitAttestationLib(ctypes.byref(info)) + if rc != 0: + raise MsiV2Error(f"[msi_v2_attestation] InitAttestationLib failed: {rc}") + + token_ptr = c_void_p() + try: + rc = lib.AttestKeyGuardImportKey( + attestation_endpoint.encode("utf-8"), + auth_token.encode("utf-8"), + client_payload.encode("utf-8"), + c_void_p(int(key_handle)), + ctypes.byref(token_ptr), + client_id.encode("utf-8"), + ) + if rc != 0: + raise MsiV2Error(f"[msi_v2_attestation] AttestKeyGuardImportKey failed: {rc}") + if not token_ptr.value: + raise MsiV2Error("[msi_v2_attestation] Attestation token pointer is NULL") + + token = ctypes.string_at(token_ptr.value).decode("utf-8", errors="replace") + return token + finally: + try: + if token_ptr.value: + lib.FreeAttestationToken(token_ptr) + finally: + try: + lib.UninitAttestationLib() + except Exception: + pass diff --git a/msi-v2-sample.spec b/msi-v2-sample.spec new file mode 100644 index 00000000..65ba9781 --- /dev/null +++ b/msi-v2-sample.spec @@ -0,0 +1,45 @@ +# -*- mode: python ; coding: utf-8 -*- +from PyInstaller.utils.hooks import collect_all + +datas = [] +binaries = [] +hiddenimports = ['requests'] +tmp_ret = collect_all('cryptography') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] + + +a = Analysis( + ['run_msi_v2_once.py'], + pathex=[], + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, + optimize=0, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.datas, + [], + name='msi-v2-sample', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/run_msi_v2_once.py b/run_msi_v2_once.py new file mode 100644 index 00000000..49165d3f --- /dev/null +++ b/run_msi_v2_once.py @@ -0,0 +1,56 @@ +""" +MSI v2 (mTLS PoP + KeyGuard Attestation) minimal sample for MSAL Python. + +Behavior: +- Requests mtls_pop + attestation +- STRICT: succeeds only if token_type == mtls_pop +- Prints ONLY: "token received" +- No resource call +""" + +import json +import os +import sys +import msal +import requests + + +DEFAULT_RESOURCE = "https://graph.microsoft.com" +RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + token_cache=msal.TokenCache(), +) + + +def acquire_mtls_pop_token_strict(): + result = client.acquire_token_for_client( + resource=RESOURCE, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + + if "access_token" not in result: + print("FAIL: token acquisition failed") + return 2 + + token_type = result.get("token_type", "mtls_pop") + print("SUCCESS: token acquired") + print(" resource =", RESOURCE) + print(" is_mtls_pop =", token_type == "mtls_pop") + + # Minimal proof we got a real JWT-ish token (don't print it) + at = result["access_token"] + print(" token_len =", len(at)) + + +if __name__ == "__main__": + try: + acquire_mtls_pop_token_strict() + print("token received") + sys.exit(0) + except Exception as ex: + print("FAIL:", ex) + sys.exit(2) diff --git a/sample/msi_v2_sample.py b/sample/msi_v2_sample.py new file mode 100644 index 00000000..b17d9978 --- /dev/null +++ b/sample/msi_v2_sample.py @@ -0,0 +1,175 @@ +""" +MSI v2 (mTLS PoP + KeyGuard Attestation) sample for MSAL Python. + +This sample requests an *attested*, certificate-bound access token (token_type=mtls_pop) +using the IMDSv2 /issuecredential endpoint and ESTS mTLS token endpoint. + +Key points (based on our E2E debugging): +- Use a resource that supports certificate-bound tokens. In this environment, Graph mTLS test + resource is supported; ARM typically is NOT (AADSTS392196). +- Run in strict mode: if mtls_pop + attestation is requested, we fail if we receive Bearer. +- Designed for Windows Azure VM where Credential Guard (VBS/KeyGuard) is available and + AttestationClientLib.dll is present. + +Environment variables: +- RESOURCE: defaults to https://mtlstb.graph.microsoft.com +- ENDPOINT: optional URL to call after acquiring token (e.g., Graph mTLS test endpoint) +- VERBOSE_LOGGING: "1"/"true" enables debug logs +- ATTESTATION_CLIENTLIB_PATH: optional absolute path to AttestationClientLib.dll (recommended) +- PYTHONNET_RUNTIME: must be "coreclr" for CSR OtherRequestAttributes (if using pythonnet path) + +Usage: + set RESOURCE=https://mtlstb.graph.microsoft.com + set ENDPOINT=https://mtlstb.graph.microsoft.com/v1.0/applications?$top=1 + python msi_v2_sample.py +""" + +import json +import logging +import os +import sys +import time + +import msal +import requests + + +# ------------------------- Logging ------------------------- + +def _truthy(s: str) -> bool: + return (s or "").strip().lower() in ("1", "true", "yes", "y", "on") + +if _truthy(os.getenv("VERBOSE_LOGGING", "")): + logging.basicConfig(level=logging.DEBUG) + logging.getLogger("msal").setLevel(logging.DEBUG) + +log = logging.getLogger("msi_v2_sample") + + +# ------------------------- Defaults ------------------------- + +# Graph mTLS test resource (known-good for mtls_pop in your environment) +DEFAULT_RESOURCE = "https://mtlstb.graph.microsoft.com" + +# ARM will often fail for mtls_pop with AADSTS392196 +ARM_RESOURCE = "https://management.azure.com/" + +RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") +ENDPOINT = os.getenv("ENDPOINT", "").strip() + +# Token cache is optional; keep it simple for E2E +token_cache = msal.TokenCache() + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + token_cache=token_cache, +) + + +# ------------------------- Helpers ------------------------- + +def _print_env_hints(): + if RESOURCE.lower().startswith(ARM_RESOURCE): + print("NOTE: RESOURCE is ARM. mtls_pop usually fails for ARM with AADSTS392196.") + print(f" Try: set RESOURCE={DEFAULT_RESOURCE}") + + if sys.platform != "win32": + print("NOTE: This sample is designed for Windows KeyGuard + attestation.") + + +def _call_endpoint_bearer(endpoint: str, token_type: str, access_token: str): + """ + Simple HTTP call using Authorization header. + NOTE: If the *resource* requires client cert at the resource layer too, this may not work. + For your current E2E, token acquisition is the primary goal. + """ + headers = {"Authorization": f"{token_type} {access_token}", "Accept": "application/json"} + r = requests.get(endpoint, headers=headers, timeout=30) + try: + return r.status_code, r.headers, r.json() + except Exception: + return r.status_code, r.headers, r.text + + +# ------------------------- Main flow ------------------------- + +def acquire_mtls_pop_token_strict(): + """ + Acquire MSI v2 token in STRICT mode: + - We request mtls_proof_of_possession=True and with_attestation_support=True + - If we don't get token_type=mtls_pop, treat as failure + """ + result = client.acquire_token_for_client( + resource=RESOURCE, + mtls_proof_of_possession=True, # MSI v2 path + with_attestation_support=True, # KeyGuard attestation required for your scenario + ) + + if "access_token" not in result: + raise RuntimeError(f"Token acquisition failed: {json.dumps(result, indent=2)}") + + token_type = (result.get("token_type") or "Bearer").lower() + if token_type != "mtls_pop": + # In strict mode, bearer is a failure + raise RuntimeError( + "Strict MSI v2 requested, but got non-mtls_pop token.\n" + f"token_type={result.get('token_type')}\n" + "This usually means MSI v2 failed or you requested a resource that doesn't support " + "certificate-bound tokens.\n" + f"Try RESOURCE={DEFAULT_RESOURCE}\n" + f"Full result: {json.dumps(result, indent=2)}" + ) + + return result + + +def main_once(): + _print_env_hints() + + # For pythonnet-based CSR attribute support, coreclr is required. + # If you're running via pythonnet and hit OtherRequestAttributes issues, set: + # set PYTHONNET_RUNTIME=coreclr + if os.getenv("PYTHONNET_RUNTIME"): + log.debug("PYTHONNET_RUNTIME=%s", os.getenv("PYTHONNET_RUNTIME")) + + print("Requesting MSI v2 token (mtls_pop + attestation)...") + result = acquire_mtls_pop_token_strict() + + print("SUCCESS: token acquired") + print(" resource =", RESOURCE) + print(" token_len =", len(result["access_token"])) + + if ENDPOINT: + print("\nCalling ENDPOINT (best-effort using Authorization header):") + status, headers, body = _call_endpoint_bearer( + ENDPOINT, result.get("token_type", "mtls_pop"), result["access_token"] + ) + print(" status =", status) + # Print a small response preview + if isinstance(body, (dict, list)): + print(json.dumps(body, indent=2)[:2000]) + else: + print(str(body)[:2000]) + + +if __name__ == "__main__": + # Run once by default (simpler for debugging) + # Set LOOP=1 if you want repeated calls + loop = _truthy(os.getenv("LOOP", "")) + if not loop: + try: + main_once() + sys.exit(0) + except Exception as ex: + print("FAIL:", ex) + sys.exit(2) + + # Optional loop mode + while True: + try: + main_once() + except Exception as ex: + print("FAIL:", ex) + print("Sleeping 10 seconds... (Ctrl-C to stop)") + time.sleep(10) diff --git a/tests/test_msi_v2.py b/tests/test_msi_v2.py new file mode 100644 index 00000000..6785399f --- /dev/null +++ b/tests/test_msi_v2.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +"""Tests for MSI v2 (mTLS PoP) implementation. + +Goals: +- Provide strong unit coverage without depending on pythonnet / KeyGuard / real IMDS. +- Avoid importing optional helpers that may not exist in the KeyGuard implementation. +- Validate: + * x5t#S256 helper correctness (local) + * verify_cnf_binding behavior (msal.msi_v2) + * ManagedIdentityClient strict gating behavior (msi v2 invoked only when explicitly requested) + * Optional IMDSv2 wire-contract helpers when present (skipped if not exposed) +""" + +import base64 +import datetime +import hashlib +import json +import os +import unittest + +try: + from unittest.mock import patch, MagicMock +except ImportError: + from mock import patch, MagicMock + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import msal +from msal import MsiV2Error + + +# Import only stable surface from msal.msi_v2 +from msal.msi_v2 import verify_cnf_binding + +# MinimalResponse is used in other test modules; safe to reuse here +from tests.test_throttled_http_client import MinimalResponse + + +# --------------------------------------------------------------------------- +# Local helpers (do not rely on msal.msi_v2 exporting these) +# --------------------------------------------------------------------------- + +def _make_self_signed_cert(private_key, common_name="test"): + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)]) + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=1)) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """x5t#S256 = base64url(SHA256(der(cert))) without padding.""" + cert = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + return base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()).rstrip(b"=").decode("ascii") + + +def _b64url(s: bytes) -> str: + return base64.urlsafe_b64encode(s).rstrip(b"=").decode("ascii") + + +def _make_jwt(payload_obj, header_obj=None) -> str: + header_obj = header_obj or {"alg": "RS256", "typ": "JWT"} + header = _b64url(json.dumps(header_obj, separators=(",", ":")).encode("utf-8")) + payload = _b64url(json.dumps(payload_obj, separators=(",", ":")).encode("utf-8")) + sig = _b64url(b"sig") + return f"{header}.{payload}.{sig}" + + +# --------------------------------------------------------------------------- +# Thumbprint helper +# --------------------------------------------------------------------------- + +class TestThumbprintHelper(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "thumbprint-test") + + def test_returns_base64url_no_padding(self): + thumb = get_cert_thumbprint_sha256(self.cert_pem) + self.assertIsInstance(thumb, str) + self.assertNotIn("=", thumb) + + decoded = base64.urlsafe_b64decode(thumb + "==") + self.assertEqual(len(decoded), 32) + + def test_same_cert_same_thumbprint(self): + t1 = get_cert_thumbprint_sha256(self.cert_pem) + t2 = get_cert_thumbprint_sha256(self.cert_pem) + self.assertEqual(t1, t2) + + def test_different_certs_different_thumbprints(self): + key2 = rsa.generate_private_key(public_exponent=65537, key_size=2048) + cert2_pem = _make_self_signed_cert(key2, "thumbprint-test-2") + self.assertNotEqual(get_cert_thumbprint_sha256(self.cert_pem), + get_cert_thumbprint_sha256(cert2_pem)) + + def test_matches_manual_sha256_der(self): + cert = x509.load_pem_x509_certificate(self.cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + expected = base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()).rstrip(b"=").decode("ascii") + self.assertEqual(get_cert_thumbprint_sha256(self.cert_pem), expected) + + +# --------------------------------------------------------------------------- +# verify_cnf_binding (more coverage) +# --------------------------------------------------------------------------- + +class TestVerifyCnfBinding(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "cnf-test") + self.thumbprint = get_cert_thumbprint_sha256(self.cert_pem) + + def test_valid_binding_true(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_wrong_thumbprint_false(self): + token = _make_jwt({"cnf": {"x5t#S256": "wrong"}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_cnf_false(self): + token = _make_jwt({"sub": "nobody"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_x5t_false(self): + token = _make_jwt({"cnf": {}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_cnf_not_object_false(self): + token = _make_jwt({"cnf": "not-an-object"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_not_a_jwt_false(self): + self.assertFalse(verify_cnf_binding("notajwt", self.cert_pem)) + + def test_two_part_jwt_false(self): + token = "a.b" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_four_part_jwt_false(self): + token = "a.b.c.d" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_malformed_payload_base64_false(self): + token = "header.!!!.sig" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_payload_not_json_false(self): + header = _b64url(b'{"alg":"none"}') + payload = _b64url(b"not-json") + token = f"{header}.{payload}.sig" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_payload_with_padding_still_works(self): + # Create payload base64 with explicit padding (library should tolerate) + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').decode("ascii") # includes padding sometimes + payload = base64.urlsafe_b64encode(json.dumps({"cnf": {"x5t#S256": self.thumbprint}}).encode("utf-8")).decode("ascii") + token = f"{header}.{payload}.sig" + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_unicode_in_payload_does_not_break(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}, "msg": "こんにちは"}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + +# --------------------------------------------------------------------------- +# ManagedIdentityClient gating + strict behavior (better coverage) +# --------------------------------------------------------------------------- + +class TestManagedIdentityClientStrictGating(unittest.TestCase): + def _make_client(self): + import requests + return msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + ) + + def test_error_is_exported(self): + self.assertIs(msal.MsiV2Error, MsiV2Error) + + def test_error_is_subclass(self): + self.assertTrue(issubclass(MsiV2Error, msal.ManagedIdentityError)) + + @patch("msal.managed_identity._obtain_token") + def test_default_path_calls_v1(self, mock_v1): + mock_v1.return_value = {"access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R") + self.assertEqual(res["access_token"], "V1") + mock_v1.assert_called_once() + + def test_attestation_requires_pop(self): + client = self._make_client() + with self.assertRaises(msal.ManagedIdentityError): + client.acquire_token_for_client(resource="R", + mtls_proof_of_possession=False, + with_attestation_support=True) + + @patch("msal.msi_v2.obtain_token") + @patch("msal.managed_identity._obtain_token") + def test_pop_without_attestation_does_not_call_v2(self, mock_v1, mock_v2): + # If your implementation requires BOTH flags, v2 must not run here. + mock_v1.return_value = {"access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R", + mtls_proof_of_possession=True, + with_attestation_support=False) + # depending on your design this could either raise or fall back to v1. + # If you changed to "v2 only when both flags", it should use v1. + self.assertEqual(res["token_type"], "Bearer") + mock_v2.assert_not_called() + mock_v1.assert_called_once() + + @patch("msal.msi_v2.obtain_token") + def test_v2_called_when_both_flags_true(self, mock_v2): + mock_v2.return_value = {"access_token": "V2", "expires_in": 3600, "token_type": "mtls_pop"} + client = self._make_client() + res = client.acquire_token_for_client(resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertEqual(res["token_type"], "mtls_pop") + mock_v2.assert_called_once() + # Ensure v2 called with expected signature (resource argument passed through) + args, kwargs = mock_v2.call_args + # obtain_token(http_client, managed_identity, resource, attestation_enabled=...) + self.assertTrue(len(args) >= 3) + self.assertEqual(args[2], "https://mtlstb.graph.microsoft.com") + self.assertIn("attestation_enabled", kwargs) + self.assertTrue(kwargs["attestation_enabled"]) + + @patch("msal.msi_v2.obtain_token", side_effect=MsiV2Error("boom")) + @patch("msal.managed_identity._obtain_token") + def test_strict_v2_failure_raises_no_v1_fallback(self, mock_v1, mock_v2): + client = self._make_client() + with self.assertRaises(MsiV2Error): + client.acquire_token_for_client(resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + mock_v1.assert_not_called() + + +# --------------------------------------------------------------------------- +# Optional: wire contract helper tests (skip if helpers not present) +# --------------------------------------------------------------------------- + +class TestImdsV2OptionalHelpers(unittest.TestCase): + def test_mi_query_params_adds_version_and_uami_selector(self): + if not hasattr(msal.msi_v2, "_mi_query_params"): + self.skipTest("msal.msi_v2._mi_query_params not exposed") + + p = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ClientId", "Id": "abc"}) + self.assertIn("cred-api-version", p) + self.assertEqual(p["cred-api-version"], "2.0") + self.assertEqual(p.get("client_id"), "abc") + + p2 = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ObjectId", "Id": "oid"}) + self.assertEqual(p2.get("object_id"), "oid") + + p3 = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ResourceId", "Id": "/sub/..."}) + self.assertEqual(p3.get("msi_res_id"), "/sub/...") + + def test_issuecredential_body_uses_attestation_token(self): + if not hasattr(msal.msi_v2, "_imds_post_json"): + self.skipTest("msal.msi_v2._imds_post_json not exposed") + if not hasattr(msal.msi_v2, "_issue_credential"): + self.skipTest("msal.msi_v2._issue_credential not exposed") + + http_client = MagicMock() + http_client.post.return_value = MinimalResponse( + status_code=200, + text=json.dumps({ + "certificate": "Zg==", + "mtls_authentication_endpoint": "https://login", + "tenant_id": "t", + "client_id": "c", + }), + ) + + msal.msi_v2._issue_credential( + http_client, + managed_identity={"ManagedIdentityIdType": "SystemAssigned", "Id": None}, + csr_b64="QUJD", + attestation_jwt="fake.jwt", + ) + + _, kwargs = http_client.post.call_args + body = json.loads(kwargs["data"]) + self.assertEqual(body["csr"], "QUJD") + self.assertEqual(body["attestation_token"], "fake.jwt") + + def test_token_endpoint_derived_from_mtls_auth_endpoint(self): + if not hasattr(msal.msi_v2, "_token_endpoint_from_credential"): + self.skipTest("msal.msi_v2._token_endpoint_from_credential not exposed") + + cred = { + "mtls_authentication_endpoint": "https://login.example.com", + "tenant_id": "tenant123", + } + ep = msal.msi_v2._token_endpoint_from_credential(cred) + self.assertTrue(ep.endswith("/tenant123/oauth2/v2.0/token")) + + +if __name__ == "__main__": + unittest.main() From f69f1424e0233f8defd60e0d2682d5e97fa34547 Mon Sep 17 00:00:00 2001 From: Gladwin Johnson Date: Tue, 24 Feb 2026 18:34:16 -0800 Subject: [PATCH 3/5] cachedmaa --- msal/msi_v2.py | 1002 +++++++++++------------------------- msal/msi_v2_attestation.py | 185 +++++-- run_msi_v2_once.py | 56 -- working/run_msi_v2_once.py | 56 ++ 4 files changed, 503 insertions(+), 796 deletions(-) delete mode 100644 run_msi_v2_once.py create mode 100644 working/run_msi_v2_once.py diff --git a/msal/msi_v2.py b/msal/msi_v2.py index 92016350..e6db2b0e 100644 --- a/msal/msi_v2.py +++ b/msal/msi_v2.py @@ -5,40 +5,22 @@ """ MSI v2 (IMDSv2) Managed Identity flow — Windows KeyGuard + Attestation + SChannel mTLS PoP. -This module implements the end-to-end "MSI v2" flow used by Azure Managed Identity on Windows -when *certificate-bound* access tokens are requested (token_type=mtls_pop). - -It is intentionally "Python-only": no pythonnet/.NET interop is required. Instead, it uses -ctypes to call a small set of Windows APIs: - - * CNG/NCrypt (ncrypt.dll) - Create a KeyGuard/VBS isolated RSA key + sign CSR (RSA-PSS/SHA256) - * Crypt32 (crypt32.dll) - Bind the issued certificate to the CNG private key - * WinHTTP (winhttp.dll) - Perform the token request over mTLS using SChannel - -Flow summary (mirrors your working PowerShell implementation): - - 1) GET /metadata/identity/getplatformmetadata?cred-api-version=2.0 - 2) Create KeyGuard RSA key (non-exportable, VBS-isolated) - 3) Build CSR signed with RSA-PSS/SHA256 and include a special CSR attribute: - OID 1.3.6.1.4.1.311.90.2.10 -> DER UTF8String(JSON(cuId)) - 4) Get an attestation JWT for the key (via .msi_v2_attestation.get_attestation_jwt) - 5) POST /metadata/identity/issuecredential?cred-api-version=2.0 with csr + attestation_token - 6) POST {tenant}/oauth2/v2.0/token (mTLS) with the issued certificate and token_type=mtls_pop - -Important design choices: - - * Windows-only: importing on non-Windows platforms is supported, but calling obtain_token() - will raise MsiV2Error. - * No MSI v1 fallback: any failure raises MsiV2Error. - * Defensive certificate-to-key binding: we set multiple certificate context properties so - WinHTTP/SChannel can consistently locate the private key. - -Security notes: - - * Access tokens are secrets. Avoid logging or printing them in production. - * The KeyGuard RSA key is created as persisted, but is deleted in cleanup. - -Public entrypoint: obtain_token(http_client, managed_identity, resource, attestation_enabled=True) +This module implements the MSI v2 token acquisition path using Windows native APIs via ctypes: + - CNG/NCrypt: create/open a KeyGuard-protected per-boot RSA key (non-exportable) + - Minimal DER/PKCS#10: build a CSR signed with RSA-PSS/SHA256 + - IMDS: call getplatformmetadata + issuecredential + - Crypt32: bind the issued certificate to the CNG private key + - WinHTTP/SChannel: acquire access token over mTLS (token_type=mtls_pop) + +Key behavior: + - Uses a *named per-boot key*: opens the key if it already exists for this boot; otherwise creates it. + - No MSI v1 fallback: any MSI v2 failure raises MsiV2Error. + - Production-ready handle management: all WinHTTP / Crypt32 / NCrypt handles are released. + +Environment variables (optional): + - AZURE_POD_IDENTITY_AUTHORITY_HOST: override IMDS base URL (default http://169.254.169.254) + - MSAL_MSI_V2_KEY_NAME: override the per-boot key name (otherwise derived from metadata clientId) + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" to disable MAA JWT caching (implemented in msi_v2_attestation.py) """ from __future__ import annotations @@ -50,20 +32,10 @@ import os import sys import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, List logger = logging.getLogger(__name__) -__all__ = [ - "get_cert_thumbprint_sha256", - "verify_cnf_binding", - "obtain_token", -] - -# -------------------------------------------------------------------------------------- -# IMDS / MSI v2 constants -# -------------------------------------------------------------------------------------- - _IMDS_DEFAULT_BASE = "http://169.254.169.254" _IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" @@ -74,42 +46,35 @@ _ISSUE_CREDENTIAL_PATH = "/metadata/identity/issuecredential" _ACQUIRE_ENTRA_TOKEN_PATH = "/oauth2/v2.0/token" -# OID for the special CSR request attribute carrying cuId JSON. _CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" -# Flags from ncrypt.h used by the PowerShell reference implementation. +# ncrypt.h flags _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 -_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 _RSA_KEY_SIZE = 2048 -# Legacy KeySpec values (CAPI compatibility / CNG interop). -# Used by NCryptCreatePersistedKey.dwLegacyKeySpec and by CRYPT_KEY_PROV_INFO.dwKeySpec. +# Legacy KeySpec values used with NCryptCreatePersistedKey.dwLegacyKeySpec and +# CRYPT_KEY_PROV_INFO.dwKeySpec when dwProvType==0 (CNG/KSP). _AT_KEYEXCHANGE = 1 _AT_SIGNATURE = 2 -# CRYPT_KEY_PROV_INFO.dwFlags for CNG keys (best-effort suppression of UI prompts). +# Flags used by CRYPT_KEY_PROV_INFO.dwFlags for CNG keys (best-effort: no UI prompts) _NCRYPT_SILENT_FLAG = 0x40 -_DEFAULT_KSP_NAME = "Microsoft Software Key Storage Provider" +_KEY_NAME_ENVVAR = "MSAL_MSI_V2_KEY_NAME" -# -------------------------------------------------------------------------------------- -# Compatibility helpers (optional; useful for tests or debugging) -# -------------------------------------------------------------------------------------- +# ---------------------------- +# Compatibility helpers (tests + cross-language parity) +# ---------------------------- def get_cert_thumbprint_sha256(cert_pem: str) -> str: """ - Compute base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 comparisons. - - Accepts a PEM-encoded certificate string. - - Returns: - Base64url-encoded SHA-256 thumbprint without '=' padding, or "" if cryptography is - unavailable or parsing fails. + Return base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 comparisons. + Accepts a PEM certificate string. """ try: - # cryptography is optional; keep this helper lightweight. from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -119,23 +84,13 @@ def get_cert_thumbprint_sha256(cert_pem: str) -> str: digest = hashlib.sha256(der).digest() return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") except Exception: - # Fail closed: if we cannot compute the thumbprint, binding verification cannot succeed. + # If cryptography isn't available, fail closed (binding cannot be verified) return "" def verify_cnf_binding(token: str, cert_pem: str) -> bool: """ - Verify that a JWT payload contains cnf.x5t#S256 matching the cert thumbprint. - - This is a *best-effort* helper for validating certificate binding in tests. It does not - validate JWT signature or claims (aud/iss/exp/etc). - - Args: - token: A JWT access token (3-part base64url string). - cert_pem: PEM certificate string. - - Returns: - True if cnf.x5t#S256 exists and equals the SHA-256 certificate thumbprint. + Verify that JWT payload contains cnf.x5t#S256 matching the cert thumbprint. """ try: parts = token.split(".") @@ -160,39 +115,23 @@ def verify_cnf_binding(token: str, cert_pem: str) -> bool: return False -# -------------------------------------------------------------------------------------- +# ---------------------------- # IMDS helpers -# -------------------------------------------------------------------------------------- - +# ---------------------------- def _imds_base() -> str: - """Resolve IMDS base URI (supports pod identity override via env var).""" return os.getenv(_IMDS_BASE_ENVVAR, _IMDS_DEFAULT_BASE).strip().rstrip("/") def _new_correlation_id() -> str: - """Generate an RFC 4122 correlation id used in x-ms-client-request-id.""" return str(uuid.uuid4()) def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: - """ - Headers required by IMDS. The Metadata=true header is mandatory. - - We also include x-ms-client-request-id to correlate IMDS and ESTS requests. - """ - return { - "Metadata": "true", - "x-ms-client-request-id": correlation_id or _new_correlation_id(), - } + return {"Metadata": "true", "x-ms-client-request-id": correlation_id or _new_correlation_id()} def _resource_to_scope(resource_or_scope: str) -> str: - """ - Convert an ADAL-style 'resource' string into an MSAL v2 scope string. - - IMDS v2 uses MSAL v2 token endpoint semantics (scope=.../.default). - """ s = (resource_or_scope or "").strip() if not s: raise ValueError("resource must be non-empty") @@ -202,11 +141,7 @@ def _resource_to_scope(resource_or_scope: str) -> str: def _der_utf8string(value: str) -> bytes: - """ - Minimal DER UTF8String encoder (tag 0x0C). - - Used for the CSR request attribute value (cuId JSON) and for X.500 CN when applicable. - """ + """DER UTF8String encoder (tag 0x0C).""" raw = value.encode("utf-8") n = len(raw) if n < 0x80: @@ -222,52 +157,34 @@ def _der_utf8string(value: str) -> bytes: def _json_loads(text: str, what: str) -> Dict[str, Any]: - """Parse JSON or raise MsiV2Error with context.""" from .managed_identity import MsiV2Error - try: - data = json.loads(text) - if isinstance(data, dict): - return data - raise MsiV2Error(f"[msi_v2] Expected JSON object from {what}, got {type(data).__name__}") + obj = json.loads(text) + if not isinstance(obj, dict): + raise TypeError("expected JSON object") + return obj except Exception as exc: # pylint: disable=broad-except raise MsiV2Error(f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: - """ - Fetch the first non-empty string value among several possible keys. - - IMDS field casing can vary (camelCase vs snake_case), so this helper checks: - 1) Exact keys - 2) Case-insensitive matches - """ - # direct keys for n in names: if n in obj and obj[n] is not None and str(obj[n]).strip() != "": return str(obj[n]) - - # case-insensitive lower = {str(k).lower(): k for k in obj.keys()} for n in names: k = lower.get(n.lower()) if k and obj[k] is not None and str(obj[k]).strip() != "": return str(obj[k]) - return None def _mi_query_params(managed_identity: Optional[Dict[str, Any]]) -> Dict[str, str]: """ - Build IMDS query parameters: - * cred-api-version=2.0 (required) - * optional user-assigned identity selectors - - managed_identity shape (MSAL Python): - {"ManagedIdentityIdType": "ClientId"|"ObjectId"|"ResourceId", "Id": ""} + Adds cred-api-version=2.0 plus optional UAMI selector params. + managed_identity shape (MSAL python): {"ManagedIdentityIdType": "...", "Id": "..."} """ params: Dict[str, str] = {_API_VERSION_QUERY_PARAM: _IMDS_V2_API_VERSION} - if not isinstance(managed_identity, dict): return params @@ -278,91 +195,55 @@ def _mi_query_params(managed_identity: Optional[Dict[str, Any]]) -> Dict[str, st wire = mapping.get(id_type) if wire and identifier: params[wire] = str(identifier) - return params def _imds_get_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str]) -> Dict[str, Any]: - """ - GET JSON from IMDS with a basic 'Server' header sanity check. - - Note: The "server: IMDS/..." header check is a defense-in-depth measure to reduce the - chance of SSRF misuse. Keep it strict unless you have a concrete reason to loosen it. - """ from .managed_identity import MsiV2Error - resp = http_client.get(url, params=params, headers=headers) - server = (resp.headers or {}).get("server", "") if "imds" not in str(server).lower(): raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") - if resp.status_code != 200: raise MsiV2Error(f"[msi_v2] IMDSv2 GET {url} failed: HTTP {resp.status_code}: {resp.text}") - return _json_loads(resp.text, f"GET {url}") -def _imds_post_json( - http_client, url: str, params: Dict[str, str], headers: Dict[str, str], body: Dict[str, Any] -) -> Dict[str, Any]: - """POST JSON to IMDS and return JSON response (with same header sanity check).""" +def _imds_post_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str], body: Dict[str, Any]) -> Dict[str, Any]: from .managed_identity import MsiV2Error - resp = http_client.post(url, params=params, headers=headers, data=json.dumps(body, separators=(",", ":"))) - server = (resp.headers or {}).get("server", "") if "imds" not in str(server).lower(): raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") - if resp.status_code != 200: raise MsiV2Error(f"[msi_v2] IMDSv2 POST {url} failed: HTTP {resp.status_code}: {resp.text}") - return _json_loads(resp.text, f"POST {url}") def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: - """ - Determine the token endpoint returned by /issuecredential. - - IMDS can return either: - * token_endpoint - * mtls_authentication_endpoint + tenant_id (compose into {mtls_auth}/{tenant}/oauth2/v2.0/token) - """ token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") if token_endpoint: return token_endpoint - mtls_auth = _get_first( - cred, - "mtls_authentication_endpoint", - "mtlsAuthenticationEndpoint", - "mtls_authenticationEndpoint", - ) + mtls_auth = _get_first(cred, "mtls_authentication_endpoint", "mtlsAuthenticationEndpoint", "mtls_authenticationEndpoint") tenant_id = _get_first(cred, "tenant_id", "tenantId") if not mtls_auth or not tenant_id: from .managed_identity import MsiV2Error - raise MsiV2Error(f"[msi_v2] issuecredential missing mtls_authentication_endpoint/tenant_id: {cred}") base = mtls_auth.rstrip("/") + "/" + tenant_id.strip("/") return base + _ACQUIRE_ENTRA_TOKEN_PATH -# -------------------------------------------------------------------------------------- -# Win32 primitives (ctypes) - lazy loaded -# -------------------------------------------------------------------------------------- +# ---------------------------- +# Win32 primitives (ctypes) +# ---------------------------- _WIN32: Optional[Dict[str, Any]] = None def _load_win32() -> Dict[str, Any]: - """ - Lazy-load Win32 APIs via ctypes. - - Keeping the import behind a function allows importing this module on non-Windows platforms - without failing at import time. The public obtain_token() function enforces Windows-only. - """ + """Lazy-load Win32 APIs via ctypes (safe to import on non-Windows).""" global _WIN32 from .managed_identity import MsiV2Error @@ -376,18 +257,15 @@ def _load_win32() -> Dict[str, Any]: import ctypes from ctypes import wintypes - # DLLs (use_last_error makes ctypes.get_last_error() reliable for BOOL-returning APIs) ncrypt = ctypes.WinDLL("ncrypt.dll") crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) - kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) - # --- Types --- + # Types NCRYPT_PROV_HANDLE = ctypes.c_void_p NCRYPT_KEY_HANDLE = ctypes.c_void_p - SECURITY_STATUS = ctypes.c_long # LONG / NTSTATUS style + SECURITY_STATUS = ctypes.c_long - # Crypt32 certificate context class CERT_CONTEXT(ctypes.Structure): _fields_ = [ ("dwCertEncodingType", wintypes.DWORD), @@ -399,36 +277,29 @@ class CERT_CONTEXT(ctypes.Structure): PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) - # Padding info for NCryptSignHash (RSA-PSS) class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): _fields_ = [ ("pszAlgId", ctypes.c_wchar_p), ("cbSalt", wintypes.ULONG), ] - # --- Constants (subset) --- + # Constants ERROR_SUCCESS = 0 - # ncrypt.h flags + # NCRYPT constants NCRYPT_OVERWRITE_KEY_FLAG = 0x00000080 - - # key properties (ncrypt.h) NCRYPT_LENGTH_PROPERTY = "Length" NCRYPT_EXPORT_POLICY_PROPERTY = "Export Policy" NCRYPT_KEY_USAGE_PROPERTY = "Key Usage" - - # key usage flags (ncrypt.h) NCRYPT_ALLOW_SIGNING_FLAG = 0x00000002 NCRYPT_ALLOW_DECRYPT_FLAG = 0x00000001 - - # bcrypt.h / padding BCRYPT_PAD_PSS = 0x00000008 BCRYPT_SHA256_ALGORITHM = "SHA256" BCRYPT_RSA_ALGORITHM = "RSA" BCRYPT_RSAPUBLIC_BLOB = "RSAPUBLICBLOB" BCRYPT_RSAPUBLIC_MAGIC = 0x31415352 # 'RSA1' - # wincrypt.h + # Crypt32 constants X509_ASN_ENCODING = 0x00000001 PKCS_7_ASN_ENCODING = 0x00010000 CERT_NCRYPT_KEY_HANDLE_PROP_ID = 78 @@ -442,74 +313,48 @@ class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): WINHTTP_QUERY_STATUS_CODE = 19 WINHTTP_QUERY_FLAG_NUMBER = 0x20000000 - # --- Function prototypes (argtypes/restype) --- - # NCrypt + # NCrypt prototypes ncrypt.NCryptOpenStorageProvider.argtypes = [ctypes.POINTER(NCRYPT_PROV_HANDLE), ctypes.c_wchar_p, wintypes.DWORD] ncrypt.NCryptOpenStorageProvider.restype = SECURITY_STATUS + ncrypt.NCryptOpenKey.argtypes = [ + NCRYPT_PROV_HANDLE, + ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, # key name + wintypes.DWORD, # legacy keyspec (AT_SIGNATURE/AT_KEYEXCHANGE) + wintypes.DWORD, # flags + ] + ncrypt.NCryptOpenKey.restype = SECURITY_STATUS + ncrypt.NCryptCreatePersistedKey.argtypes = [ NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), ctypes.c_wchar_p, # alg id ctypes.c_wchar_p, # key name - wintypes.DWORD, # legacy keyspec - wintypes.DWORD, # flags + wintypes.DWORD, # legacy keyspec + wintypes.DWORD, # flags ] ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS - ncrypt.NCryptSetProperty.argtypes = [ - ctypes.c_void_p, - ctypes.c_wchar_p, - ctypes.c_void_p, - wintypes.DWORD, - wintypes.DWORD, - ] + ncrypt.NCryptSetProperty.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, wintypes.DWORD] ncrypt.NCryptSetProperty.restype = SECURITY_STATUS ncrypt.NCryptFinalizeKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] ncrypt.NCryptFinalizeKey.restype = SECURITY_STATUS - ncrypt.NCryptGetProperty.argtypes = [ - ctypes.c_void_p, - ctypes.c_wchar_p, - ctypes.c_void_p, - wintypes.DWORD, - ctypes.POINTER(wintypes.DWORD), - wintypes.DWORD, - ] + ncrypt.NCryptGetProperty.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] ncrypt.NCryptGetProperty.restype = SECURITY_STATUS - ncrypt.NCryptExportKey.argtypes = [ - NCRYPT_KEY_HANDLE, - NCRYPT_KEY_HANDLE, - ctypes.c_wchar_p, - ctypes.c_void_p, - ctypes.c_void_p, - wintypes.DWORD, - ctypes.POINTER(wintypes.DWORD), - wintypes.DWORD, - ] + ncrypt.NCryptExportKey.argtypes = [NCRYPT_KEY_HANDLE, NCRYPT_KEY_HANDLE, ctypes.c_wchar_p, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] ncrypt.NCryptExportKey.restype = SECURITY_STATUS - ncrypt.NCryptSignHash.argtypes = [ - NCRYPT_KEY_HANDLE, - ctypes.c_void_p, # padding info - ctypes.c_void_p, # hash bytes - wintypes.DWORD, # hash len - ctypes.c_void_p, # sig out - wintypes.DWORD, # sig out len - ctypes.POINTER(wintypes.DWORD), - wintypes.DWORD, # flags - ] + ncrypt.NCryptSignHash.argtypes = [NCRYPT_KEY_HANDLE, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] ncrypt.NCryptSignHash.restype = SECURITY_STATUS - ncrypt.NCryptDeleteKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] - ncrypt.NCryptDeleteKey.restype = SECURITY_STATUS - ncrypt.NCryptFreeObject.argtypes = [ctypes.c_void_p] ncrypt.NCryptFreeObject.restype = SECURITY_STATUS - # Crypt32 + # Crypt32 prototypes crypt32.CertCreateCertificateContext.argtypes = [wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] crypt32.CertCreateCertificateContext.restype = PCCERT_CONTEXT @@ -519,55 +364,26 @@ class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] crypt32.CertFreeCertificateContext.restype = wintypes.BOOL - # WinHTTP - winhttp.WinHttpOpen.argtypes = [ - ctypes.c_wchar_p, - wintypes.DWORD, - ctypes.c_wchar_p, - ctypes.c_wchar_p, - wintypes.DWORD, - ] + # WinHTTP prototypes + winhttp.WinHttpOpen.argtypes = [ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_wchar_p, wintypes.DWORD] winhttp.WinHttpOpen.restype = ctypes.c_void_p winhttp.WinHttpConnect.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] winhttp.WinHttpConnect.restype = ctypes.c_void_p - winhttp.WinHttpOpenRequest.argtypes = [ - ctypes.c_void_p, - ctypes.c_wchar_p, - ctypes.c_wchar_p, - ctypes.c_wchar_p, - ctypes.c_wchar_p, - ctypes.c_void_p, - wintypes.DWORD, - ] + winhttp.WinHttpOpenRequest.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD] winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p winhttp.WinHttpSetOption.argtypes = [ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] winhttp.WinHttpSetOption.restype = wintypes.BOOL - winhttp.WinHttpSendRequest.argtypes = [ - ctypes.c_void_p, - ctypes.c_wchar_p, - wintypes.DWORD, - ctypes.c_void_p, - wintypes.DWORD, - wintypes.DWORD, - ctypes.c_ulonglong, # context - ] + winhttp.WinHttpSendRequest.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD, wintypes.DWORD, ctypes.c_ulonglong] winhttp.WinHttpSendRequest.restype = wintypes.BOOL winhttp.WinHttpReceiveResponse.argtypes = [ctypes.c_void_p, ctypes.c_void_p] winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL - winhttp.WinHttpQueryHeaders.argtypes = [ - ctypes.c_void_p, - wintypes.DWORD, - ctypes.c_wchar_p, - ctypes.c_void_p, - ctypes.POINTER(wintypes.DWORD), - ctypes.POINTER(wintypes.DWORD), - ] + winhttp.WinHttpQueryHeaders.argtypes = [ctypes.c_void_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD), ctypes.POINTER(wintypes.DWORD)] winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL winhttp.WinHttpQueryDataAvailable.argtypes = [ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] @@ -579,25 +395,18 @@ class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] winhttp.WinHttpCloseHandle.restype = wintypes.BOOL - # Kernel32 - kernel32.GetLastError.argtypes = [] - kernel32.GetLastError.restype = wintypes.DWORD - _WIN32 = { "ctypes": ctypes, "wintypes": wintypes, "ncrypt": ncrypt, "crypt32": crypt32, "winhttp": winhttp, - "kernel32": kernel32, - # types "NCRYPT_PROV_HANDLE": NCRYPT_PROV_HANDLE, "NCRYPT_KEY_HANDLE": NCRYPT_KEY_HANDLE, "SECURITY_STATUS": SECURITY_STATUS, "CERT_CONTEXT": CERT_CONTEXT, "PCCERT_CONTEXT": PCCERT_CONTEXT, "BCRYPT_PSS_PADDING_INFO": BCRYPT_PSS_PADDING_INFO, - # constants "ERROR_SUCCESS": ERROR_SUCCESS, "NCRYPT_OVERWRITE_KEY_FLAG": NCRYPT_OVERWRITE_KEY_FLAG, "NCRYPT_LENGTH_PROPERTY": NCRYPT_LENGTH_PROPERTY, @@ -624,54 +433,49 @@ class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): return _WIN32 -def _format_win32_error(ctypes_mod, code: int) -> str: - """Format a Win32 error code into a human-readable string (best-effort).""" - try: - return ctypes_mod.FormatError(code).strip() - except Exception: - return "" - - def _raise_win32_last_error(msg: str) -> None: - """ - Raise MsiV2Error with the current Win32 last-error code. - - Use for WinHTTP/Crypt32 APIs where failure is indicated via BOOL/NULL and details are in GetLastError(). - """ + """Raise MsiV2Error with the current Win32 last-error code.""" from .managed_identity import MsiV2Error - win32 = _load_win32() ctypes_mod = win32["ctypes"] err = ctypes_mod.get_last_error() - detail = _format_win32_error(ctypes_mod, err) + detail = "" + try: + detail = ctypes_mod.FormatError(err).strip() + except Exception: + pass if detail: raise MsiV2Error(f"{msg} (winerror={err} {detail})") raise MsiV2Error(f"{msg} (winerror={err})") def _check_security_status(status: int, what: str) -> None: - """ - Check SECURITY_STATUS/NTSTATUS-style return codes from NCrypt. - - Most NCrypt APIs return 0 for success; otherwise they return a status code (often an NTSTATUS). - """ + """Check SECURITY_STATUS return codes from NCrypt (0 == success).""" from .managed_identity import MsiV2Error - if int(status) != 0: code_u32 = int(status) & 0xFFFFFFFF raise MsiV2Error(f"[msi_v2] {what} failed: status=0x{code_u32:08X}") -# -------------------------------------------------------------------------------------- -# DER helpers (minimal PKCS#10 CSR builder) -# -------------------------------------------------------------------------------------- +def _status_u32(status: int) -> int: + return int(status) & 0xFFFFFFFF + -# This is a minimal DER encoder sufficient for: -# * PKCS#10 CertificationRequestInfo (subject, spki, attributes) -# * RSASSA-PSS AlgorithmIdentifier params -# It intentionally avoids general ASN.1 frameworks to keep dependencies low. +# Common NCRYPT "not found" style statuses (NTE_*). +_NTE_BAD_KEYSET = 0x80090016 +_NTE_NO_KEY = 0x8009000D +_NTE_NOT_FOUND = 0x80090011 +_NTE_EXISTS = 0x8009000F +def _is_key_not_found(status: int) -> bool: + return _status_u32(status) in (_NTE_BAD_KEYSET, _NTE_NO_KEY, _NTE_NOT_FOUND) + + +# ---------------------------- +# DER helpers (minimal PKCS#10 CSR builder) +# ---------------------------- + def _der_len(n: int) -> bytes: if n < 0: raise ValueError("DER length cannot be negative") @@ -716,7 +520,6 @@ def _der_oid(oid: str) -> bytes: for p in parts[2:]: if p < 0: raise ValueError(f"Invalid OID component: {oid}") - # base-128 encoding stack = bytearray() if p == 0: stack.append(0) @@ -736,14 +539,12 @@ def _der_sequence(*items: bytes) -> bytes: def _der_set(*items: bytes) -> bytes: - # DER SET requires elements to be sorted by their full DER encoding. - enc = sorted(items) + enc = sorted(items) # DER SET requires sorting by full encoding return _der(0x31, b"".join(enc)) def _der_bitstring(data: bytes) -> bytes: - # 0 unused bits - return _der(0x03, b"\x00" + data) + return _der(0x03, b"\x00" + data) # 0 unused bits def _der_ia5string(value: str) -> bytes: @@ -752,31 +553,16 @@ def _der_ia5string(value: str) -> bytes: def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: - if not 0 <= tagnum <= 30: - raise ValueError("Unsupported tag number") return _der(0xA0 + tagnum, inner) def _der_context_implicit_constructed(tagnum: int, inner_content: bytes) -> bytes: - """ - Context-specific IMPLICIT, constructed. - - Used for PKCS#10 attributes: - attributes [0] IMPLICIT SET OF Attribute - Where we encode the SET OF's contents without the SET tag (0x31). - """ - if not 0 <= tagnum <= 30: - raise ValueError("Unsupported tag number") + """Context-specific IMPLICIT, constructed (CSR attributes use [0] IMPLICIT).""" return _der(0xA0 + tagnum, inner_content) def _der_name_cn_dc(cn: str, dc: str) -> bytes: - """ - Encode X.500 Name with CN and DC RDNs. - - CN (2.5.4.3) is encoded as UTF8String. - DC (0.9.2342.19200300.100.1.25) is usually IA5String (ASCII), else UTF8String. - """ + """Encode X.500 Name with CN and DC RDNs (CN UTF8String, DC IA5String if ASCII).""" cn_atv = _der_sequence(_der_oid("2.5.4.3"), _der_utf8string(cn)) cn_rdn = _der_set(cn_atv) @@ -787,7 +573,6 @@ def _der_name_cn_dc(cn: str, dc: str) -> bytes: dc_atv = _der_sequence(_der_oid("0.9.2342.19200300.100.1.25"), dc_value) dc_rdn = _der_set(dc_atv) - # RDNSequence is a SEQUENCE of SETs return _der_sequence(cn_rdn, dc_rdn) @@ -798,16 +583,11 @@ def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: def _der_algid_rsapss_sha256() -> bytes: - """ - AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), saltLength=32, trailerField=1. - - This matches what .NET / PowerShell emits for the working flow. - """ + """AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), saltLength=32.""" sha256 = _der_sequence(_der_oid("2.16.840.1.101.3.4.2.1"), _der_null()) mgf1 = _der_sequence(_der_oid("1.2.840.113549.1.1.8"), sha256) salt_len = _der_integer(32) trailer = _der_integer(1) - params = _der_sequence( _der_context_explicit(0, sha256), _der_context_explicit(1, mgf1), @@ -817,149 +597,118 @@ def _der_algid_rsapss_sha256() -> bytes: return _der_sequence(_der_oid("1.2.840.113549.1.1.10"), params) -# -------------------------------------------------------------------------------------- +# ---------------------------- # CNG/NCrypt wrappers -# -------------------------------------------------------------------------------------- - +# ---------------------------- -def _ncrypt_get_property(win32: Dict[str, Any], handle: Any, name: str) -> bytes: - """Get an NCrypt property value as raw bytes.""" +def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: ctypes_mod = win32["ctypes"] wintypes = win32["wintypes"] ncrypt = win32["ncrypt"] cb = wintypes.DWORD(0) - - status = ncrypt.NCryptGetProperty(handle, name, None, 0, ctypes_mod.byref(cb), 0) + status = ncrypt.NCryptGetProperty(h, name, None, 0, ctypes_mod.byref(cb), 0) if int(status) != 0 and cb.value == 0: _check_security_status(status, f"NCryptGetProperty({name})") - if cb.value == 0: return b"" - buf = (ctypes_mod.c_ubyte * cb.value)() - status = ncrypt.NCryptGetProperty(handle, name, buf, cb.value, ctypes_mod.byref(cb), 0) + status = ncrypt.NCryptGetProperty(h, name, buf, cb.value, ctypes_mod.byref(cb), 0) _check_security_status(status, f"NCryptGetProperty({name})") - return bytes(buf[: cb.value]) -def _create_keyguard_rsa_key(win32: Dict[str, Any]) -> Tuple[Any, Any, str]: - """ - Create a non-exportable RSA key protected with VBS/KeyGuard. +def _stable_key_name(client_id: str) -> str: + # Keep the name deterministic and safe for CNG key naming. + base = (client_id or "").strip() + safe = [] + for ch in base: + if ch.isalnum() or ch in ("-", "_"): + safe.append(ch) + else: + safe.append("_") + # Max length: keep some headroom + return "MsalMsiV2Key_" + "".join(safe)[:90] - Returns: - (prov_handle, key_handle, key_name) - key_name is the persisted CNG key name (container name). WinHTTP/SChannel can require it to - re-open the key when doing client-certificate authentication. +def _open_or_create_keyguard_rsa_key(win32: Dict[str, Any], *, key_name: str) -> Tuple[Any, Any, str, bool]: """ - from .managed_identity import MsiV2Error + Open a named per-boot KeyGuard RSA key if it exists; otherwise create it. + Returns: (prov_handle, key_handle, key_name, opened_existing) + """ ctypes_mod = win32["ctypes"] wintypes = win32["wintypes"] ncrypt = win32["ncrypt"] prov = win32["NCRYPT_PROV_HANDLE"]() - status = ncrypt.NCryptOpenStorageProvider(ctypes_mod.byref(prov), _DEFAULT_KSP_NAME, 0) + status = ncrypt.NCryptOpenStorageProvider(ctypes_mod.byref(prov), "Microsoft Software Key Storage Provider", 0) _check_security_status(status, "NCryptOpenStorageProvider") key = win32["NCRYPT_KEY_HANDLE"]() - key_name = "MsalMsiV2Key_" + _new_correlation_id() - flags = win32["NCRYPT_OVERWRITE_KEY_FLAG"] | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG | _NCRYPT_USE_PER_BOOT_KEY_FLAG + # 1) Try open first + status = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), str(key_name), _AT_SIGNATURE, 0) + if int(status) == 0: + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error("[msi_v2] Virtual Iso property missing/invalid; Credential Guard likely not active.") + return prov, key, str(key_name), True + + if not _is_key_not_found(status): + _check_security_status(status, f"NCryptOpenKey({key_name})") + + # 2) Create if missing + flags = ( + win32["NCRYPT_OVERWRITE_KEY_FLAG"] + | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG + | _NCRYPT_USE_PER_BOOT_KEY_FLAG + ) - # IMPORTANT: - # When a certificate is bound to a CNG key via CERT_KEY_PROV_INFO_PROP_ID, Schannel/WinHTTP - # re-opens the key using NCryptOpenKey and passes CRYPT_KEY_PROV_INFO.dwKeySpec as the legacy - # keyspec parameter. The CRYPT_KEY_PROV_INFO docs specify that for CNG (dwProvType==0), - # dwKeySpec must be AT_SIGNATURE or AT_KEYEXCHANGE (not CERT_NCRYPT_KEY_SPEC). - # - # We therefore create the key with dwLegacyKeySpec=AT_SIGNATURE so re-open works reliably. status = ncrypt.NCryptCreatePersistedKey( prov, ctypes_mod.byref(key), win32["BCRYPT_RSA_ALGORITHM"], - key_name, + str(key_name), _AT_SIGNATURE, flags, ) - try: - _check_security_status(status, "NCryptCreatePersistedKey") - - # Length must be DWORD. - length = wintypes.DWORD(int(_RSA_KEY_SIZE)) - status = ncrypt.NCryptSetProperty( - key, - win32["NCRYPT_LENGTH_PROPERTY"], - ctypes_mod.byref(length), - ctypes_mod.sizeof(length), - 0, - ) - _check_security_status(status, "NCryptSetProperty(Length)") - - # Key usage: signing is required; decrypt doesn't hurt for TLS use-cases. - usage = wintypes.DWORD(win32["NCRYPT_ALLOW_SIGNING_FLAG"] | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) - status = ncrypt.NCryptSetProperty( - key, - win32["NCRYPT_KEY_USAGE_PROPERTY"], - ctypes_mod.byref(usage), - ctypes_mod.sizeof(usage), - 0, - ) - _check_security_status(status, "NCryptSetProperty(Key Usage)") - - # Export policy: 0 (disallow export). - export_policy = wintypes.DWORD(0) - status = ncrypt.NCryptSetProperty( - key, - win32["NCRYPT_EXPORT_POLICY_PROPERTY"], - ctypes_mod.byref(export_policy), - ctypes_mod.sizeof(export_policy), - 0, - ) - _check_security_status(status, "NCryptSetProperty(Export Policy)") + if _status_u32(status) == _NTE_EXISTS: + # Race: another thread/process created it. + status2 = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), str(key_name), _AT_SIGNATURE, 0) + _check_security_status(status2, f"NCryptOpenKey({key_name}) after exists") + return prov, key, str(key_name), True - status = ncrypt.NCryptFinalizeKey(key, 0) - _check_security_status(status, "NCryptFinalizeKey") + _check_security_status(status, "NCryptCreatePersistedKey") - # Validate Virtual Iso property (Credential Guard / VBS). Helps fail fast if KeyGuard isn't active. - try: - vi = _ncrypt_get_property(win32, key, "Virtual Iso") - if vi is None or len(vi) < 4: - raise MsiV2Error("[msi_v2] Virtual Iso property missing/invalid; Credential Guard likely not active.") - except Exception as exc: - raise MsiV2Error("[msi_v2] Virtual Iso property not available; Credential Guard likely not active.") from exc + # Set properties only when we created the key. + length = wintypes.DWORD(int(_RSA_KEY_SIZE)) + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_LENGTH_PROPERTY"], ctypes_mod.byref(length), ctypes_mod.sizeof(length), 0) + _check_security_status(status, "NCryptSetProperty(Length)") - return prov, key, key_name + usage = wintypes.DWORD(win32["NCRYPT_ALLOW_SIGNING_FLAG"] | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_KEY_USAGE_PROPERTY"], ctypes_mod.byref(usage), ctypes_mod.sizeof(usage), 0) + _check_security_status(status, "NCryptSetProperty(Key Usage)") - except Exception: - # Best-effort cleanup. The caller also cleans up in obtain_token(). - try: - if key: - ncrypt.NCryptDeleteKey(key, 0) - except Exception: - pass - try: - if key: - ncrypt.NCryptFreeObject(key) - except Exception: - pass - try: - if prov: - ncrypt.NCryptFreeObject(prov) - except Exception: - pass - raise + export_policy = wintypes.DWORD(0) # non-exportable + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_EXPORT_POLICY_PROPERTY"], ctypes_mod.byref(export_policy), ctypes_mod.sizeof(export_policy), 0) + _check_security_status(status, "NCryptSetProperty(Export Policy)") + status = ncrypt.NCryptFinalizeKey(key, 0) + _check_security_status(status, "NCryptFinalizeKey") -def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int]: - """ - Export RSA public key (modulus, exponent) from an NCrypt key handle. + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error("[msi_v2] Virtual Iso property not available; Credential Guard likely not active.") - We export as BCRYPT_RSAPUBLIC_BLOB and parse it. - """ + return prov, key, str(key_name), False + + +def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int]: + """Export RSA public key (modulus, exponent) from an NCrypt key handle.""" from .managed_identity import MsiV2Error ctypes_mod = win32["ctypes"] @@ -978,18 +727,14 @@ def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int _check_security_status(status, "NCryptExportKey(RSAPUBLICBLOB)") blob = bytes(buf[: cb.value]) - # BCRYPT_RSAKEY_BLOB header is 6 DWORDs, little-endian. if len(blob) < 24: raise MsiV2Error("[msi_v2] RSAPUBLICBLOB too small") import struct - magic, bitlen, cb_exp, cb_mod, cb_p1, cb_p2 = struct.unpack("<6I", blob[:24]) if magic != win32["BCRYPT_RSAPUBLIC_MAGIC"]: raise MsiV2Error(f"[msi_v2] RSAPUBLICBLOB magic mismatch: 0x{magic:08X}") - if cb_p1 != 0 or cb_p2 != 0: - # Public blob should have primes=0; ignore if present (defensive). logger.debug("[msi_v2] RSAPUBLICBLOB contains primes unexpectedly (ignored).") offset = 24 @@ -1003,7 +748,6 @@ def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int exponent = int.from_bytes(exp_bytes, "big") modulus = int.from_bytes(mod_bytes, "big") - # sanity: header bitlen should match modulus bit length (often does). if bitlen != modulus.bit_length(): logger.debug("[msi_v2] RSA bit length mismatch: header=%d computed=%d", bitlen, modulus.bit_length()) @@ -1011,11 +755,7 @@ def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int def _ncrypt_sign_pss_sha256(win32: Dict[str, Any], key: Any, digest: bytes) -> bytes: - """ - Sign a SHA-256 digest using RSA-PSS via NCryptSignHash. - - NCryptSignHash expects the *hash digest*, not the original message. - """ + """Sign a SHA-256 digest using RSA-PSS via NCryptSignHash.""" from .managed_identity import MsiV2Error if len(digest) != 32: @@ -1031,48 +771,25 @@ def _ncrypt_sign_pss_sha256(win32: Dict[str, Any], key: Any, digest: bytes) -> b hash_buf = (ctypes_mod.c_ubyte * len(digest)).from_buffer_copy(digest) cb_sig = wintypes.DWORD(0) - status = ncrypt.NCryptSignHash( - key, - ctypes_mod.byref(pad), - hash_buf, - len(digest), - None, - 0, - ctypes_mod.byref(cb_sig), - win32["BCRYPT_PAD_PSS"], - ) + status = ncrypt.NCryptSignHash(key, ctypes_mod.byref(pad), hash_buf, len(digest), None, 0, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) if int(status) != 0 and cb_sig.value == 0: _check_security_status(status, "NCryptSignHash(size)") if cb_sig.value == 0: raise MsiV2Error("[msi_v2] NCryptSignHash returned empty signature size") sig_buf = (ctypes_mod.c_ubyte * cb_sig.value)() - status = ncrypt.NCryptSignHash( - key, - ctypes_mod.byref(pad), - hash_buf, - len(digest), - sig_buf, - cb_sig.value, - ctypes_mod.byref(cb_sig), - win32["BCRYPT_PAD_PSS"], - ) + status = ncrypt.NCryptSignHash(key, ctypes_mod.byref(pad), hash_buf, len(digest), sig_buf, cb_sig.value, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) _check_security_status(status, "NCryptSignHash") - return bytes(sig_buf[: cb_sig.value]) -# -------------------------------------------------------------------------------------- +# ---------------------------- # CSR builder (KeyGuard key handle) -# -------------------------------------------------------------------------------------- - +# ---------------------------- def _build_csr_b64(win32: Dict[str, Any], key: Any, client_id: str, tenant_id: str, cu_id: Any) -> str: """ - Build a PKCS#10 CSR signed by the KeyGuard key (RSA-PSS/SHA256) and return base64. - - The CSR includes a request attribute (OID _CU_ID_OID_STR) whose value is: - DER UTF8String(JSON(cuId)) + Build CSR signed by KeyGuard key (RSA-PSS SHA256), including cuId request attribute. """ modulus, exponent = _ncrypt_export_rsa_public(win32, key) @@ -1082,10 +799,10 @@ def _build_csr_b64(win32: Dict[str, Any], key: Any, client_id: str, tenant_id: s cuid_json = json.dumps(cu_id, separators=(",", ":"), ensure_ascii=False) cuid_val = _der_utf8string(cuid_json) - # Attribute: SEQUENCE { OID, SET { } } + # Attribute: SEQUENCE { OID, SET { } } attr = _der_sequence(_der_oid(_CU_ID_OID_STR), _der_set(cuid_val)) - # PKCS#10 attributes: [0] IMPLICIT SET OF Attribute + # attributes [0] IMPLICIT SET OF Attribute attrs_content = b"".join(sorted([attr])) attrs = _der_context_implicit_constructed(0, attrs_content) @@ -1098,10 +815,9 @@ def _build_csr_b64(win32: Dict[str, Any], key: Any, client_id: str, tenant_id: s return base64.b64encode(csr).decode("ascii") -# -------------------------------------------------------------------------------------- +# ---------------------------- # Certificate binding + WinHTTP mTLS -# -------------------------------------------------------------------------------------- - +# ---------------------------- def _create_cert_context_with_key( win32: Dict[str, Any], @@ -1109,49 +825,34 @@ def _create_cert_context_with_key( key: Any, key_name: str, *, - ksp_name: str = _DEFAULT_KSP_NAME, -) -> Tuple[Any, Tuple[Any, ...]]: + ksp_name: str = "Microsoft Software Key Storage Provider", +) -> Tuple[Any, Any, Tuple[Any, ...]]: """ - Create a CERT_CONTEXT from DER bytes and associate it with the given CNG private key. - - Why set multiple properties? - - WinHTTP/SChannel sometimes fails to locate the private key unless the cert context contains - enough information. We set (best-effort): + Create a CERT_CONTEXT from DER bytes and associate it with a CNG private key. - * CERT_NCRYPT_KEY_HANDLE_PROP_ID (78) - direct handle binding - * CERT_KEY_CONTEXT_PROP_ID (5) - CERT_KEY_CONTEXT with hNCryptKey + CERT_NCRYPT_KEY_SPEC - * CERT_KEY_PROV_INFO_PROP_ID (2) - CRYPT_KEY_PROV_INFO referencing the persisted key name + We set multiple properties to maximize compatibility across Schannel consumers: + - CERT_NCRYPT_KEY_HANDLE_PROP_ID (78): direct handle + - CERT_KEY_CONTEXT_PROP_ID (5): CERT_KEY_CONTEXT union (best-effort) + - CERT_KEY_PROV_INFO_PROP_ID (2): CRYPT_KEY_PROV_INFO so Schannel can reopen by name - The returned keepalive tuple MUST remain referenced for as long as the CERT_CONTEXT is used, - because it contains buffers referenced by the cert properties. - - Returns: - (PCCERT_CONTEXT, keepalive) + NOTE: For CNG keys (dwProvType==0), CRYPT_KEY_PROV_INFO.dwKeySpec must be AT_SIGNATURE or AT_KEYEXCHANGE. """ - from .managed_identity import MsiV2Error - ctypes_mod = win32["ctypes"] wintypes = win32["wintypes"] crypt32 = win32["crypt32"] enc = win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"] - # Keep DER bytes alive to be safe. - cert_buf = ctypes_mod.create_string_buffer(cert_der) - ctx = crypt32.CertCreateCertificateContext(enc, cert_buf, len(cert_der)) + buf = ctypes_mod.create_string_buffer(cert_der) + ctx = crypt32.CertCreateCertificateContext(enc, buf, len(cert_der)) if not ctx: _raise_win32_last_error("[msi_v2] CertCreateCertificateContext failed") - keepalive: List[Any] = [cert_buf] + keepalive: List[Any] = [buf] try: - key_value = int(getattr(key, "value", key) or 0) - if not key_value: - raise MsiV2Error("[msi_v2] Invalid CNG key handle (0)") - - # --- (A) CERT_NCRYPT_KEY_HANDLE_PROP_ID (78) - key_handle = ctypes_mod.c_void_p(key_value) + # (A) direct handle + key_handle = ctypes_mod.c_void_p(int(key.value)) keepalive.append(key_handle) ok = crypt32.CertSetCertificateContextProperty( @@ -1163,14 +864,14 @@ def _create_cert_context_with_key( if not ok: _raise_win32_last_error("[msi_v2] CertSetCertificateContextProperty(CERT_NCRYPT_KEY_HANDLE_PROP_ID) failed") - # --- (B) CERT_KEY_CONTEXT_PROP_ID (5) - optional but helpful + # (B) CERT_KEY_CONTEXT_PROP_ID (best-effort) CERT_KEY_CONTEXT_PROP_ID = 5 - CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF # wincrypt.h: CERT_NCRYPT_KEY_SPEC + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF class CERT_KEY_CONTEXT(ctypes_mod.Structure): _fields_ = [ ("cbSize", wintypes.DWORD), - ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), # union: HCRYPTPROV / NCRYPT_KEY_HANDLE + ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), ("dwKeySpec", wintypes.DWORD), ] @@ -1184,10 +885,9 @@ class CERT_KEY_CONTEXT(ctypes_mod.Structure): ctypes_mod.byref(key_ctx), ) if not ok: - # Not fatal in all environments; keep going. logger.debug("[msi_v2] Failed to set CERT_KEY_CONTEXT_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) - # --- (C) CERT_KEY_PROV_INFO_PROP_ID (2) - allows Schannel to re-open key by name + # (C) CERT_KEY_PROV_INFO_PROP_ID (so Schannel can reopen by name) CERT_KEY_PROV_INFO_PROP_ID = 2 class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): @@ -1205,16 +905,14 @@ class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): provider_buf = ctypes_mod.create_unicode_buffer(str(ksp_name)) keepalive.extend([container_buf, provider_buf]) - # For CNG keys (dwProvType==0), dwKeySpec is passed as dwLegacyKeySpec to NCryptOpenKey. - # It must be AT_SIGNATURE or AT_KEYEXCHANGE per CRYPT_KEY_PROV_INFO docs. prov_info = CRYPT_KEY_PROV_INFO( ctypes_mod.cast(container_buf, wintypes.LPWSTR), ctypes_mod.cast(provider_buf, wintypes.LPWSTR), - wintypes.DWORD(0), # CNG + wintypes.DWORD(0), # dwProvType=0 for CNG/KSP wintypes.DWORD(_NCRYPT_SILENT_FLAG), wintypes.DWORD(0), None, - wintypes.DWORD(_AT_SIGNATURE), + wintypes.DWORD(_AT_SIGNATURE), # IMPORTANT ) keepalive.append(prov_info) @@ -1225,10 +923,9 @@ class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): ctypes_mod.byref(prov_info), ) if not ok: - # If this fails, WinHTTP may still work via CERT_NCRYPT_KEY_HANDLE_PROP_ID. logger.debug("[msi_v2] Failed to set CERT_KEY_PROV_INFO_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) - return ctx, tuple(keepalive) + return ctx, buf, tuple(keepalive) except Exception: try: @@ -1238,42 +935,25 @@ class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): raise -def _winhttp_close(win32: Dict[str, Any], handle: Any) -> None: - """Close a WinHTTP HINTERNET handle (best-effort).""" +def _winhttp_close(win32: Dict[str, Any], h: Any) -> None: try: - if handle: - win32["winhttp"].WinHttpCloseHandle(handle) + if h: + win32["winhttp"].WinHttpCloseHandle(h) except Exception: pass -def _winhttp_set_option_dword(win32: Dict[str, Any], handle: Any, option: int, value: int, *, fatal: bool = False) -> None: - """Set a WinHTTP option that takes a DWORD value.""" - ctypes_mod = win32["ctypes"] - wintypes = win32["wintypes"] - winhttp = win32["winhttp"] - - v = wintypes.DWORD(int(value)) - ok = winhttp.WinHttpSetOption(handle, option, ctypes_mod.byref(v), ctypes_mod.sizeof(v)) - if not ok and fatal: - _raise_win32_last_error(f"[msi_v2] WinHttpSetOption({option}) failed") - - def _winhttp_post(win32: Dict[str, Any], url: str, cert_ctx: Any, body: bytes, headers: Dict[str, str]) -> Tuple[int, bytes]: """ - POST bytes to an https:// URL using WinHTTP + SChannel, presenting the provided cert context. - - Returns: - (status_code, response_body_bytes) + POST bytes to https URL using WinHTTP + SChannel, presenting the provided cert context. """ from .managed_identity import MsiV2Error + from urllib.parse import urlparse ctypes_mod = win32["ctypes"] wintypes = win32["wintypes"] winhttp = win32["winhttp"] - from urllib.parse import urlparse - u = urlparse(url) if u.scheme.lower() != "https": raise MsiV2Error(f"[msi_v2] Token endpoint must be https, got: {url!r}") @@ -1286,158 +966,114 @@ def _winhttp_post(win32: Dict[str, Any], url: str, cert_ctx: Any, body: bytes, h if u.query: path += "?" + u.query - # WinHTTP uses UTF-16 wide strings. - user_agent = "msal-python-msi-v2" - - h_session = winhttp.WinHttpOpen( - user_agent, - win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], - None, - None, - 0, - ) + h_session = winhttp.WinHttpOpen("msal-python-msi-v2", win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], None, None, 0) if not h_session: _raise_win32_last_error("[msi_v2] WinHttpOpen failed") + h_connect = None + h_request = None try: - # Best-effort: ensure client cert context is honored even when HTTP/2 is negotiated. - # Not all Windows builds support this option; ignore failures. - _winhttp_set_option_dword(win32, h_session, win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], 1, fatal=False) + # Best-effort: allow WinHTTP to send client cert for HTTP/2 negotiated connections. + enable = wintypes.DWORD(1) + try: + winhttp.WinHttpSetOption(h_session, win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], ctypes_mod.byref(enable), ctypes_mod.sizeof(enable)) + except Exception: + pass h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) if not h_connect: _raise_win32_last_error("[msi_v2] WinHttpConnect failed") - try: - h_request = winhttp.WinHttpOpenRequest( - h_connect, - "POST", - path, - None, - None, - None, - win32["WINHTTP_FLAG_SECURE"], - ) - if not h_request: - _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") - try: - # Set client certificate context on request. - CertContext = win32["CERT_CONTEXT"] - ok = winhttp.WinHttpSetOption( - h_request, - win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], - cert_ctx, - ctypes_mod.sizeof(CertContext), - ) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpSetOption(WINHTTP_OPTION_CLIENT_CERT_CONTEXT) failed") - - header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) - header_str = header_lines # unicode / wide - - body_buf = ctypes_mod.create_string_buffer(body) - - ok = winhttp.WinHttpSendRequest( - h_request, - header_str, - 0xFFFFFFFF, # -1L (auto compute) - body_buf, - len(body), - len(body), - 0, - ) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") - - ok = winhttp.WinHttpReceiveResponse(h_request, None) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") - - # Query status code as DWORD. - status = wintypes.DWORD(0) - status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) - index = wintypes.DWORD(0) - ok = winhttp.WinHttpQueryHeaders( - h_request, - win32["WINHTTP_QUERY_STATUS_CODE"] | win32["WINHTTP_QUERY_FLAG_NUMBER"], - None, - ctypes_mod.byref(status), - ctypes_mod.byref(status_size), - ctypes_mod.byref(index), - ) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpQueryHeaders(WINHTTP_QUERY_STATUS_CODE) failed") - - chunks: List[bytes] = [] - while True: - avail = wintypes.DWORD(0) - ok = winhttp.WinHttpQueryDataAvailable(h_request, ctypes_mod.byref(avail)) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpQueryDataAvailable failed") - if avail.value == 0: - break - buf = (ctypes_mod.c_ubyte * avail.value)() - read = wintypes.DWORD(0) - ok = winhttp.WinHttpReadData(h_request, buf, avail.value, ctypes_mod.byref(read)) - if not ok: - _raise_win32_last_error("[msi_v2] WinHttpReadData failed") - if read.value: - chunks.append(bytes(buf[: read.value])) - if read.value == 0: - break - - return int(status.value), b"".join(chunks) - finally: - _winhttp_close(win32, h_request) - finally: - _winhttp_close(win32, h_connect) + + h_request = winhttp.WinHttpOpenRequest(h_connect, "POST", path, None, None, None, win32["WINHTTP_FLAG_SECURE"]) + if not h_request: + _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") + + # Attach cert context (mTLS). + CertContext = win32["CERT_CONTEXT"] + ok = winhttp.WinHttpSetOption(h_request, win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], cert_ctx, ctypes_mod.sizeof(CertContext)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSetOption(WINHTTP_OPTION_CLIENT_CERT_CONTEXT) failed") + + header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) + + body_buf = ctypes_mod.create_string_buffer(body) + ok = winhttp.WinHttpSendRequest(h_request, header_lines, 0xFFFFFFFF, body_buf, len(body), len(body), 0) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") + + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + + ok = winhttp.WinHttpQueryHeaders( + h_request, + win32["WINHTTP_QUERY_STATUS_CODE"] | win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, + ctypes_mod.byref(status), + ctypes_mod.byref(status_size), + ctypes_mod.byref(index), + ) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryHeaders(WINHTTP_QUERY_STATUS_CODE) failed") + + chunks: List[bytes] = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable(h_request, ctypes_mod.byref(avail)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData(h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[: read.value])) + if read.value == 0: + break + + return int(status.value), b"".join(chunks) finally: + _winhttp_close(win32, h_request) + _winhttp_close(win32, h_connect) _winhttp_close(win32, h_session) -def _acquire_token_mtls_schannel( - win32: Dict[str, Any], - token_endpoint: str, - cert_ctx: Any, - client_id: str, - scope: str, -) -> Dict[str, Any]: - """ - Acquire an mtls_pop token from ESTS using WinHTTP/SChannel with the provided cert context. - """ +def _acquire_token_mtls_schannel(win32: Dict[str, Any], token_endpoint: str, cert_ctx: Any, client_id: str, scope: str) -> Dict[str, Any]: + """Acquire an mtls_pop token from ESTS using WinHTTP/SChannel.""" from .managed_identity import MsiV2Error from urllib.parse import urlencode - form = urlencode( - { - "grant_type": "client_credentials", - "client_id": client_id, - "scope": scope, - "token_type": "mtls_pop", - } - ).encode("utf-8") + form = urlencode({ + "grant_type": "client_credentials", + "client_id": client_id, + "scope": scope, + "token_type": "mtls_pop", + }).encode("utf-8") status, resp_body = _winhttp_post( win32, token_endpoint, cert_ctx, form, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }, + headers={"Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json"}, ) text = resp_body.decode("utf-8", errors="replace") if status < 200 or status >= 300: raise MsiV2Error(f"[msi_v2] ESTS token request failed: HTTP {status} Body={text!r}") - return _json_loads(text, "ESTS token") -# -------------------------------------------------------------------------------------- +# ---------------------------- # Public API -# -------------------------------------------------------------------------------------- - +# ---------------------------- def obtain_token( http_client, @@ -1446,27 +1082,7 @@ def obtain_token( *, attestation_enabled: bool = True, ) -> Dict[str, Any]: - """ - Acquire an mtls_pop access token using Windows KeyGuard + attestation. - - Args: - http_client: - Requests-like object that provides .get() and .post() returning responses with - .status_code, .text, .headers. (MSAL passes its own session by default.) - managed_identity: - MSAL-managed identity selector dict (system-assigned or user-assigned). Used only - to set optional IMDS query params for UAMI. - resource: - Resource or scope. If it doesn't end with "/.default", we append "/.default". - attestation_enabled: - Must be True for this KeyGuard flow. If False, we fail closed. - - Returns: - Dict with access_token, expires_in, token_type, and optional resource. - - Raises: - MsiV2Error on any failure (no MSI v1 fallback). - """ + """Acquire mtls_pop token using Windows KeyGuard + MAA attestation.""" from .managed_identity import MsiV2Error win32 = _load_win32() @@ -1479,12 +1095,10 @@ def obtain_token( prov = None key = None - key_name = None cert_ctx = None - cert_keepalive: Optional[Tuple[Any, ...]] = None try: - # 1) Read platform metadata (client_id, tenant_id, cuId, attestation endpoint). + # 1) metadata meta_url = base + _CSR_METADATA_PATH meta = _imds_get_json(http_client, meta_url, params, _imds_headers(corr)) @@ -1496,44 +1110,37 @@ def obtain_token( if not client_id or not tenant_id or cu_id is None: raise MsiV2Error(f"[msi_v2] getplatformmetadata missing required fields: {meta}") - # 2) Create KeyGuard RSA key (NCrypt). - prov, key, key_name = _create_keyguard_rsa_key(win32) + # 2) Open-or-create named per-boot KeyGuard RSA key + key_name = os.getenv(_KEY_NAME_ENVVAR) or _stable_key_name(str(client_id)) + prov, key, key_name, opened = _open_or_create_keyguard_rsa_key(win32, key_name=key_name) + logger.debug("[msi_v2] KeyGuard key name=%s opened_existing=%s", key_name, opened) - # 3) CSR signed with RSA-PSS/SHA256, includes cuId request attribute. + # 3) CSR csr_b64 = _build_csr_b64(win32, key, str(client_id), str(tenant_id), cu_id) - # 4) Attestation JWT (required in this flow). + # 4) Attestation (required for KeyGuard flow in this scenario) if not attestation_enabled: raise MsiV2Error("[msi_v2] attestation_enabled must be True for this KeyGuard flow.") if not attestation_endpoint: raise MsiV2Error("[msi_v2] attestationEndpoint missing from metadata.") - key_handle_int = int(getattr(key, "value", 0) or 0) - if not key_handle_int: - raise MsiV2Error("[msi_v2] Invalid key handle for attestation") - + # Use a stable cache key so the attestation module can reuse the JWT across opens/closes. from .msi_v2_attestation import get_attestation_jwt - att_jwt = get_attestation_jwt( attestation_endpoint=str(attestation_endpoint), client_id=str(client_id), - key_handle=key_handle_int, + key_handle=int(key.value), + cache_key=f"{key_name}", # stable per boot ) if not att_jwt or not str(att_jwt).strip(): raise MsiV2Error("[msi_v2] Attestation token is missing/empty; refusing to call issuecredential.") - # 5) Exchange CSR + attestation for an issued certificate (IMDS /issuecredential). + # 5) issuecredential issue_url = base + _ISSUE_CREDENTIAL_PATH issue_headers = _imds_headers(corr) issue_headers["Content-Type"] = "application/json" - cred = _imds_post_json( - http_client, - issue_url, - params, - issue_headers, - {"csr": csr_b64, "attestation_token": att_jwt}, - ) + cred = _imds_post_json(http_client, issue_url, params, issue_headers, {"csr": csr_b64, "attestation_token": att_jwt}) cert_b64 = _get_first(cred, "certificate", "Certificate") if not cert_b64: @@ -1547,10 +1154,10 @@ def obtain_token( canonical_client_id = _get_first(cred, "client_id", "clientId") or str(client_id) token_endpoint = _token_endpoint_from_credential(cred) - # 6) Bind KeyGuard key to the issued cert and request token over mTLS (SChannel). - cert_ctx, cert_keepalive = _create_cert_context_with_key(win32, cert_der, key, str(key_name)) - + # 6) Bind key->cert then call ESTS over mTLS using SChannel + cert_ctx, _, _ = _create_cert_context_with_key(win32, cert_der, key, key_name) scope = _resource_to_scope(resource) + token_json = _acquire_token_mtls_schannel(win32, token_endpoint, cert_ctx, canonical_client_id, scope) if token_json.get("access_token") and token_json.get("expires_in"): @@ -1560,25 +1167,17 @@ def obtain_token( "token_type": token_json.get("token_type") or "mtls_pop", "resource": token_json.get("resource"), } - - # Some error shapes could still be JSON; return raw for caller to interpret. return token_json finally: - # Cleanup: cert context (WinHTTP duplicates it internally during request). + # Crypt32: free cert context (request-scoped) try: if cert_ctx: crypt32.CertFreeCertificateContext(cert_ctx) except Exception: pass - # Cleanup: key and provider handles. - # The key is persisted, so we delete it explicitly and then free handles. - try: - if key: - ncrypt.NCryptDeleteKey(key, 0) - except Exception: - pass + # NCrypt: release handles (key persists for the boot because it is named + per-boot) try: if key: ncrypt.NCryptFreeObject(key) @@ -1589,6 +1188,3 @@ def obtain_token( ncrypt.NCryptFreeObject(prov) except Exception: pass - - # keepalive is intentionally unused; it just keeps buffers alive while cert_ctx existed. - _ = cert_keepalive diff --git a/msal/msi_v2_attestation.py b/msal/msi_v2_attestation.py index d46f9338..e0fe2f7b 100644 --- a/msal/msi_v2_attestation.py +++ b/msal/msi_v2_attestation.py @@ -5,29 +5,38 @@ """ Windows attestation for MSI v2 KeyGuard keys using AttestationClientLib.dll. -Equivalent to your PowerShell / C# P/Invoke signatures: - - int InitAttestationLib(ref AttestationLogInfo info); - int AttestKeyGuardImportKey(string endpoint, string authToken, string clientPayload, - IntPtr keyHandle, out IntPtr token, string clientId); - void FreeAttestationToken(IntPtr token); - void UninitAttestationLib(); +This module calls into AttestationClientLib.dll to mint an attestation JWT for a KeyGuard key handle. +It also provides a small in-memory cache to reuse the attestation JWT until ~90% of its lifetime. + +Caching notes: + - Cache is process-local (in-memory). It does not persist across process restarts. + - Cache is keyed by (attestation_endpoint, client_id, cache_key/auth_token/payload). + Provide a stable cache_key (e.g., the named per-boot key name) to maximize hits. + - If the token cannot be parsed or has no exp claim, it is not cached. + +Env vars: + - ATTESTATION_CLIENTLIB_PATH: absolute path to AttestationClientLib.dll (optional) + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" disables caching (default enabled) """ from __future__ import annotations +import base64 import ctypes +import json import logging import os import sys +import threading +import time from ctypes import POINTER, Structure, c_char_p, c_int, c_void_p +from dataclasses import dataclass +from typing import Optional, Tuple logger = logging.getLogger(__name__) -# keep callback alive _NATIVE_LOG_CB = None - # void LogFunc(void* ctx, const char* tag, int lvl, const char* func, int line, const char* msg); _LogFunc = ctypes.CFUNCTYPE(None, c_void_p, c_char_p, c_int, c_char_p, c_int, c_char_p) @@ -46,10 +55,13 @@ def _default_logger(ctx, tag, lvl, func, line, msg): pass +def _truthy_env(name: str, default: str = "1") -> bool: + val = os.getenv(name, default) + return (val or "").strip().lower() in ("1", "true", "yes", "y", "on") + + def _maybe_add_dll_dirs(): - """ - Make DLL resolution more reliable (especially for packaged apps). - """ + """Make DLL resolution more reliable (especially for packaged apps).""" if sys.platform != "win32": return @@ -57,29 +69,12 @@ def _maybe_add_dll_dirs(): if not add_dir: return - # exe dir - try: - exe_dir = os.path.dirname(sys.executable) - if exe_dir and os.path.isdir(exe_dir): - add_dir(exe_dir) - except Exception: - pass - - # cwd - try: - cwd = os.getcwd() - if cwd and os.path.isdir(cwd): - add_dir(cwd) - except Exception: - pass - - # module dir - try: - mod_dir = os.path.dirname(__file__) - if mod_dir and os.path.isdir(mod_dir): - add_dir(mod_dir) - except Exception: - pass + for p in (os.path.dirname(sys.executable), os.getcwd(), os.path.dirname(__file__)): + try: + if p and os.path.isdir(p): + add_dir(p) + except Exception: + pass def _load_lib(): @@ -102,6 +97,99 @@ def _load_lib(): ) from exc +def _b64url_decode(s: str) -> bytes: + s = (s or "").strip() + s += "=" * ((4 - len(s) % 4) % 4) + return base64.urlsafe_b64decode(s.encode("ascii")) + + +def _try_extract_exp_iat(jwt: str) -> Tuple[Optional[int], Optional[int]]: + """ + Extract exp and iat (Unix seconds) from a JWT without validating signature. + Returns (exp, iat). Either can be None. + """ + try: + parts = jwt.split(".") + if len(parts) < 2: + return None, None + payload = json.loads(_b64url_decode(parts[1]).decode("utf-8", errors="replace")) + if not isinstance(payload, dict): + return None, None + + def _to_int(v): + if isinstance(v, bool): + return None + if isinstance(v, int): + return v + if isinstance(v, float): + return int(v) + if isinstance(v, str) and v.strip().isdigit(): + return int(v.strip()) + return None + + exp = _to_int(payload.get("exp")) + iat = _to_int(payload.get("iat")) + return exp, iat + except Exception: + return None, None + + +@dataclass(frozen=True) +class _CacheKey: + attestation_endpoint: str + client_id: str + cache_key: str + auth_token: str + client_payload: str + + +@dataclass +class _CacheEntry: + jwt: str + exp: int + refresh_after: float # epoch seconds when we should refresh + + +_CACHE_LOCK = threading.Lock() +_CACHE: dict[_CacheKey, _CacheEntry] = {} + + +def _cache_lookup(key: _CacheKey) -> Optional[str]: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return None + + now = time.time() + with _CACHE_LOCK: + entry = _CACHE.get(key) + if not entry: + return None + if now >= entry.refresh_after or now >= entry.exp - 5: + return None + return entry.jwt + + +def _cache_store(key: _CacheKey, jwt: str) -> None: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return + + exp, iat = _try_extract_exp_iat(jwt) + if exp is None: + return + + now = int(time.time()) + issued_at = iat if iat is not None else now + lifetime = exp - issued_at + if lifetime <= 0: + return + + # Refresh at 90% of lifetime, with a small absolute guard. + refresh_after = issued_at + (0.90 * lifetime) + refresh_after = min(refresh_after, exp - 10) + + with _CACHE_LOCK: + _CACHE[key] = _CacheEntry(jwt=jwt, exp=exp, refresh_after=float(refresh_after)) + + def get_attestation_jwt( *, attestation_endpoint: str, @@ -109,9 +197,15 @@ def get_attestation_jwt( key_handle: int, auth_token: str = "", client_payload: str = "{}", + cache_key: Optional[str] = None, ) -> str: """ Returns attestation JWT string. Raises MsiV2Error on failure. + + cache_key: + - Optional stable identifier used for caching (recommended: named per-boot key name). + - If not provided, key_handle is used as part of the cache key (less cache-friendly + if the key is opened/closed between calls). """ from .managed_identity import MsiV2Error @@ -122,6 +216,19 @@ def get_attestation_jwt( if not key_handle: raise MsiV2Error("[msi_v2_attestation] key_handle must be non-zero") + stable_cache_key = cache_key if cache_key is not None else f"handle:{int(key_handle)}" + ck = _CacheKey( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + cache_key=str(stable_cache_key), + auth_token=str(auth_token or ""), + client_payload=str(client_payload or "{}"), + ) + + cached = _cache_lookup(ck) + if cached: + return cached + lib = _load_lib() lib.InitAttestationLib.argtypes = [POINTER(AttestationLogInfo)] @@ -158,8 +265,8 @@ def get_attestation_jwt( try: rc = lib.AttestKeyGuardImportKey( attestation_endpoint.encode("utf-8"), - auth_token.encode("utf-8"), - client_payload.encode("utf-8"), + (auth_token or "").encode("utf-8"), + (client_payload or "{}").encode("utf-8"), c_void_p(int(key_handle)), ctypes.byref(token_ptr), client_id.encode("utf-8"), @@ -170,6 +277,10 @@ def get_attestation_jwt( raise MsiV2Error("[msi_v2_attestation] Attestation token pointer is NULL") token = ctypes.string_at(token_ptr.value).decode("utf-8", errors="replace") + if not token or "." not in token: + raise MsiV2Error("[msi_v2_attestation] Attestation token looks malformed") + + _cache_store(ck, token) return token finally: try: diff --git a/run_msi_v2_once.py b/run_msi_v2_once.py deleted file mode 100644 index 49165d3f..00000000 --- a/run_msi_v2_once.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -MSI v2 (mTLS PoP + KeyGuard Attestation) minimal sample for MSAL Python. - -Behavior: -- Requests mtls_pop + attestation -- STRICT: succeeds only if token_type == mtls_pop -- Prints ONLY: "token received" -- No resource call -""" - -import json -import os -import sys -import msal -import requests - - -DEFAULT_RESOURCE = "https://graph.microsoft.com" -RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") - -client = msal.ManagedIdentityClient( - msal.SystemAssignedManagedIdentity(), - http_client=requests.Session(), - token_cache=msal.TokenCache(), -) - - -def acquire_mtls_pop_token_strict(): - result = client.acquire_token_for_client( - resource=RESOURCE, - mtls_proof_of_possession=True, - with_attestation_support=True, - ) - - if "access_token" not in result: - print("FAIL: token acquisition failed") - return 2 - - token_type = result.get("token_type", "mtls_pop") - print("SUCCESS: token acquired") - print(" resource =", RESOURCE) - print(" is_mtls_pop =", token_type == "mtls_pop") - - # Minimal proof we got a real JWT-ish token (don't print it) - at = result["access_token"] - print(" token_len =", len(at)) - - -if __name__ == "__main__": - try: - acquire_mtls_pop_token_strict() - print("token received") - sys.exit(0) - except Exception as ex: - print("FAIL:", ex) - sys.exit(2) diff --git a/working/run_msi_v2_once.py b/working/run_msi_v2_once.py new file mode 100644 index 00000000..47ffdc7b --- /dev/null +++ b/working/run_msi_v2_once.py @@ -0,0 +1,56 @@ +import json +import os +import sys + +import msal +import requests + +DEFAULT_RESOURCE = "https://graph.microsoft.com" +RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") + +session = requests.Session() +cache = msal.TokenCache() + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=session, + token_cache=cache, +) + +def acquire_mtls_pop_token_strict(): + result = client.acquire_token_for_client( + resource=RESOURCE, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + + if "access_token" not in result: + raise RuntimeError(f"Token acquisition failed: {json.dumps(result, indent=2)}") + + token_type = (result.get("token_type") or "Bearer").lower() + if token_type != "mtls_pop": + raise RuntimeError( + f"Strict MSI v2 requested, but got token_type={result.get('token_type')}. " + f"Full result: {json.dumps(result, indent=2)}" + ) + + return result + +if __name__ == "__main__": + try: + r1 = acquire_mtls_pop_token_strict() + print("token received (1)") + + r2 = acquire_mtls_pop_token_strict() + print("token received (2)") + + # If MSAL exposes a cache indicator, print it (optional) + ts1 = r1.get("token_source") or r1.get("source") or "" + ts2 = r2.get("token_source") or r2.get("source") or "" + if ts1 or ts2: + print(f"source1={ts1} source2={ts2}") + + sys.exit(0) + except Exception as ex: + print("FAIL:", ex) + sys.exit(2) \ No newline at end of file From 6237b8bef9007743ee5b1b56f16abb133e2db509 Mon Sep 17 00:00:00 2001 From: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com> Date: Tue, 24 Feb 2026 18:37:12 -0800 Subject: [PATCH 4/5] Potential fix for code scanning alert no. 92: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- working/run_msi_v2_once.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/working/run_msi_v2_once.py b/working/run_msi_v2_once.py index 47ffdc7b..a34558c6 100644 --- a/working/run_msi_v2_once.py +++ b/working/run_msi_v2_once.py @@ -44,11 +44,11 @@ def acquire_mtls_pop_token_strict(): r2 = acquire_mtls_pop_token_strict() print("token received (2)") - # If MSAL exposes a cache indicator, print it (optional) + # If MSAL exposes a cache indicator, avoid printing its concrete value to logs ts1 = r1.get("token_source") or r1.get("source") or "" ts2 = r2.get("token_source") or r2.get("source") or "" if ts1 or ts2: - print(f"source1={ts1} source2={ts2}") + print("token sources are available (not logged for security)") sys.exit(0) except Exception as ex: From 78f20aa923dd75788945616e230ee06539977e2a Mon Sep 17 00:00:00 2001 From: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com> Date: Wed, 25 Feb 2026 11:18:06 -0800 Subject: [PATCH 5/5] Return certificate with token for mTLS Enhance mTLS token acquisition by returning the certificate alongside the access token. --- msal/msi_v2.py | 182 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 153 insertions(+), 29 deletions(-) diff --git a/msal/msi_v2.py b/msal/msi_v2.py index e6db2b0e..efd29c9e 100644 --- a/msal/msi_v2.py +++ b/msal/msi_v2.py @@ -16,6 +16,7 @@ - Uses a *named per-boot key*: opens the key if it already exists for this boot; otherwise creates it. - No MSI v1 fallback: any MSI v2 failure raises MsiV2Error. - Production-ready handle management: all WinHTTP / Crypt32 / NCrypt handles are released. + - Returns certificate with token for mTLS with resource. Environment variables (optional): - AZURE_POD_IDENTITY_AUTHORITY_HOST: override IMDS base URL (default http://169.254.169.254) @@ -36,6 +37,10 @@ logger = logging.getLogger(__name__) +# ---------------------------- +# IMDS constants +# ---------------------------- + _IMDS_DEFAULT_BASE = "http://169.254.169.254" _IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" @@ -48,9 +53,13 @@ _CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" -# ncrypt.h flags +# ---------------------------- +# NCrypt/CNG flags +# ---------------------------- + +# ncrypt.h flags for KeyGuard protection _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 -_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 _RSA_KEY_SIZE = 2048 @@ -64,6 +73,19 @@ _KEY_NAME_ENVVAR = "MSAL_MSI_V2_KEY_NAME" +# ---------------------------- +# Error codes +# ---------------------------- + +# Common NCRYPT "not found" style statuses (NTE_*). +_NTE_BAD_KEYSET = 0x80090016 +_NTE_NO_KEY = 0x8009000D +_NTE_NOT_FOUND = 0x80090011 +_NTE_EXISTS = 0x8009000F + +# Lazy-loaded Win32 API cache +_WIN32: Optional[Dict[str, Any]] = None + # ---------------------------- # Compatibility helpers (tests + cross-language parity) @@ -73,6 +95,12 @@ def get_cert_thumbprint_sha256(cert_pem: str) -> str: """ Return base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 comparisons. Accepts a PEM certificate string. + + Args: + cert_pem: PEM-formatted certificate string + + Returns: + Base64url-encoded SHA256 thumbprint without padding, or empty string on error """ try: from cryptography import x509 @@ -91,6 +119,13 @@ def get_cert_thumbprint_sha256(cert_pem: str) -> str: def verify_cnf_binding(token: str, cert_pem: str) -> bool: """ Verify that JWT payload contains cnf.x5t#S256 matching the cert thumbprint. + + Args: + token: Access token (JWT format) + cert_pem: PEM-formatted certificate + + Returns: + True if token's cnf.x5t#S256 claim matches certificate thumbprint """ try: parts = token.split(".") @@ -115,23 +150,51 @@ def verify_cnf_binding(token: str, cert_pem: str) -> bool: return False +def _der_to_pem(der_bytes: bytes) -> str: + """ + Convert DER certificate bytes to PEM string format. + + Args: + der_bytes: DER-encoded certificate bytes + + Returns: + PEM-formatted certificate string + """ + b64 = base64.b64encode(der_bytes).decode("ascii") + # PEM line wrapping at 64 characters + lines = [b64[i:i+64] for i in range(0, len(b64), 64)] + return "-----BEGIN CERTIFICATE-----\n" + "\n".join(lines) + "\n-----END CERTIFICATE-----" + + # ---------------------------- # IMDS helpers # ---------------------------- def _imds_base() -> str: + """Get IMDS base URL from environment or use default.""" return os.getenv(_IMDS_BASE_ENVVAR, _IMDS_DEFAULT_BASE).strip().rstrip("/") def _new_correlation_id() -> str: + """Generate a new correlation ID for request tracing.""" return str(uuid.uuid4()) def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: + """Build IMDS request headers with Metadata: true and correlation ID.""" return {"Metadata": "true", "x-ms-client-request-id": correlation_id or _new_correlation_id()} def _resource_to_scope(resource_or_scope: str) -> str: + """ + Normalize resource to scope format (append /.default if needed). + + Args: + resource_or_scope: Resource URI or scope string + + Returns: + Normalized scope string ending with /.default + """ s = (resource_or_scope or "").strip() if not s: raise ValueError("resource must be non-empty") @@ -157,17 +220,21 @@ def _der_utf8string(value: str) -> bytes: def _json_loads(text: str, what: str) -> Dict[str, Any]: + """Parse JSON with error context.""" from .managed_identity import MsiV2Error try: obj = json.loads(text) if not isinstance(obj, dict): raise TypeError("expected JSON object") return obj - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: raise MsiV2Error(f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: + """ + Get first non-empty value from object by multiple name variants (case-insensitive). + """ for n in names: if n in obj and obj[n] is not None and str(obj[n]).strip() != "": return str(obj[n]) @@ -199,6 +266,7 @@ def _mi_query_params(managed_identity: Optional[Dict[str, Any]]) -> Dict[str, st def _imds_get_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str]) -> Dict[str, Any]: + """GET request to IMDS with server verification.""" from .managed_identity import MsiV2Error resp = http_client.get(url, params=params, headers=headers) server = (resp.headers or {}).get("server", "") @@ -210,6 +278,7 @@ def _imds_get_json(http_client, url: str, params: Dict[str, str], headers: Dict[ def _imds_post_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str], body: Dict[str, Any]) -> Dict[str, Any]: + """POST request to IMDS with server verification.""" from .managed_identity import MsiV2Error resp = http_client.post(url, params=params, headers=headers, data=json.dumps(body, separators=(",", ":"))) server = (resp.headers or {}).get("server", "") @@ -221,6 +290,10 @@ def _imds_post_json(http_client, url: str, params: Dict[str, str], headers: Dict def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: + """ + Extract token endpoint from issuecredential response. + Prefers explicit token_endpoint, falls back to mtls_authentication_endpoint + tenant_id. + """ token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") if token_endpoint: return token_endpoint @@ -239,11 +312,18 @@ def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: # Win32 primitives (ctypes) # ---------------------------- -_WIN32: Optional[Dict[str, Any]] = None - - def _load_win32() -> Dict[str, Any]: - """Lazy-load Win32 APIs via ctypes (safe to import on non-Windows).""" + """ + Lazy-load Win32 APIs via ctypes (safe to import on non-Windows). + + Initializes: + - ncrypt.dll: CNG/NCrypt key management + - crypt32.dll: X.509 certificate handling + - winhttp.dll: HTTP client with SChannel support + + Returns: + Dictionary mapping API names to ctypes objects + """ global _WIN32 from .managed_identity import MsiV2Error @@ -320,19 +400,19 @@ class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): ncrypt.NCryptOpenKey.argtypes = [ NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), - ctypes.c_wchar_p, # key name - wintypes.DWORD, # legacy keyspec (AT_SIGNATURE/AT_KEYEXCHANGE) - wintypes.DWORD, # flags + ctypes.c_wchar_p, + wintypes.DWORD, + wintypes.DWORD, ] ncrypt.NCryptOpenKey.restype = SECURITY_STATUS ncrypt.NCryptCreatePersistedKey.argtypes = [ NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), - ctypes.c_wchar_p, # alg id - ctypes.c_wchar_p, # key name - wintypes.DWORD, # legacy keyspec - wintypes.DWORD, # flags + ctypes.c_wchar_p, + ctypes.c_wchar_p, + wintypes.DWORD, + wintypes.DWORD, ] ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS @@ -458,17 +538,12 @@ def _check_security_status(status: int, what: str) -> None: def _status_u32(status: int) -> int: + """Convert signed status to unsigned 32-bit value.""" return int(status) & 0xFFFFFFFF -# Common NCRYPT "not found" style statuses (NTE_*). -_NTE_BAD_KEYSET = 0x80090016 -_NTE_NO_KEY = 0x8009000D -_NTE_NOT_FOUND = 0x80090011 -_NTE_EXISTS = 0x8009000F - - def _is_key_not_found(status: int) -> bool: + """Check if status indicates key not found (NTE_*).""" return _status_u32(status) in (_NTE_BAD_KEYSET, _NTE_NO_KEY, _NTE_NOT_FOUND) @@ -477,6 +552,7 @@ def _is_key_not_found(status: int) -> bool: # ---------------------------- def _der_len(n: int) -> bytes: + """Encode DER length field.""" if n < 0: raise ValueError("DER length cannot be negative") if n < 0x80: @@ -490,14 +566,17 @@ def _der_len(n: int) -> bytes: def _der(tag: int, content: bytes) -> bytes: + """Encode DER TLV (tag-length-value).""" return bytes([tag]) + _der_len(len(content)) + content def _der_null() -> bytes: + """Encode DER NULL.""" return b"\x05\x00" def _der_integer(value: int) -> bytes: + """Encode DER INTEGER.""" if value < 0: raise ValueError("Only non-negative INTEGER supported") if value == 0: @@ -510,6 +589,7 @@ def _der_integer(value: int) -> bytes: def _der_oid(oid: str) -> bytes: + """Encode DER OID (e.g., "2.5.4.3").""" parts = [int(x) for x in oid.split(".")] if len(parts) < 2: raise ValueError(f"Invalid OID: {oid}") @@ -535,24 +615,29 @@ def _der_oid(oid: str) -> bytes: def _der_sequence(*items: bytes) -> bytes: + """Encode DER SEQUENCE.""" return _der(0x30, b"".join(items)) def _der_set(*items: bytes) -> bytes: + """Encode DER SET (sorted for canonical encoding).""" enc = sorted(items) # DER SET requires sorting by full encoding return _der(0x31, b"".join(enc)) def _der_bitstring(data: bytes) -> bytes: + """Encode DER BITSTRING (no unused bits).""" return _der(0x03, b"\x00" + data) # 0 unused bits def _der_ia5string(value: str) -> bytes: + """Encode DER IA5String (ASCII-only).""" raw = value.encode("ascii") return _der(0x16, raw) def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: + """Encode context-specific EXPLICIT tag [tagnum].""" return _der(0xA0 + tagnum, inner) @@ -577,6 +662,7 @@ def _der_name_cn_dc(cn: str, dc: str) -> bytes: def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: + """Encode SubjectPublicKeyInfo for RSA public key.""" rsa_pub = _der_sequence(_der_integer(modulus), _der_integer(exponent)) alg = _der_sequence(_der_oid("1.2.840.113549.1.1.1"), _der_null()) # rsaEncryption + NULL return _der_sequence(alg, _der_bitstring(rsa_pub)) @@ -602,6 +688,7 @@ def _der_algid_rsapss_sha256() -> bytes: # ---------------------------- def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: + """Get NCrypt property (two-pass: query size, then read).""" ctypes_mod = win32["ctypes"] wintypes = win32["wintypes"] ncrypt = win32["ncrypt"] @@ -619,7 +706,7 @@ def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: def _stable_key_name(client_id: str) -> str: - # Keep the name deterministic and safe for CNG key naming. + """Generate a stable, CNG-safe key name from client ID.""" base = (client_id or "").strip() safe = [] for ch in base: @@ -936,6 +1023,7 @@ class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): def _winhttp_close(win32: Dict[str, Any], h: Any) -> None: + """Close WinHTTP handle safely.""" try: if h: win32["winhttp"].WinHttpCloseHandle(h) @@ -1082,7 +1170,32 @@ def obtain_token( *, attestation_enabled: bool = True, ) -> Dict[str, Any]: - """Acquire mtls_pop token using Windows KeyGuard + MAA attestation.""" + """ + Acquire mtls_pop token using Windows KeyGuard + MAA attestation. + + Flow: + 1. getplatformmetadata: fetch client_id, tenant_id, cu_id, attestationEndpoint + 2. Open/create named per-boot KeyGuard RSA key (non-exportable) + 3. Build PKCS#10 CSR with cuId attribute, sign with RSA-PSS/SHA256 + 4. Get attestation JWT from MAA (or cached) + 5. issuecredential: submit CSR + attestation → get X.509 cert + 6. Create CERT_CONTEXT, bind to KeyGuard private key + 7. POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS + + Returns: + { + "access_token": "...", + "expires_in": 3600, + "token_type": "mtls_pop", + "resource": "...", + "cert_pem": "-----BEGIN CERTIFICATE-----\\n...", # For mTLS with resource + "cert_der_b64": "base64-encoded...", # DER format + "cert_thumbprint_sha256": "...", # For verification + } + + Raises: + MsiV2Error: on any failure (no fallback to MSI v1) + """ from .managed_identity import MsiV2Error win32 = _load_win32() @@ -1096,9 +1209,10 @@ def obtain_token( prov = None key = None cert_ctx = None + cert_der = None # Track for return try: - # 1) metadata + # 1) getplatformmetadata: fetch metadata (client_id, tenant_id, cu_id, attestationEndpoint) meta_url = base + _CSR_METADATA_PATH meta = _imds_get_json(http_client, meta_url, params, _imds_headers(corr)) @@ -1115,7 +1229,7 @@ def obtain_token( prov, key, key_name, opened = _open_or_create_keyguard_rsa_key(win32, key_name=key_name) logger.debug("[msi_v2] KeyGuard key name=%s opened_existing=%s", key_name, opened) - # 3) CSR + # 3) Build PKCS#10 CSR with cuId attribute, signed by KeyGuard key (RSA-PSS/SHA256) csr_b64 = _build_csr_b64(win32, key, str(client_id), str(tenant_id), cu_id) # 4) Attestation (required for KeyGuard flow in this scenario) @@ -1124,7 +1238,7 @@ def obtain_token( if not attestation_endpoint: raise MsiV2Error("[msi_v2] attestationEndpoint missing from metadata.") - # Use a stable cache key so the attestation module can reuse the JWT across opens/closes. + # Get attestation JWT from MAA (with stable cache key so it persists across opens/closes) from .msi_v2_attestation import get_attestation_jwt att_jwt = get_attestation_jwt( attestation_endpoint=str(attestation_endpoint), @@ -1135,7 +1249,7 @@ def obtain_token( if not att_jwt or not str(att_jwt).strip(): raise MsiV2Error("[msi_v2] Attestation token is missing/empty; refusing to call issuecredential.") - # 5) issuecredential + # 5) issuecredential: submit CSR + attestation JWT → get X.509 cert issue_url = base + _ISSUE_CREDENTIAL_PATH issue_headers = _imds_headers(corr) issue_headers["Content-Type"] = "application/json" @@ -1154,22 +1268,32 @@ def obtain_token( canonical_client_id = _get_first(cred, "client_id", "clientId") or str(client_id) token_endpoint = _token_endpoint_from_credential(cred) - # 6) Bind key->cert then call ESTS over mTLS using SChannel + # 6) Create CERT_CONTEXT and bind to KeyGuard private key cert_ctx, _, _ = _create_cert_context_with_key(win32, cert_der, key, key_name) scope = _resource_to_scope(resource) + # 7) POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS (client cert presentation) token_json = _acquire_token_mtls_schannel(win32, token_endpoint, cert_ctx, canonical_client_id, scope) + # Return token + certificate for mTLS use with resource if token_json.get("access_token") and token_json.get("expires_in"): + cert_pem = _der_to_pem(cert_der) + cert_thumbprint = get_cert_thumbprint_sha256(cert_pem) + return { "access_token": token_json["access_token"], "expires_in": int(token_json["expires_in"]), "token_type": token_json.get("token_type") or "mtls_pop", "resource": token_json.get("resource"), + # Certificate for mTLS with resource + "cert_pem": cert_pem, + "cert_der_b64": base64.b64encode(cert_der).decode("ascii"), + "cert_thumbprint_sha256": cert_thumbprint, } return token_json finally: + # Cleanup: release handles (key persists for the boot because it is named + per-boot) # Crypt32: free cert context (request-scoped) try: if cert_ctx: @@ -1177,7 +1301,7 @@ def obtain_token( except Exception: pass - # NCrypt: release handles (key persists for the boot because it is named + per-boot) + # NCrypt: release provider and key handles (key persists) try: if key: ncrypt.NCryptFreeObject(key)