Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions buffalogs/impossible_travel/management/commands/setup_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from argparse import RawTextHelpFormatter
from typing import Any, Tuple
from typing import Any, List, Tuple

from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import ValidationError
Expand Down Expand Up @@ -30,8 +31,30 @@ def _cast_value(val: str) -> Any:
return val


def _parse_list_values(inner: str) -> List[Any]:
"""Parse comma-separated values from a list string, handling quoted values with spaces."""
if not inner.strip():
return []

pattern = r"""
'([^']*)' | # single-quoted
"([^"]*)" | # double-quoted
([^,\[\]'"]+) # unquoted
"""

values = []
for match in re.finditer(pattern, inner, re.VERBOSE):
value = match.group(1) or match.group(2) or match.group(3)
if value is not None:
value = value.strip()
if value: # Skip empty values
values.append(_cast_value(value))

return values


def parse_field_value(item: str) -> Tuple[str, Any]:
"""Parse a string of the form FIELD=VALUE or FIELD=[val1,val2]"""
"""Parse a FIELD=VALUE string, supporting list syntax like FIELD=[val1, val2]."""
if "=" not in item:
raise CommandError(f"Invalid syntax '{item}': must be FIELD=VALUE")

Expand All @@ -40,7 +63,7 @@ def parse_field_value(item: str) -> Tuple[str, Any]:

if value.startswith("[") and value.endswith("]"):
inner = value[1:-1].strip()
parsed = [_cast_value(v) for v in inner.split(",") if v.strip()]
parsed = _parse_list_values(inner)
else:
parsed = _cast_value(value)

Expand All @@ -63,10 +86,22 @@ def create_parser(self, *args, **kwargs):
-r FIELD=VALUE Remove the specified VALUE from list values

Examples:
./manage.py setup_config -o allowed_countries=["Italy","Romania"]
./manage.py setup_config -r ignored_users=[admin]
# Override with multiple values (use quotes around the entire argument)
./manage.py setup_config -o "allowed_countries=['Italy', 'Romania', 'Germany']"

# Append multiple values to a list field
./manage.py setup_config -a "filtered_alerts_types=['New Device', 'User Risk Threshold', 'Anonymous IP Login']"

# Remove multiple values from a list field
./manage.py setup_config -r "ignored_users=['admin', 'bot', 'audit']"

# Mixed operations
./manage.py setup_config -o "allowed_countries=['Italy']" -r "ignored_users=['bot']" -a "filtered_alerts_types=['New Device', 'Impossible Travel']"

# Non-list field override
./manage.py setup_config -a alert_is_vip_only=True
./manage.py setup_config -o allowed_countries=["Italy"] -r ignored_users="bot" -r ignored_users=["audit"] -a filtered_alerts_types=["New Device"]

Note: When passing values with spaces, wrap the entire argument in quotes.

Additional options:
--set-default-values Reset all fields in Config to their default values
Expand Down Expand Up @@ -148,30 +183,29 @@ def handle(self, *args, **options):
if is_list and not isinstance(value, list):
value = [value]

# Validate values
values_to_validate = value if is_list else [value]
for val in values_to_validate:
for validator in getattr(field_obj, "validators", []):
try:
validator(val)
except ValidationError as e:
raise CommandError(f"Validation error on field '{field}' with value '{val}': {e}")

# Apply changes
# Apply changes first (before validation)
if is_list:
current = current or []
if mode == "append":
current += value
new_value = current + value
elif mode == "override":
current = value
new_value = value
elif mode == "remove":
current = [item for item in current if item not in value]
new_value = [item for item in current if item not in value]
else:
if mode != "override":
raise CommandError(f"Field '{field}' is not a list. Use --override to set its value.")
current = value
new_value = value

# Validate the final computed value
# For ArrayFields, validators expect the complete list (not individual items)
for validator in getattr(field_obj, "validators", []):
try:
validator(new_value)
except ValidationError as e:
raise CommandError(f"Validation error on field '{field}' with value '{new_value}': {e}")

setattr(config, field, current)
setattr(config, field, new_value)

