diff --git a/judoscale/core/config.py b/judoscale/core/config.py index 40913c3..8cc1253 100644 --- a/judoscale/core/config.py +++ b/judoscale/core/config.py @@ -23,13 +23,11 @@ def is_release_instance(self): class Config(UserDict): - def __init__( - self, runtime_container: RuntimeContainer, api_base_url: str, env: Mapping - ): + def __init__(self, runtime_container: RuntimeContainer, env: Mapping): initialdata = dict( DEFAULTS, RUNTIME_CONTAINER=runtime_container, - API_BASE_URL=api_base_url, + API_BASE_URL=env.get("JUDOSCALE_URL"), ) for key in {"RQ", "CELERY"}: @@ -46,62 +44,28 @@ def __init__( @classmethod def initialize(cls, env: Mapping = os.environ): if env.get("JUDOSCALE_CONTAINER"): - return cls.for_custom(env) + container = env["JUDOSCALE_CONTAINER"] elif env.get("DYNO"): - return cls.for_heroku(env) + container = env["DYNO"] elif env.get("RENDER_INSTANCE_ID"): - return cls.for_render(env) + service_id = env.get("RENDER_SERVICE_ID") + container = env["RENDER_INSTANCE_ID"].replace(f"{service_id}-", "") elif env.get("ECS_CONTAINER_METADATA_URI"): - return cls.for_ecs(env) + container = env["ECS_CONTAINER_METADATA_URI"].split("/")[-1] elif env.get("FLY_MACHINE_ID"): - return cls.for_fly(env) + container = env["FLY_MACHINE_ID"] elif env.get("RAILWAY_REPLICA_ID"): - return cls.for_railway(env) + container = env["RAILWAY_REPLICA_ID"] else: - return cls.for_unknown(env) + container = "" - @classmethod - def for_heroku(cls, env: Mapping): - runtime_container = RuntimeContainer(env["DYNO"]) - api_base_url = env.get("JUDOSCALE_URL") - return cls(runtime_container, api_base_url, env) - - @classmethod - def for_render(cls, env: Mapping): - service_id = env.get("RENDER_SERVICE_ID") - instance = env.get("RENDER_INSTANCE_ID").replace(f"{service_id}-", "") - runtime_container = RuntimeContainer(instance) - api_base_url = env.get("JUDOSCALE_URL") or f"https://adapter.judoscale.com/api/{service_id}" - return cls(runtime_container, api_base_url, env) + config = cls(RuntimeContainer(container), env) - @classmethod - def for_ecs(cls, env: Mapping): - instance = env["ECS_CONTAINER_METADATA_URI"].split("/")[-1] - runtime_container = RuntimeContainer(instance) - api_base_url = env.get("JUDOSCALE_URL") - return cls(runtime_container, api_base_url, env) - - @classmethod - def for_fly(cls, env: Mapping): - runtime_container = RuntimeContainer(env["FLY_MACHINE_ID"]) - api_base_url = env.get("JUDOSCALE_URL") - return cls(runtime_container, api_base_url, env) + # Render legacy support: fall back to constructing URL from service ID + if not config["API_BASE_URL"] and env.get("RENDER_SERVICE_ID"): + config["API_BASE_URL"] = f"https://adapter.judoscale.com/api/{env['RENDER_SERVICE_ID']}" - @classmethod - def for_railway(cls, env: Mapping): - runtime_container = RuntimeContainer(env["RAILWAY_REPLICA_ID"]) - api_base_url = env.get("JUDOSCALE_URL") - return cls(runtime_container, api_base_url, env) - - @classmethod - def for_custom(cls, env: Mapping): - runtime_container = RuntimeContainer(env["JUDOSCALE_CONTAINER"]) - api_base_url = env.get("JUDOSCALE_URL") - return cls(runtime_container, api_base_url, env) - - @classmethod - def for_unknown(cls, env: Mapping): - return cls(RuntimeContainer(""), env.get("JUDOSCALE_URL"), env) + return config @property def is_enabled(self) -> bool: diff --git a/tests/test_config.py b/tests/test_config.py index da6b5a1..9842ff8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,20 +118,17 @@ def test_judoscale_log_level_env(self): "JUDOSCALE_LOG_LEVEL": "WARN", "JUDOSCALE_URL": "https://api.example.com", } - config = Config.for_heroku(fake_env) + config = Config.initialize(fake_env) assert config["RUNTIME_CONTAINER"] == "web.1" assert config["LOG_LEVEL"] == "WARN" assert config["API_BASE_URL"] == "https://api.example.com" def test_is_enabled(self): - config = Config(None, "", {}) - assert not config.is_enabled - - config = Config(None, None, {}) + config = Config(RuntimeContainer(""), {}) assert not config.is_enabled - config = Config(None, "https://some-url.com", {}) + config = Config(RuntimeContainer(""), {"JUDOSCALE_URL": "https://some-url.com"}) assert config.is_enabled def test_for_report(self): @@ -140,7 +137,7 @@ def test_for_report(self): "LOG_LEVEL": "WARN", "JUDOSCALE_URL": "https://api.example.com", } - config = Config.for_heroku(fake_env) + config = Config.initialize(fake_env) assert config.for_report == {"log_level": "WARN", "report_interval_seconds": 10} config.update({"LOG_LEVEL": "ERROR", "REPORT_INTERVAL_SECONDS": 20}) @@ -160,7 +157,7 @@ def test_update(self): "QUEUES": ["default", "high"], }, } - config = Config.for_heroku(fake_env) + config = Config.initialize(fake_env) assert config["API_BASE_URL"] == "https://api.example.com" assert config["RUNTIME_CONTAINER"] == "worker.1" assert config["LOG_LEVEL"] == "WARN" @@ -191,7 +188,7 @@ def test_update_lowercase_keys(self): "QUEUES": ["default", "high"], }, } - config = Config.for_heroku(fake_env) + config = Config.initialize(fake_env) assert config["LOG_LEVEL"] == "WARN" assert config["REPORT_INTERVAL_SECONDS"] == 10