diff --git a/src/urlscan/pro/channel.py b/src/urlscan/pro/channel.py index 6dbe06e..4a71f1e 100644 --- a/src/urlscan/pro/channel.py +++ b/src/urlscan/pro/channel.py @@ -2,13 +2,14 @@ from typing import Any -from urlscan.client import BaseClient, _compact +from urlscan.client import BaseClient from urlscan.types import ( ChannelPermissionType, ChannelTypeType, FrequencyType, WeekDaysType, ) +from urlscan.utils import _compact, _merge class Channel(BaseClient): @@ -40,6 +41,7 @@ def create( ignore_time: bool | None = None, week_days: list[WeekDaysType] | None = None, permissions: list[ChannelPermissionType] | None = None, + **kwargs: Any, ) -> dict: """Create a new channel. @@ -55,6 +57,7 @@ def create( ignore_time (bool | None, optional): Whether to ignore time constraints. Defaults to None. week_days (list[WeekDaysType] | None, optional): Days of the week alerts will be generated (Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday). Defaults to None. permissions (list[ChannelPermissionType] | None, optional): Permissions associated with this channel (team:read, team:write). Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Object containing the created channel. @@ -64,19 +67,22 @@ def create( """ channel: dict[str, Any] = _compact( - { - "type": channel_type, - "name": name, - "webhookURL": webhook_url, - "frequency": frequency, - "emailAddresses": email_addresses, - "utcTime": utc_time, - "isActive": is_active, - "isDefault": is_default, - "ignoreTime": ignore_time, - "weekDays": week_days, - "permissions": permissions, - } + _merge( + { + "type": channel_type, + "name": name, + "webhookURL": webhook_url, + "frequency": frequency, + "emailAddresses": email_addresses, + "utcTime": utc_time, + "isActive": is_active, + "isDefault": is_default, + "ignoreTime": ignore_time, + "weekDays": week_days, + "permissions": permissions, + }, + kwargs, + ) ) data = {"channel": channel} @@ -113,6 +119,7 @@ def update( ignore_time: bool | None = None, week_days: list[WeekDaysType] | None = None, permissions: list[ChannelPermissionType] | None = None, + **kwargs: Any, ) -> dict: """Update an existing channel. @@ -129,29 +136,32 @@ def update( ignore_time (bool | None, optional): Whether to ignore time constraints. Defaults to None. week_days (list[WeekDaysType] | None, optional): Days of the week alerts will be generated (Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday). Defaults to None. permissions (list[ChannelPermissionType] | None, optional): Permissions associated with this channel (team:read, team:write). Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Object containing the updated channel. - Reference: https://docs.urlscan.io/apis/urlscan-openapi/channels/channelsupdate """ channel: dict[str, Any] = _compact( - { - "type": channel_type, - "name": name, - "webhookURL": webhook_url, - "frequency": frequency, - "emailAddresses": email_addresses, - "utcTime": utc_time, - "isActive": is_active, - "isDefault": is_default, - "ignoreTime": ignore_time, - "weekDays": week_days, - "permissions": permissions, - } + _merge( + { + "type": channel_type, + "name": name, + "webhookURL": webhook_url, + "frequency": frequency, + "emailAddresses": email_addresses, + "utcTime": utc_time, + "isActive": is_active, + "isDefault": is_default, + "ignoreTime": ignore_time, + "weekDays": week_days, + "permissions": permissions, + }, + kwargs, + ) ) data = {"channel": channel} diff --git a/src/urlscan/pro/incident.py b/src/urlscan/pro/incident.py index 37e9a7c..fb90de4 100644 --- a/src/urlscan/pro/incident.py +++ b/src/urlscan/pro/incident.py @@ -8,6 +8,7 @@ ScanIntervalModeType, WatchedAttributeType, ) +from urlscan.utils import _merge class Incident(BaseClient): @@ -33,6 +34,7 @@ def create( scan_interval_after_malicious: int | None = None, incident_profile: str | None = None, expire_after: int | None = None, + **kwargs: Any, ) -> dict: """Create an incident with specific options. @@ -54,6 +56,7 @@ def create( scan_interval_after_malicious (int | None, optional): How to change the scan interval after the observable became malicious. Defaults to None. incident_profile (str | None, optional): ID of the incident profile to use when creating this incident. Defaults to None. expire_after (int | None, optional): Seconds until the incident will automatically be closed. Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Incident body. @@ -63,25 +66,28 @@ def create( """ incident: dict[str, Any] = _compact( - { - "observable": observable, - "visibility": visibility, - "channels": channels, - "scanInterval": scan_interval, - "scanIntervalMode": scan_interval_mode, - "watchedAttributes": watched_attributes, - "userAgents": user_agents, - "userAgentsPerInterval": user_agents_per_interval, - "countries": countries, - "countriesPerInterval": countries_per_interval, - "stopDelaySuspended": stop_delay_suspended, - "stopDelayInactive": stop_delay_inactive, - "stopDelayMalicious": stop_delay_malicious, - "scanIntervalAfterSuspended": scan_interval_after_suspended, - "scanIntervalAfterMalicious": scan_interval_after_malicious, - "incidentProfile": incident_profile, - "expireAfter": expire_after, - } + _merge( + { + "observable": observable, + "visibility": visibility, + "channels": channels, + "scanInterval": scan_interval, + "scanIntervalMode": scan_interval_mode, + "watchedAttributes": watched_attributes, + "userAgents": user_agents, + "userAgentsPerInterval": user_agents_per_interval, + "countries": countries, + "countriesPerInterval": countries_per_interval, + "stopDelaySuspended": stop_delay_suspended, + "stopDelayInactive": stop_delay_inactive, + "stopDelayMalicious": stop_delay_malicious, + "scanIntervalAfterSuspended": scan_interval_after_suspended, + "scanIntervalAfterMalicious": scan_interval_after_malicious, + "incidentProfile": incident_profile, + "expireAfter": expire_after, + }, + kwargs, + ) ) data = {"incident": incident} @@ -124,6 +130,7 @@ def update( scan_interval_after_malicious: int | None = None, incident_profile: str | None = None, expire_after: int | None = None, + **kwargs: Any, ) -> dict: """Update specific runtime options of the incident. @@ -146,6 +153,7 @@ def update( scan_interval_after_malicious (int | None, optional): How to change the scan interval after the observable became malicious. Defaults to None. incident_profile (str | None, optional): ID of the incident profile to use when creating this incident. Defaults to None. expire_after (int | None, optional): Seconds until the incident will automatically be closed. Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Incident body. @@ -155,25 +163,28 @@ def update( """ incident: dict[str, Any] = _compact( - { - "observable": observable, - "visibility": visibility, - "channels": channels, - "scanInterval": scan_interval, - "scanIntervalMode": scan_interval_mode, - "watchedAttributes": watched_attributes, - "userAgents": user_agents, - "userAgentsPerInterval": user_agents_per_interval, - "countries": countries, - "countriesPerInterval": countries_per_interval, - "stopDelaySuspended": stop_delay_suspended, - "stopDelayInactive": stop_delay_inactive, - "stopDelayMalicious": stop_delay_malicious, - "scanIntervalAfterSuspended": scan_interval_after_suspended, - "scanIntervalAfterMalicious": scan_interval_after_malicious, - "incidentProfile": incident_profile, - "expireAfter": expire_after, - } + _merge( + { + "observable": observable, + "visibility": visibility, + "channels": channels, + "scanInterval": scan_interval, + "scanIntervalMode": scan_interval_mode, + "watchedAttributes": watched_attributes, + "userAgents": user_agents, + "userAgentsPerInterval": user_agents_per_interval, + "countries": countries, + "countriesPerInterval": countries_per_interval, + "stopDelaySuspended": stop_delay_suspended, + "stopDelayInactive": stop_delay_inactive, + "stopDelayMalicious": stop_delay_malicious, + "scanIntervalAfterSuspended": scan_interval_after_suspended, + "scanIntervalAfterMalicious": scan_interval_after_malicious, + "incidentProfile": incident_profile, + "expireAfter": expire_after, + }, + kwargs, + ) ) data = {"incident": incident} diff --git a/src/urlscan/pro/livescan.py b/src/urlscan/pro/livescan.py index 5263adc..d838057 100644 --- a/src/urlscan/pro/livescan.py +++ b/src/urlscan/pro/livescan.py @@ -4,6 +4,7 @@ from urlscan.client import BaseClient, _compact from urlscan.types import LiveScanResourceType, VisibilityType +from urlscan.utils import _merge class LiveScan(BaseClient): @@ -32,6 +33,7 @@ def task( extra_headers: dict[str, str] | None = None, enable_features: list[str] | None = None, disable_features: list[str] | None = None, + **kwargs: Any, ) -> dict: """Task a URL to be scanned. @@ -46,6 +48,7 @@ def task( extra_headers (dict[str, str] | None, optional): Extra HTTP headers. Defaults to None. enable_features (list[str] | None, optional): Features to enable. Defaults to None. disable_features (list[str] | None, optional): Features to disable. Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Response containing the scan UUID. @@ -61,15 +64,18 @@ def task( } ) scanner: dict[str, Any] = _compact( - { - "pageTimeout": page_timeout, - "captureDelay": capture_delay, - "extraHeaders": extra_headers, - "enableFeatures": enable_features, - "disableFeatures": disable_features, - } + _merge( + { + "pageTimeout": page_timeout, + "captureDelay": capture_delay, + "extraHeaders": extra_headers, + "enableFeatures": enable_features, + "disableFeatures": disable_features, + }, + kwargs, + ) ) - data: dict[str, Any] = _compact({"task": task, "scanner": scanner}) + data: dict[str, Any] = {"task": task, "scanner": scanner} res = self._post(f"/api/v1/livescan/{scanner_id}/task/", json=data) return self._response_to_json(res) @@ -85,6 +91,7 @@ def scan( extra_headers: dict[str, str] | None = None, enable_features: list[str] | None = None, disable_features: list[str] | None = None, + **kwargs: Any, ) -> dict: """Task a URL to be scanned. The HTTP request will block until the scan has finished. @@ -97,6 +104,7 @@ def scan( extra_headers (dict[str, str] | None, optional): Extra HTTP headers. Defaults to None. enable_features (list[str] | None, optional): Features to enable. Defaults to None. disable_features (list[str] | None, optional): Features to disable. Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Response containing the scan UUID. @@ -112,15 +120,18 @@ def scan( } ) scanner: dict[str, Any] = _compact( - { - "pageTimeout": page_timeout, - "captureDelay": capture_delay, - "extraHeaders": extra_headers, - "enableFeatures": enable_features, - "disableFeatures": disable_features, - } + _merge( + { + "pageTimeout": page_timeout, + "captureDelay": capture_delay, + "extraHeaders": extra_headers, + "enableFeatures": enable_features, + "disableFeatures": disable_features, + }, + kwargs, + ) ) - data: dict[str, Any] = _compact({"task": task, "scanner": scanner}) + data: dict[str, Any] = {"task": task, "scanner": scanner} res = self._post(f"/api/v1/livescan/{scanner_id}/scan/", json=data) return self._response_to_json(res) diff --git a/src/urlscan/pro/saved_search.py b/src/urlscan/pro/saved_search.py index f449328..1f13a0c 100644 --- a/src/urlscan/pro/saved_search.py +++ b/src/urlscan/pro/saved_search.py @@ -4,6 +4,7 @@ from urlscan.client import BaseClient, _compact from urlscan.types import PermissionType, SavedSearchDataSource, TLPType +from urlscan.utils import _merge class SavedSearch(BaseClient): @@ -32,6 +33,7 @@ def create( tlp: TLPType | None = None, user_tags: list[str] | None = None, permissions: list[PermissionType] | None = None, + **kwargs: Any, ) -> dict: """Create a Saved Search. @@ -53,6 +55,7 @@ def create( permissions (list[PermissionType] | None, optional): Determine whether only other users on the same team or everyone on urlscan Pro can see the search. Valid values: "public:read", "team:read", "team:write". Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Created Saved Search object containing the search properties and unique _id. @@ -62,16 +65,19 @@ def create( """ search: dict[str, Any] = _compact( - { - "datasource": datasource, - "query": query, - "name": name, - "description": description, - "longDescription": long_description, - "tlp": tlp, - "userTags": user_tags, - "permissions": permissions, - } + _merge( + { + "datasource": datasource, + "query": query, + "name": name, + "description": description, + "longDescription": long_description, + "tlp": tlp, + "userTags": user_tags, + "permissions": permissions, + }, + kwargs, + ) ) data: dict[str, Any] = {"search": search} @@ -90,6 +96,7 @@ def update( tlp: TLPType | None = None, user_tags: list[str] | None = None, permissions: list[PermissionType] | None = None, + **kwargs: Any, ) -> dict: """Update a Saved Search. @@ -112,6 +119,7 @@ def update( permissions (list[PermissionType] | None, optional): Determine whether only other users on the same team or everyone on urlscan Pro can see the search. Valid values: "public:read", "team:read", "team:write". Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Updated Saved Search object containing the search properties and unique _id. @@ -121,16 +129,19 @@ def update( """ search: dict[str, Any] = _compact( - { - "datasource": datasource, - "query": query, - "name": name, - "description": description, - "longDescription": long_description, - "tlp": tlp, - "userTags": user_tags, - "permissions": permissions, - } + _merge( + { + "datasource": datasource, + "query": query, + "name": name, + "description": description, + "longDescription": long_description, + "tlp": tlp, + "userTags": user_tags, + "permissions": permissions, + }, + kwargs, + ) ) data: dict[str, Any] = {"search": search} diff --git a/src/urlscan/pro/subscription.py b/src/urlscan/pro/subscription.py index 9c6bee9..d49a4ff 100644 --- a/src/urlscan/pro/subscription.py +++ b/src/urlscan/pro/subscription.py @@ -11,6 +11,7 @@ SubscriptionPermissionType, WeekDaysType, ) +from urlscan.utils import _merge class Subscription(BaseClient): @@ -46,6 +47,7 @@ def create( incident_visibility: IncidentVisibilityType | None = None, incident_creation_mode: IncidentCreationModeType | None = None, incident_watch_keys: IncidentWatchKeyType | None = None, + **kwargs: Any, ) -> dict: """Create a new subscription. @@ -65,6 +67,7 @@ def create( incident_visibility (IncidentVisibilityType | None, optional): Incident visibility for this subscription ("unlisted" or "private"). Defaults to None. incident_creation_mode (IncidentCreationModeType | None, optional): Incident creation rule for this subscription ("none", "default", "always", or "ignore-if-exists"). Defaults to None. incident_watch_keys (IncidentWatchKeyType | None, optional): Source/key to watch in the incident (scans/page.url, scans/page.domain, scans/page.ip, scans/page.apexDomain, hostnames/hostname, hostnames/ip, hostnames/domain). Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Response containing the created subscription with an '_id' field. @@ -74,23 +77,26 @@ def create( """ subscription: dict[str, Any] = _compact( - { - "searchIds": search_ids, - "frequency": frequency, - "emailAddresses": email_addresses, - "name": name, - "description": description, - "isActive": is_active, - "ignoreTime": ignore_time, - "weekDays": week_days, - "permissions": permissions, - "channelIds": channel_ids, - "incidentChannelIds": incident_channel_ids, - "incidentProfileId": incident_profile_id, - "incidentVisibility": incident_visibility, - "incidentCreationMode": incident_creation_mode, - "incidentWatchKeys": incident_watch_keys, - } + _merge( + { + "searchIds": search_ids, + "frequency": frequency, + "emailAddresses": email_addresses, + "name": name, + "description": description, + "isActive": is_active, + "ignoreTime": ignore_time, + "weekDays": week_days, + "permissions": permissions, + "channelIds": channel_ids, + "incidentChannelIds": incident_channel_ids, + "incidentProfileId": incident_profile_id, + "incidentVisibility": incident_visibility, + "incidentCreationMode": incident_creation_mode, + "incidentWatchKeys": incident_watch_keys, + }, + kwargs, + ) ) data = {"subscription": subscription} @@ -116,6 +122,7 @@ def update( incident_visibility: IncidentVisibilityType | None = None, incident_creation_mode: IncidentCreationModeType | None = None, incident_watch_keys: IncidentWatchKeyType | None = None, + **kwargs: Any, ) -> dict: """Update the settings for a subscription. @@ -136,6 +143,7 @@ def update( incident_visibility (IncidentVisibilityType | None, optional): Incident visibility for this subscription ("unlisted" or "private"). Defaults to None. incident_creation_mode (IncidentCreationModeType | None, optional): Incident creation rule for this subscription ("none", "default", "always", or "ignore-if-exists"). Defaults to None. incident_watch_keys (IncidentWatchKeyType | None, optional): Source/key to watch in the incident (scans/page.url, scans/page.domain, scans/page.ip, scans/page.apexDomain, hostnames/hostname, hostnames/ip, hostnames/domain). Defaults to None. + **kwargs: Additional parameters to include in the request payload. Returns: dict: Response containing the updated subscription with an '_id' field. @@ -145,23 +153,26 @@ def update( """ subscription: dict[str, Any] = _compact( - { - "searchIds": search_ids, - "frequency": frequency, - "emailAddresses": email_addresses, - "name": name, - "description": description, - "isActive": is_active, - "ignoreTime": ignore_time, - "weekDays": week_days, - "permissions": permissions, - "channelIds": channel_ids, - "incidentChannelIds": incident_channel_ids, - "incidentProfileId": incident_profile_id, - "incidentVisibility": incident_visibility, - "incidentCreationMode": incident_creation_mode, - "incidentWatchKeys": incident_watch_keys, - } + _merge( + { + "searchIds": search_ids, + "frequency": frequency, + "emailAddresses": email_addresses, + "name": name, + "description": description, + "isActive": is_active, + "ignoreTime": ignore_time, + "weekDays": week_days, + "permissions": permissions, + "channelIds": channel_ids, + "incidentChannelIds": incident_channel_ids, + "incidentProfileId": incident_profile_id, + "incidentVisibility": incident_visibility, + "incidentCreationMode": incident_creation_mode, + "incidentWatchKeys": incident_watch_keys, + }, + kwargs, + ) ) data = {"subscription": subscription} diff --git a/src/urlscan/utils.py b/src/urlscan/utils.py index 6e39daa..6e49302 100644 --- a/src/urlscan/utils.py +++ b/src/urlscan/utils.py @@ -4,6 +4,7 @@ import gzip import os import tarfile +from typing import Any StrOrBytesPath = str | bytes | os.PathLike[str] | os.PathLike[bytes] @@ -13,6 +14,18 @@ def _compact(d: dict) -> dict: return {k: v for k, v in d.items() if v is not None} +def _merge(d: dict, kwargs: dict[str, Any]) -> dict: + """Merge a dictionary with additional key-value pairs.""" + result = d.copy() + for k, v in kwargs.items(): + if k in result: + raise ValueError(f"Recived multiple values for key: {k}") + + result[k] = v + + return result + + def parse_datetime(s: str) -> datetime.datetime: """Parse an ISO 8601 datetime string to a datetime object.""" dt = datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%fZ") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index fc8deda..c16614c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -4,7 +4,22 @@ import pytest -from urlscan.utils import extract, parse_datetime +from urlscan.utils import _merge, extract, parse_datetime + + +def test_merge(): + def inner(**kwargs): + return _merge({"a": 1, "b": 2}, kwargs) + + assert inner(c=3, d=4) == {"a": 1, "b": 2, "c": 3, "d": 4} + + +def test_merge_with_duplication(): + def inner(**kwargs): + return _merge({"a": 1, "b": 2}, kwargs) + + with pytest.raises(ValueError): + inner(b=3) @pytest.mark.parametrize(