config.save()
self.stdout.write(self.style.SUCCESS("Config updated successfully."))
107 changes: 106 additions & 1 deletion buffalogs/impossible_travel/tests/task/test_management_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.db.models.fields import Field
from django.test import TestCase
from impossible_travel.constants import AlertDetectionType, UserRiskScoreType
from impossible_travel.management.commands.setup_config import Command, parse_field_value
from impossible_travel.management.commands.setup_config import Command, _parse_list_values, parse_field_value
from impossible_travel.models import (
Config,
User,
Expand Down Expand Up @@ -260,6 +260,111 @@ def test_parse_field_value_numeric(self):
self.assertEqual(field_float, "vel_accepted")
self.assertEqual(value_float, 55.7)

# === Tests for _parse_list_values function (multiple values with spaces) ===

def test_parse_list_values_single_quotes_with_spaces(self):
# Testing _parse_list_values with single-quoted values containing spaces
result = _parse_list_values("'New Device', 'User Risk Threshold', 'Anonymous IP Login'")
self.assertEqual(result, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_list_values_double_quotes_with_spaces(self):
# Testing _parse_list_values with double-quoted values containing spaces
result = _parse_list_values('"New Device", "User Risk Threshold", "Anonymous IP Login"')
self.assertEqual(result, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_list_values_mixed_quotes(self):
# Testing _parse_list_values with mixed single and double quotes
result = _parse_list_values("'New Device', \"User Risk Threshold\", 'Impossible Travel'")
self.assertEqual(result, ["New Device", "User Risk Threshold", "Impossible Travel"])

def test_parse_list_values_unquoted_values(self):
# Testing _parse_list_values with unquoted values (no spaces)
result = _parse_list_values("admin, user1, user2")
self.assertEqual(result, ["admin", "user1", "user2"])

def test_parse_list_values_empty_string(self):
# Testing _parse_list_values with empty string
result = _parse_list_values("")
self.assertEqual(result, [])

def test_parse_list_values_whitespace_only(self):
# Testing _parse_list_values with whitespace only
result = _parse_list_values(" ")
self.assertEqual(result, [])

def test_parse_list_values_mixed_quoted_and_unquoted(self):
# Testing _parse_list_values with mixed quoted and unquoted values
result = _parse_list_values("'New Device', admin, \"Impossible Travel\"")
self.assertEqual(result, ["New Device", "admin", "Impossible Travel"])

# === Tests for parse_field_value with multiple values (Issue #499) ===

def test_parse_field_value_multiple_values_single_quotes(self):
# Testing parse_field_value with multiple single-quoted values containing spaces
field, value = parse_field_value("filtered_alerts_types=['New Device', 'User Risk Threshold', 'Anonymous IP Login']")
self.assertEqual(field, "filtered_alerts_types")
self.assertEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_field_value_multiple_values_double_quotes(self):
# Testing parse_field_value with multiple double-quoted values containing spaces
field, value = parse_field_value('filtered_alerts_types=["New Device", "User Risk Threshold", "Anonymous IP Login"]')
self.assertEqual(field, "filtered_alerts_types")
self.assertEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_field_value_multiple_countries(self):
# Testing parse_field_value with multiple country values
field, value = parse_field_value("allowed_countries=['Italy', 'Romania', 'Germany']")
self.assertEqual(field, "allowed_countries")
self.assertEqual(value, ["Italy", "Romania", "Germany"])

def test_parse_field_value_multiple_users_with_spaces(self):
# Testing parse_field_value with user values that could contain special chars
field, value = parse_field_value("ignored_users=['admin', 'bot', 'audit']")
self.assertEqual(field, "ignored_users")
self.assertEqual(value, ["admin", "bot", "audit"])

# === Integration tests for setup_config command with multiple values ===

def test_setup_config_append_multiple_values(self):
# Integration test: append multiple values to a list field
Config.objects.all().delete()
config = Config.objects.create(id=1, filtered_alerts_types=[])

call_command("setup_config", "-a", "filtered_alerts_types=['New Device', 'User Risk Threshold']")
config.refresh_from_db()

self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"])

def test_setup_config_override_multiple_values(self):
# Integration test: override with multiple values
Config.objects.all().delete()
config = Config.objects.create(id=1, allowed_countries=["USA"])

call_command("setup_config", "-o", "allowed_countries=['Italy', 'Romania', 'Germany']")
config.refresh_from_db()

self.assertListEqual(config.allowed_countries, ["Italy", "Romania", "Germany"])

def test_setup_config_remove_multiple_values(self):
# Integration test: remove multiple values from a list field
Config.objects.all().delete()
config = Config.objects.create(id=1, ignored_users=["admin", "bot", "audit", "system"])

call_command("setup_config", "-r", "ignored_users=['admin', 'bot']")
config.refresh_from_db()

self.assertListEqual(config.ignored_users, ["audit", "system"])

def test_setup_config_append_to_existing_values(self):
# Integration test: append multiple values to existing list
Config.objects.all().delete()
config = Config.objects.create(id=1, filtered_alerts_types=["New Device"])

call_command("setup_config", "-a", "filtered_alerts_types=['User Risk Threshold', 'Anonymous IP Login']")
config.refresh_from_db()

self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold", "Anonymous IP Login"])


class ResetUserRiskScoreCommandTests(TestCase):
def setUp(self):
Expand Down
Loading