From c907412ed3e94136f5ed41b2a49f0ebd48f96a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:40 -0300 Subject: [PATCH 1/6] test: add tests/test_tools_v2.py --- tests/test_tools_v2.py | 1059 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1059 insertions(+) create mode 100644 tests/test_tools_v2.py diff --git a/tests/test_tools_v2.py b/tests/test_tools_v2.py new file mode 100644 index 0000000..e6bf12f --- /dev/null +++ b/tests/test_tools_v2.py @@ -0,0 +1,1059 @@ +# ============================================================================= +# MedeX - Tool System Tests (V2-aligned) +# ============================================================================= +""" +Comprehensive tests for the MedeX V2 Tool System. + +Tests cover: +- Tool models (ToolParameter, ToolDefinition, ToolCall, ToolResult, ToolError) +- Tool registry (registration, retrieval, filtering, enable/disable) +- Tool executor (execution, timeout, batch, metrics) +- Medical tools (drug interactions, dosage, labs, emergency) +- Tool service integration + +All tests use the actual V2 API signatures. +""" + +from __future__ import annotations + +import json +from uuid import uuid4 + +import pytest + +from medex.tools.executor import ToolExecutor +from medex.tools.models import ( + ParameterType, + ToolCall, + ToolCategory, + ToolDefinition, + ToolError, + ToolParameter, + ToolResult, + ToolStatus, +) +from medex.tools.registry import ( + ToolRegistry, + number_param, + string_param, +) + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_parameter() -> ToolParameter: + """Create a sample tool parameter.""" + return ToolParameter( + name="medication", + type=ParameterType.STRING, + description="Name of the medication", + required=True, + ) + + +@pytest.fixture +def sample_definition() -> ToolDefinition: + """Create a sample tool definition.""" + return ToolDefinition( + name="check_medication", + description="Check medication information", + category=ToolCategory.DRUG, + parameters=[ + ToolParameter( + name="medication", + type=ParameterType.STRING, + description="Medication name", + required=True, + ), + ToolParameter( + name="dosage", + type=ParameterType.NUMBER, + description="Dosage in mg", + required=False, + ), + ], + handler=lambda medication, dosage=None: { + "medication": medication, + "dosage": dosage, + }, + tags=["medication", "drug"], + ) + + +@pytest.fixture +def fresh_registry() -> ToolRegistry: + """Create a fresh registry for testing.""" + return ToolRegistry() + + +@pytest.fixture +def executor(fresh_registry: ToolRegistry) -> ToolExecutor: + """Create an executor with a fresh registry.""" + return ToolExecutor( + registry=fresh_registry, + max_concurrent=3, + default_timeout=5.0, + ) + + +# ============================================================================= +# ToolParameter Tests +# ============================================================================= + + +class TestToolParameter: + """Tests for ToolParameter model.""" + + def test_create_string_parameter(self): + """Test creating a string parameter.""" + param = ToolParameter( + name="patient_name", + type=ParameterType.STRING, + description="Patient's full name", + required=True, + ) + + assert param.name == "patient_name" + assert param.type == ParameterType.STRING + assert param.required is True + + def test_create_number_parameter_with_range(self): + """Test creating a number parameter with min/max.""" + param = ToolParameter( + name="age", + type=ParameterType.INTEGER, + description="Patient age", + required=True, + minimum=0, + maximum=150, + ) + + assert param.minimum == 0 + assert param.maximum == 150 + + def test_create_enum_parameter(self): + """Test creating an enum parameter.""" + param = ToolParameter( + name="severity", + type=ParameterType.STRING, + description="Severity level", + required=True, + enum=["low", "medium", "high"], + ) + + assert param.enum == ["low", "medium", "high"] + + def test_to_json_schema(self, sample_parameter: ToolParameter): + """Test conversion to JSON Schema.""" + schema = sample_parameter.to_json_schema() + + assert schema["type"] == "string" + assert schema["description"] == "Name of the medication" + + def test_to_json_schema_with_constraints(self): + """Test JSON Schema with constraints.""" + param = ToolParameter( + name="weight", + type=ParameterType.NUMBER, + description="Weight in kg", + required=True, + minimum=0.5, + maximum=500, + ) + + schema = param.to_json_schema() + + assert schema["type"] == "number" + assert schema["minimum"] == 0.5 + assert schema["maximum"] == 500 + + +# ============================================================================= +# ToolDefinition Tests +# ============================================================================= + + +class TestToolDefinition: + """Tests for ToolDefinition model.""" + + def test_create_definition(self, sample_definition: ToolDefinition): + """Test creating a tool definition.""" + assert sample_definition.name == "check_medication" + assert sample_definition.category == ToolCategory.DRUG + assert len(sample_definition.parameters) == 2 + assert sample_definition.enabled is True + + def test_to_openai_format(self, sample_definition: ToolDefinition): + """Test conversion to OpenAI function calling format.""" + openai_format = sample_definition.to_openai_format() + + assert openai_format["type"] == "function" + assert openai_format["function"]["name"] == "check_medication" + assert "parameters" in openai_format["function"] + assert openai_format["function"]["parameters"]["type"] == "object" + assert "medication" in openai_format["function"]["parameters"]["properties"] + + def test_to_anthropic_format(self, sample_definition: ToolDefinition): + """Test conversion to Anthropic tool format.""" + anthropic_format = sample_definition.to_anthropic_format() + + assert anthropic_format["name"] == "check_medication" + assert "input_schema" in anthropic_format + assert anthropic_format["input_schema"]["type"] == "object" + assert "medication" in anthropic_format["input_schema"]["properties"] + + def test_validate_arguments_valid(self, sample_definition: ToolDefinition): + """Test argument validation with valid arguments.""" + errors = sample_definition.validate_arguments({"medication": "aspirin"}) + assert len(errors) == 0 + + def test_validate_arguments_missing_required( + self, sample_definition: ToolDefinition + ): + """Test argument validation with missing required parameter.""" + errors = sample_definition.validate_arguments({}) + assert len(errors) > 0 + assert "medication" in errors[0].lower() + + +# ============================================================================= +# ToolCall Tests +# ============================================================================= + + +class TestToolCall: + """Tests for ToolCall model.""" + + def test_create_tool_call(self): + """Test creating a tool call.""" + call = ToolCall( + tool_name="check_drug_interactions", + arguments={"drugs": ["aspirin", "ibuprofen"]}, + ) + + assert call.tool_name == "check_drug_interactions" + assert call.arguments["drugs"] == ["aspirin", "ibuprofen"] + assert call.id is not None # UUID auto-generated + + def test_from_openai(self): + """Test creating from OpenAI response.""" + openai_response = { + "id": "call_123", + "function": { + "name": "calculate_dose", + "arguments": '{"drug": "amoxicillin", "weight": 70}', + }, + } + + call = ToolCall.from_openai(openai_response) + + assert call.tool_name == "calculate_dose" + assert call.arguments["drug"] == "amoxicillin" + assert call.arguments["weight"] == 70 + + def test_from_anthropic(self): + """Test creating from Anthropic response.""" + anthropic_response = { + "id": "toolu_123", + "name": "interpret_lab", + "input": {"hemoglobin": 12.5, "sex": "male"}, + } + + call = ToolCall.from_anthropic(anthropic_response) + + assert call.tool_name == "interpret_lab" + assert call.arguments["hemoglobin"] == 12.5 + + def test_to_dict(self): + """Test serialization to dict.""" + call = ToolCall( + tool_name="test_tool", + arguments={"key": "value"}, + ) + d = call.to_dict() + + assert "id" in d + assert d["tool_name"] == "test_tool" + assert d["arguments"] == {"key": "value"} + assert "created_at" in d + + def test_from_openai_invalid_json(self): + """Test handling of invalid JSON in arguments.""" + openai_response = { + "function": { + "name": "test", + "arguments": "{invalid json}", + }, + } + + call = ToolCall.from_openai(openai_response) + assert call.tool_name == "test" + assert call.arguments == {} + + +# ============================================================================= +# ToolResult Tests +# ============================================================================= + + +class TestToolResult: + """Tests for ToolResult model.""" + + def test_create_success_result(self): + """Test creating a successful result.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="calculate_dose", + status=ToolStatus.SUCCESS, + output={"dose": 500, "unit": "mg", "frequency": "q8h"}, + ) + + assert result.is_success is True + assert result.is_error is False + assert result.output["dose"] == 500 + assert result.error is None + + def test_create_error_result(self): + """Test creating an error result.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="calculate_dose", + status=ToolStatus.ERROR, + error="Invalid weight: must be positive", + error_code="VALIDATION_ERROR", + ) + + assert result.is_success is False + assert result.is_error is True + assert result.error == "Invalid weight: must be positive" + + def test_to_llm_format_success(self): + """Test formatting successful result for LLM.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="check_interactions", + status=ToolStatus.SUCCESS, + output={ + "interactions": [ + {"drug1": "warfarin", "drug2": "aspirin", "severity": "high"} + ] + }, + ) + + llm_format = result.to_llm_format() + + # to_llm_format returns a string + assert isinstance(llm_format, str) + content = json.loads(llm_format) + assert "interactions" in content + + def test_to_llm_format_error(self): + """Test formatting error result for LLM.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="failing_tool", + status=ToolStatus.ERROR, + error="Something went wrong", + ) + + llm_format = result.to_llm_format() + assert "failing_tool" in llm_format + assert "Something went wrong" in llm_format + + def test_to_dict(self): + """Test serialization to dict.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="test_tool", + status=ToolStatus.SUCCESS, + output={"ok": True}, + ) + + d = result.to_dict() + assert d["call_id"] == str(call_id) + assert d["tool_name"] == "test_tool" + assert d["status"] == "success" + assert d["output"] == {"ok": True} + + def test_to_openai_format(self): + """Test conversion to OpenAI tool result format.""" + call_id = uuid4() + result = ToolResult( + call_id=call_id, + tool_name="test_tool", + status=ToolStatus.SUCCESS, + output={"data": "value"}, + ) + + fmt = result.to_openai_format() + assert fmt["role"] == "tool" + assert fmt["tool_call_id"] == str(call_id) + assert "content" in fmt + + def test_timeout_status(self): + """Test timeout status is considered error.""" + result = ToolResult( + call_id=uuid4(), + tool_name="slow_tool", + status=ToolStatus.TIMEOUT, + ) + assert result.is_error is True + assert result.is_success is False + + def test_cancelled_status(self): + """Test cancelled status is considered error.""" + result = ToolResult( + call_id=uuid4(), + tool_name="cancelled_tool", + status=ToolStatus.CANCELLED, + ) + assert result.is_error is True + assert result.is_success is False + + +# ============================================================================= +# ToolError Tests +# ============================================================================= + + +class TestToolError: + """Tests for ToolError model.""" + + def test_create_error(self): + """Test creating a tool error.""" + error = ToolError( + code="VALIDATION_ERROR", + message="Invalid weight: must be positive", + tool_name="calculate_dose", + ) + + assert error.code == "VALIDATION_ERROR" + assert error.message == "Invalid weight: must be positive" + assert error.tool_name == "calculate_dose" + assert error.recoverable is True + + def test_to_dict(self): + """Test serialization to dict.""" + error = ToolError( + code="INTERNAL_ERROR", + message="Unexpected failure", + tool_name="broken_tool", + recoverable=False, + ) + + d = error.to_dict() + assert d["code"] == "INTERNAL_ERROR" + assert d["message"] == "Unexpected failure" + assert d["tool_name"] == "broken_tool" + assert d["recoverable"] is False + + def test_to_result(self): + """Test converting error to ToolResult.""" + call_id = uuid4() + error = ToolError( + code="TIMEOUT", + message="Tool timed out", + tool_name="slow_tool", + ) + + result = error.to_result(call_id) + assert result.call_id == call_id + assert result.tool_name == "slow_tool" + assert result.status == ToolStatus.ERROR + assert result.error == "Tool timed out" + assert result.error_code == "TIMEOUT" + + def test_error_codes(self): + """Test error code constants.""" + assert ToolError.VALIDATION_ERROR == "VALIDATION_ERROR" + assert ToolError.NOT_FOUND == "TOOL_NOT_FOUND" + assert ToolError.TIMEOUT == "EXECUTION_TIMEOUT" + assert ToolError.PERMISSION_DENIED == "PERMISSION_DENIED" + assert ToolError.RATE_LIMITED == "RATE_LIMITED" + assert ToolError.INTERNAL_ERROR == "INTERNAL_ERROR" + + +# ============================================================================= +# ToolRegistry Tests +# ============================================================================= + + +class TestToolRegistry: + """Tests for ToolRegistry.""" + + def test_register_tool(self, fresh_registry: ToolRegistry): + """Test registering a tool.""" + + async def my_tool(x: str) -> dict: + return {"result": x} + + definition = ToolDefinition( + name="my_tool", + description="A test tool", + category=ToolCategory.UTILITY, + parameters=[], + handler=my_tool, + ) + + fresh_registry.register(definition) + + assert fresh_registry.get("my_tool") is not None + + def test_get_nonexistent_tool(self, fresh_registry: ToolRegistry): + """Test getting a tool that doesn't exist.""" + result = fresh_registry.get("nonexistent") + assert result is None + + def test_get_by_category(self, fresh_registry: ToolRegistry): + """Test filtering tools by category.""" + for i, category in enumerate( + [ToolCategory.DRUG, ToolCategory.DRUG, ToolCategory.LAB] + ): + + async def handler() -> dict: + return {} + + fresh_registry.register( + ToolDefinition( + name=f"tool_{i}", + description=f"Tool {i}", + category=category, + parameters=[], + handler=handler, + ) + ) + + drug_tools = fresh_registry.get_by_category(ToolCategory.DRUG) + assert len(drug_tools) == 2 + + lab_tools = fresh_registry.get_by_category(ToolCategory.LAB) + assert len(lab_tools) == 1 + + def test_enable_disable_tool( + self, fresh_registry: ToolRegistry, sample_definition: ToolDefinition + ): + """Test enabling and disabling a tool.""" + fresh_registry.register(sample_definition) + + # Disable + assert fresh_registry.disable("check_medication") is True + tool_def = fresh_registry.get("check_medication") + assert tool_def.enabled is False + + # Enable + assert fresh_registry.enable("check_medication") is True + tool_def = fresh_registry.get("check_medication") + assert tool_def.enabled is True + + def test_unregister(self, fresh_registry: ToolRegistry): + """Test unregistering a tool.""" + + async def handler() -> dict: + return {} + + fresh_registry.register( + ToolDefinition( + name="removable", + description="Removable tool", + category=ToolCategory.UTILITY, + parameters=[], + handler=handler, + ) + ) + + assert fresh_registry.get("removable") is not None + result = fresh_registry.unregister("removable") + assert result is True + assert fresh_registry.get("removable") is None + + def test_unregister_nonexistent(self, fresh_registry: ToolRegistry): + """Test unregistering a tool that doesn't exist.""" + result = fresh_registry.unregister("nonexistent") + assert result is False + + def test_len_and_contains(self, fresh_registry: ToolRegistry): + """Test __len__ and __contains__.""" + + async def handler() -> dict: + return {} + + assert len(fresh_registry) == 0 + assert "my_tool" not in fresh_registry + + fresh_registry.register( + ToolDefinition( + name="my_tool", + description="Test", + category=ToolCategory.UTILITY, + parameters=[], + handler=handler, + ) + ) + + assert len(fresh_registry) == 1 + assert "my_tool" in fresh_registry + + def test_summary(self, fresh_registry: ToolRegistry, sample_definition): + """Test registry summary.""" + fresh_registry.register(sample_definition) + summary = fresh_registry.summary() + + assert "total_tools" in summary + assert summary["total_tools"] == 1 + + def test_to_openai_format( + self, fresh_registry: ToolRegistry, sample_definition: ToolDefinition + ): + """Test converting all tools to OpenAI format.""" + fresh_registry.register(sample_definition) + openai_tools = fresh_registry.to_openai_format() + + assert isinstance(openai_tools, list) + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + + def test_to_anthropic_format( + self, fresh_registry: ToolRegistry, sample_definition: ToolDefinition + ): + """Test converting all tools to Anthropic format.""" + fresh_registry.register(sample_definition) + anthropic_tools = fresh_registry.to_anthropic_format() + + assert isinstance(anthropic_tools, list) + assert len(anthropic_tools) == 1 + assert "input_schema" in anthropic_tools[0] + + +# ============================================================================= +# ToolExecutor Tests +# ============================================================================= + + +class TestToolExecutor: + """Tests for ToolExecutor.""" + + @pytest.mark.asyncio + async def test_execute_simple_tool(self, executor: ToolExecutor): + """Test executing a simple tool.""" + + async def greet(name: str) -> dict: + return {"greeting": f"Hello, {name}!"} + + definition = ToolDefinition( + name="greet", + description="Greet someone", + category=ToolCategory.UTILITY, + parameters=[string_param("name", "Name to greet")], + handler=greet, + ) + + executor._registry.register(definition) + + call = ToolCall(tool_name="greet", arguments={"name": "MedeX"}) + result = await executor.execute(call) + + assert result.is_success is True + assert result.output["greeting"] == "Hello, MedeX!" + + @pytest.mark.asyncio + async def test_execute_with_validation_error(self, executor: ToolExecutor): + """Test execution with invalid arguments.""" + + async def calculate(x: float, y: float) -> dict: + return {"sum": x + y} + + definition = ToolDefinition( + name="calculate", + description="Calculate sum", + category=ToolCategory.UTILITY, + parameters=[ + number_param("x", "First number", required=True), + number_param("y", "Second number", required=True), + ], + handler=calculate, + ) + + executor._registry.register(definition) + + # Missing required argument + call = ToolCall(tool_name="calculate", arguments={"x": 5}) + result = await executor.execute(call) + + assert result.is_success is False + + @pytest.mark.asyncio + async def test_execute_tool_not_found(self, executor: ToolExecutor): + """Test execution of non-existent tool.""" + call = ToolCall(tool_name="nonexistent_tool", arguments={}) + result = await executor.execute(call) + + assert result.is_success is False + assert result.is_error is True + + @pytest.mark.asyncio + async def test_execute_batch(self, executor: ToolExecutor): + """Test batch execution.""" + + async def echo(msg: str) -> dict: + return {"echo": msg} + + definition = ToolDefinition( + name="echo", + description="Echo message", + category=ToolCategory.UTILITY, + parameters=[string_param("msg", "Message")], + handler=echo, + ) + + executor._registry.register(definition) + + calls = [ + ToolCall(tool_name="echo", arguments={"msg": "one"}), + ToolCall(tool_name="echo", arguments={"msg": "two"}), + ToolCall(tool_name="echo", arguments={"msg": "three"}), + ] + + results = await executor.execute_batch(calls) + + assert len(results) == 3 + assert all(r.is_success for r in results) + assert results[0].output["echo"] == "one" + assert results[1].output["echo"] == "two" + assert results[2].output["echo"] == "three" + + @pytest.mark.asyncio + async def test_execute_batch_empty(self, executor: ToolExecutor): + """Test batch execution with empty list.""" + results = await executor.execute_batch([]) + assert results == [] + + @pytest.mark.asyncio + async def test_execute_batch_sequential(self, executor: ToolExecutor): + """Test sequential batch execution.""" + + async def counter(val: int) -> dict: + return {"value": val} + + definition = ToolDefinition( + name="counter", + description="Return value", + category=ToolCategory.UTILITY, + parameters=[], + handler=counter, + ) + + executor._registry.register(definition) + + calls = [ToolCall(tool_name="counter", arguments={"val": i}) for i in range(3)] + + results = await executor.execute_batch(calls, parallel=False) + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_get_metrics(self, executor: ToolExecutor): + """Test metrics collection.""" + + async def metric_tool() -> dict: + return {"ok": True} + + definition = ToolDefinition( + name="metric_tool", + description="Tool for metrics", + category=ToolCategory.UTILITY, + parameters=[], + handler=metric_tool, + ) + + executor._registry.register(definition) + + for _ in range(3): + call = ToolCall(tool_name="metric_tool", arguments={}) + await executor.execute(call) + + metrics = executor.get_metrics() + assert metrics["total_executions"] == 3 + assert metrics["successful"] == 3 + + @pytest.mark.asyncio + async def test_reset_metrics(self, executor: ToolExecutor): + """Test resetting metrics.""" + + async def dummy() -> dict: + return {} + + definition = ToolDefinition( + name="dummy", + description="Dummy", + category=ToolCategory.UTILITY, + parameters=[], + handler=dummy, + ) + + executor._registry.register(definition) + + call = ToolCall(tool_name="dummy", arguments={}) + await executor.execute(call) + + executor.reset_metrics() + metrics = executor.get_metrics() + assert metrics["total_executions"] == 0 + + +# ============================================================================= +# Medical Tools Tests +# ============================================================================= + + +class TestDrugInteractionTools: + """Tests for drug interaction tools.""" + + @pytest.mark.asyncio + async def test_check_drug_interactions_no_interactions(self): + """Test with drugs that don't interact.""" + from medex.tools.medical.drug_interactions import check_drug_interactions + + result = await check_drug_interactions(["amoxicillin", "paracetamol"]) + + assert "interactions" in result + + @pytest.mark.asyncio + async def test_get_drug_info(self): + """Test getting drug information.""" + from medex.tools.medical.drug_interactions import get_drug_info + + result = await get_drug_info("metformin") + + assert "name" in result or "drug" in result + assert "found" in result + + @pytest.mark.asyncio + async def test_check_drug_interactions_single_drug(self): + """Test with single drug (edge case).""" + from medex.tools.medical.drug_interactions import check_drug_interactions + + result = await check_drug_interactions(["aspirin"]) + assert "interactions" in result + + @pytest.mark.asyncio + async def test_check_drug_interactions_multiple(self): + """Test with multiple drugs.""" + from medex.tools.medical.drug_interactions import check_drug_interactions + + result = await check_drug_interactions(["warfarin", "aspirin", "ibuprofen"]) + assert "interactions" in result + + +class TestDosageCalculatorTools: + """Tests for dosage calculator tools.""" + + @pytest.mark.asyncio + async def test_calculate_pediatric_dose(self): + """Test pediatric dose calculation.""" + from medex.tools.medical.dosage_calculator import calculate_pediatric_dose + + result = await calculate_pediatric_dose( + drug_name="amoxicillin", + weight_kg=20, + ) + + assert "drug_name" in result or "drug" in result + + @pytest.mark.asyncio + async def test_calculate_bsa(self): + """Test BSA calculation.""" + from medex.tools.medical.dosage_calculator import calculate_bsa + + result = await calculate_bsa( + weight_kg=70, + height_cm=175, + ) + + assert "bsa_m2" in result + # bsa_m2 is a dict with multiple formula results + bsa = result["bsa_m2"] + assert isinstance(bsa, dict) + assert bsa["recommended"] > 0 + assert bsa["recommended"] < 3 + + @pytest.mark.asyncio + async def test_calculate_creatinine_clearance(self): + """Test CrCl calculation.""" + from medex.tools.medical.dosage_calculator import calculate_creatinine_clearance + + result = await calculate_creatinine_clearance( + age_years=65, + weight_kg=70, + creatinine_mg_dl=1.2, + is_female=False, + ) + + assert "creatinine_clearance_ml_min" in result + assert result["creatinine_clearance_ml_min"] > 0 + assert "gfr_category" in result + + @pytest.mark.asyncio + async def test_adjust_dose_renal(self): + """Test renal dose adjustment.""" + from medex.tools.medical.dosage_calculator import adjust_dose_renal + + result = await adjust_dose_renal( + drug_name="metformin", + gfr=45.0, + ) + + assert isinstance(result, dict) + + +class TestLabInterpreterTools: + """Tests for lab interpreter tools.""" + + @pytest.mark.asyncio + async def test_interpret_cbc(self): + """Test CBC interpretation.""" + from medex.tools.medical.lab_interpreter import interpret_cbc + + result = await interpret_cbc( + hemoglobin=10.5, + sex="female", + age_years=45, + mcv=75, + ) + + assert "interpretations" in result + assert "clinical_findings" in result + assert "differential_diagnoses" in result + + @pytest.mark.asyncio + async def test_interpret_liver_panel(self): + """Test liver panel interpretation.""" + from medex.tools.medical.lab_interpreter import interpret_liver_panel + + result = await interpret_liver_panel( + alt=150, + ast=120, + alp=90, + ) + + assert "pattern" in result + assert "de_ritis_ratio" in result + assert result["pattern"] in ["normal", "hepatocelular", "colestásico", "mixto"] + + @pytest.mark.asyncio + async def test_interpret_thyroid_panel(self): + """Test thyroid panel interpretation.""" + from medex.tools.medical.lab_interpreter import interpret_thyroid_panel + + result = await interpret_thyroid_panel( + tsh=0.1, + t4_free=2.5, + ) + + assert "thyroid_status" in result + assert "interpretations" in result + + +class TestEmergencyDetectorTools: + """Tests for emergency detector tools.""" + + @pytest.mark.asyncio + async def test_detect_emergency_cardiac(self): + """Test detection of cardiac emergency.""" + from medex.tools.medical.emergency_detector import detect_emergency + + result = await detect_emergency( + symptoms=["dolor torácico", "sudoración", "disnea"], + onset="súbito", + duration="30 minutos", + ) + + assert result["emergency_detected"] is True + assert result["triage"]["level"] <= 2 + + @pytest.mark.asyncio + async def test_detect_emergency_non_urgent(self): + """Test with non-urgent symptoms.""" + from medex.tools.medical.emergency_detector import detect_emergency + + result = await detect_emergency( + symptoms=["tos leve", "congestión nasal"], + onset="gradual", + duration="3 días", + ) + + assert result["triage"]["level"] >= 4 + + @pytest.mark.asyncio + async def test_check_critical_values(self): + """Test critical lab value detection.""" + from medex.tools.medical.emergency_detector import check_critical_values + + result = await check_critical_values(lab_values={"potasio": 7.0, "glucosa": 40}) + + assert "has_critical_values" in result + + @pytest.mark.asyncio + async def test_quick_triage(self): + """Test quick triage assessment.""" + from medex.tools.medical.emergency_detector import quick_triage + + result = await quick_triage( + chief_complaint="dolor abdominal severo", + severity="severo", + duration_hours=2, + worsening=True, + ) + + assert "triage_level" in result + assert "triage_color" in result + assert "recommendation" in result + + +# ============================================================================= +# Enum Tests +# ============================================================================= + + +class TestEnums: + """Tests for tool system enums.""" + + def test_tool_category_values(self): + """Test ToolCategory enum values.""" + assert ToolCategory.DRUG.value == "drug" + assert ToolCategory.LAB.value == "lab" + assert ToolCategory.DOSAGE.value == "dosage" + assert ToolCategory.EMERGENCY.value == "emergency" + assert ToolCategory.DIAGNOSIS.value == "diagnosis" + assert ToolCategory.UTILITY.value == "utility" + + def test_tool_status_values(self): + """Test ToolStatus enum values.""" + assert ToolStatus.PENDING.value == "pending" + assert ToolStatus.RUNNING.value == "running" + assert ToolStatus.SUCCESS.value == "success" + assert ToolStatus.ERROR.value == "error" + assert ToolStatus.TIMEOUT.value == "timeout" + + def test_parameter_type_values(self): + """Test ParameterType enum values.""" + assert ParameterType.STRING.value == "string" + assert ParameterType.NUMBER.value == "number" + assert ParameterType.INTEGER.value == "integer" + assert ParameterType.BOOLEAN.value == "boolean" + assert ParameterType.ARRAY.value == "array" + assert ParameterType.OBJECT.value == "object" + + +# ============================================================================= +# Run Tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From d217b668f2b1a982f8f4acb36b4bf89ae1c30252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:41 -0300 Subject: [PATCH 2/6] test: add tests/test_vision.py --- tests/test_vision.py | 226 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 tests/test_vision.py diff --git a/tests/test_vision.py b/tests/test_vision.py new file mode 100644 index 0000000..9b9085d --- /dev/null +++ b/tests/test_vision.py @@ -0,0 +1,226 @@ +# ============================================================================= +# MedeX - Vision Module Tests +# ============================================================================= +""" +Tests for the MedeX Vision module. + +Covers: +- ImageAnalyzer: validation, supported modalities, rejection messages +- ImagingModality: enum values +- ImageValidation: dataclass construction +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from medex.vision.analyzer import ImageAnalyzer, ImageValidation, ImagingModality + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def analyzer() -> ImageAnalyzer: + """Create an ImageAnalyzer instance.""" + return ImageAnalyzer() + + +@pytest.fixture +def valid_image(tmp_path: Path) -> Path: + """Create a valid test image file (small PNG-like).""" + img = tmp_path / "test_xray.png" + # Write minimal content (not a real image, but we only validate metadata) + img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + return img + + +@pytest.fixture +def valid_dicom(tmp_path: Path) -> Path: + """Create a valid DICOM test file.""" + dcm = tmp_path / "scan.dcm" + dcm.write_bytes(b"DICM" + b"\x00" * 100) + return dcm + + +@pytest.fixture +def oversized_image(tmp_path: Path) -> Path: + """Create an oversized test file (>50MB).""" + img = tmp_path / "huge.jpg" + # Write just over 50MB + img.write_bytes(b"\xff\xd8" + b"\x00" * (51 * 1024 * 1024)) + return img + + +# ============================================================================= +# ImagingModality Tests +# ============================================================================= + + +class TestImagingModality: + """Tests for ImagingModality enum.""" + + def test_modality_values(self): + """Test all modality enum values.""" + assert ImagingModality.RADIOGRAPHY.value == "RX" + assert ImagingModality.CT.value == "TAC" + assert ImagingModality.MRI.value == "RM" + assert ImagingModality.ULTRASOUND.value == "US" + assert ImagingModality.UNKNOWN.value == "UNKNOWN" + + def test_modality_members(self): + """Test all modality members exist.""" + members = list(ImagingModality) + assert len(members) == 5 + assert ImagingModality.UNKNOWN in members + + +# ============================================================================= +# ImageValidation Tests +# ============================================================================= + + +class TestImageValidation: + """Tests for ImageValidation dataclass.""" + + def test_create_valid_result(self): + """Test creating a valid validation result.""" + result = ImageValidation( + is_valid=True, + modality=ImagingModality.RADIOGRAPHY, + confidence=0.95, + message="Valid radiograph", + ) + + assert result.is_valid is True + assert result.modality == ImagingModality.RADIOGRAPHY + assert result.confidence == 0.95 + assert result.message == "Valid radiograph" + + def test_create_invalid_result(self): + """Test creating an invalid validation result.""" + result = ImageValidation( + is_valid=False, + modality=ImagingModality.UNKNOWN, + confidence=0.0, + message="File not found", + ) + + assert result.is_valid is False + assert result.modality == ImagingModality.UNKNOWN + assert result.confidence == 0.0 + + +# ============================================================================= +# ImageAnalyzer Tests +# ============================================================================= + + +class TestImageAnalyzer: + """Tests for ImageAnalyzer class.""" + + def test_analyzer_initialization(self, analyzer: ImageAnalyzer): + """Test analyzer creates successfully.""" + assert analyzer is not None + + def test_supported_extensions(self, analyzer: ImageAnalyzer): + """Test supported file extensions.""" + assert ".jpg" in analyzer.SUPPORTED_EXTENSIONS + assert ".jpeg" in analyzer.SUPPORTED_EXTENSIONS + assert ".png" in analyzer.SUPPORTED_EXTENSIONS + assert ".dcm" in analyzer.SUPPORTED_EXTENSIONS + assert ".dicom" in analyzer.SUPPORTED_EXTENSIONS + + def test_max_file_size(self, analyzer: ImageAnalyzer): + """Test max file size is 50MB.""" + assert analyzer.MAX_FILE_SIZE == 50 * 1024 * 1024 + + def test_validate_valid_png(self, analyzer: ImageAnalyzer, valid_image: Path): + """Test validation of valid PNG file.""" + result = analyzer.validate_image(str(valid_image)) + + assert result.is_valid is True + assert result.modality == ImagingModality.UNKNOWN # Detection happens via AI + assert result.confidence == 0.5 + + def test_validate_valid_dicom(self, analyzer: ImageAnalyzer, valid_dicom: Path): + """Test validation of valid DICOM file.""" + result = analyzer.validate_image(str(valid_dicom)) + + assert result.is_valid is True + + def test_validate_file_not_found(self, analyzer: ImageAnalyzer): + """Test validation with non-existent file.""" + result = analyzer.validate_image("/tmp/definitely_nonexistent_file.png") + + assert result.is_valid is False + assert result.modality == ImagingModality.UNKNOWN + assert result.confidence == 0.0 + assert "not found" in result.message.lower() + + def test_validate_unsupported_extension( + self, analyzer: ImageAnalyzer, tmp_path: Path + ): + """Test validation with unsupported file extension.""" + bad_file = tmp_path / "document.pdf" + bad_file.write_bytes(b"%PDF-1.4" + b"\x00" * 100) + + result = analyzer.validate_image(str(bad_file)) + + assert result.is_valid is False + assert "unsupported" in result.message.lower() + + def test_validate_oversized_file( + self, analyzer: ImageAnalyzer, oversized_image: Path + ): + """Test validation with oversized file.""" + result = analyzer.validate_image(str(oversized_image)) + + assert result.is_valid is False + assert "too large" in result.message.lower() + + def test_validate_jpg_extension(self, analyzer: ImageAnalyzer, tmp_path: Path): + """Test validation with .jpg extension.""" + jpg = tmp_path / "xray.jpg" + jpg.write_bytes(b"\xff\xd8\xff" + b"\x00" * 100) + + result = analyzer.validate_image(str(jpg)) + assert result.is_valid is True + + def test_validate_jpeg_extension(self, analyzer: ImageAnalyzer, tmp_path: Path): + """Test validation with .jpeg extension.""" + jpeg = tmp_path / "scan.jpeg" + jpeg.write_bytes(b"\xff\xd8\xff" + b"\x00" * 100) + + result = analyzer.validate_image(str(jpeg)) + assert result.is_valid is True + + def test_get_supported_modalities(self): + """Test getting supported modalities list.""" + modalities = ImageAnalyzer.get_supported_modalities() + + assert isinstance(modalities, list) + assert "RX" in modalities + assert "TAC" in modalities + assert "RM" in modalities + assert "US" in modalities + assert "UNKNOWN" not in modalities + + def test_format_rejection_message(self): + """Test standard rejection message.""" + msg = ImageAnalyzer.format_rejection_message() + + assert isinstance(msg, str) + assert len(msg) > 0 + assert "RX" in msg or "TAC" in msg # Contains modality references + + +# ============================================================================= +# Run Tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 0c23ddcb787022e622f1741cb21396e6d9c4937b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:43 -0300 Subject: [PATCH 3/6] test: add tests/test_knowledge.py --- tests/test_knowledge.py | 446 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 446 insertions(+) create mode 100644 tests/test_knowledge.py diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py new file mode 100644 index 0000000..ebe876c --- /dev/null +++ b/tests/test_knowledge.py @@ -0,0 +1,446 @@ +# ============================================================================= +# MedeX - Knowledge Module Tests +# ============================================================================= +""" +Tests for the MedeX Knowledge module. + +Covers: +- medical_base.py: MedicalCondition, Medication, DiagnosticProcedure, + ClinicalProtocol, MedicalKnowledgeBase +- pharmaceutical.py: DrugMonograph, DrugInteraction, PharmaceuticalDatabase, + InteractionSeverity, RouteOfAdministration +""" + +from __future__ import annotations + +import pytest + +from medex.knowledge.medical_base import ( + ClinicalProtocol, + DiagnosticProcedure, + MedicalCondition, + MedicalKnowledgeBase, + Medication, +) +from medex.knowledge.pharmaceutical import ( + DrugInteraction, + DrugMonograph, + InteractionSeverity, + PharmaceuticalDatabase, + RouteOfAdministration, +) + +# ============================================================================= +# InteractionSeverity Enum Tests +# ============================================================================= + + +class TestInteractionSeverity: + """Tests for InteractionSeverity enum.""" + + def test_severity_values(self): + """Test all severity enum values.""" + assert InteractionSeverity.MINOR.value == "minor" + assert InteractionSeverity.MODERATE.value == "moderate" + assert InteractionSeverity.MAJOR.value == "major" + assert InteractionSeverity.CONTRAINDICATED.value == "contraindicated" + + def test_severity_members(self): + """Test all severity members exist.""" + members = list(InteractionSeverity) + assert len(members) == 4 + + def test_severity_comparison(self): + """Test severity enum values are distinct.""" + assert InteractionSeverity.MINOR != InteractionSeverity.MAJOR + assert InteractionSeverity.MODERATE != InteractionSeverity.CONTRAINDICATED + + +# ============================================================================= +# RouteOfAdministration Enum Tests +# ============================================================================= + + +class TestRouteOfAdministration: + """Tests for RouteOfAdministration enum.""" + + def test_route_values(self): + """Test core route enum values.""" + assert RouteOfAdministration.ORAL.value == "oral" + assert RouteOfAdministration.IV.value == "intravenous" + assert RouteOfAdministration.IM.value == "intramuscular" + assert RouteOfAdministration.SC.value == "subcutaneous" + assert RouteOfAdministration.TOPICAL.value == "topical" + + def test_route_members(self): + """Test all route members exist.""" + members = list(RouteOfAdministration) + assert len(members) >= 5 # At least 5 standard routes + + +# ============================================================================= +# MedicalCondition Tests +# ============================================================================= + + +class TestMedicalCondition: + """Tests for MedicalCondition dataclass.""" + + def _make_condition(self, **overrides) -> MedicalCondition: + """Helper to create a MedicalCondition with defaults.""" + defaults = { + "icd10_code": "I10", + "name": "Hypertension", + "category": "Cardiovascular", + "description": "Elevated blood pressure", + "symptoms": ["headache", "dizziness"], + "risk_factors": ["obesity", "smoking"], + "complications": ["stroke", "heart failure"], + "diagnostic_criteria": ["BP > 140/90 on 2+ occasions"], + "differential_diagnosis": ["White coat hypertension"], + "treatment_protocol": ["Lifestyle changes", "ACE inhibitors"], + "emergency_signs": ["BP > 180/120", "Vision changes"], + "prognosis": "Good with treatment", + "follow_up": ["Quarterly BP checks"], + } + defaults.update(overrides) + return MedicalCondition(**defaults) + + def test_create_condition(self): + """Test creating a medical condition with all fields.""" + condition = self._make_condition() + + assert condition.name == "Hypertension" + assert condition.icd10_code == "I10" + assert condition.description == "Elevated blood pressure" + assert "headache" in condition.symptoms + assert len(condition.risk_factors) == 2 + + def test_condition_category(self): + """Test condition category field.""" + condition = self._make_condition(category="Endocrine") + assert condition.category == "Endocrine" + + def test_condition_complications(self): + """Test condition complications field.""" + condition = self._make_condition( + complications=["Retinopathy", "Nephropathy", "Neuropathy"] + ) + assert len(condition.complications) == 3 + assert "Nephropathy" in condition.complications + + def test_condition_emergency_signs(self): + """Test condition emergency signs field.""" + condition = self._make_condition( + emergency_signs=["Altered consciousness", "Chest pain"] + ) + assert len(condition.emergency_signs) == 2 + + +# ============================================================================= +# Medication Tests +# ============================================================================= + + +class TestMedication: + """Tests for Medication dataclass.""" + + def _make_medication(self, **overrides) -> Medication: + """Helper to create a Medication with defaults.""" + defaults = { + "name": "Amoxicillin", + "generic_name": "Amoxicillin", + "category": "Antibiotic", + "indications": ["Upper respiratory infection"], + "contraindications": ["Penicillin allergy"], + "dosage_adult": "500mg q8h", + "dosage_pediatric": "25mg/kg/day divided q8h", + "side_effects": ["Nausea", "Diarrhea", "Rash"], + "interactions": ["Methotrexate", "Warfarin"], + "monitoring": ["Renal function", "CBC"], + "pregnancy_category": "B", + } + defaults.update(overrides) + return Medication(**defaults) + + def test_create_medication(self): + """Test creating a medication with all fields.""" + med = self._make_medication() + + assert med.name == "Amoxicillin" + assert med.generic_name == "Amoxicillin" + assert med.category == "Antibiotic" + assert med.pregnancy_category == "B" + + def test_medication_dosage_fields(self): + """Test medication dosage fields.""" + med = self._make_medication( + dosage_adult="10mg/day", + dosage_pediatric="5mg/kg/day", + ) + assert med.dosage_adult == "10mg/day" + assert med.dosage_pediatric == "5mg/kg/day" + + def test_medication_interactions(self): + """Test medication drug interactions list.""" + med = self._make_medication(interactions=["Aspirin", "Ibuprofen", "Naproxen"]) + assert len(med.interactions) == 3 + assert "Aspirin" in med.interactions + + +# ============================================================================= +# DiagnosticProcedure Tests +# ============================================================================= + + +class TestDiagnosticProcedure: + """Tests for DiagnosticProcedure dataclass.""" + + def _make_procedure(self, **overrides) -> DiagnosticProcedure: + """Helper to create a DiagnosticProcedure with defaults.""" + defaults = { + "name": "Complete Blood Count", + "category": "Laboratory", + "indications": ["Anemia screening", "Infection evaluation"], + "contraindications": [], + "preparation": ["No fasting required"], + "procedure_steps": ["Venipuncture", "Analyze sample"], + "interpretation": ["Low Hgb suggests anemia"], + "complications": ["Bruising at puncture site"], + "cost_range": "Low", + } + defaults.update(overrides) + return DiagnosticProcedure(**defaults) + + def test_create_procedure(self): + """Test creating a diagnostic procedure.""" + proc = self._make_procedure() + + assert proc.name == "Complete Blood Count" + assert proc.category == "Laboratory" + assert len(proc.indications) == 2 + + def test_procedure_steps(self): + """Test procedure steps field.""" + proc = self._make_procedure( + procedure_steps=["Prepare patient", "Administer contrast", "Scan"] + ) + assert len(proc.procedure_steps) == 3 + + def test_procedure_cost_range(self): + """Test procedure cost range field.""" + proc = self._make_procedure(cost_range="High") + assert proc.cost_range == "High" + + +# ============================================================================= +# ClinicalProtocol Tests +# ============================================================================= + + +class TestClinicalProtocol: + """Tests for ClinicalProtocol dataclass.""" + + def _make_protocol(self, **overrides) -> ClinicalProtocol: + """Helper to create a ClinicalProtocol with defaults.""" + defaults = { + "name": "Sepsis Management", + "category": "Emergency", + "indication": "Suspected sepsis", + "steps": ["Measure lactate", "Blood cultures", "Antibiotics"], + "decision_points": ["Lactate > 2 mmol/L"], + "emergency_modifications": ["Vasopressors if MAP < 65"], + "evidence_level": "1A", + "last_updated": "2024-01-01", + } + defaults.update(overrides) + return ClinicalProtocol(**defaults) + + def test_create_protocol(self): + """Test creating a clinical protocol.""" + protocol = self._make_protocol() + + assert protocol.name == "Sepsis Management" + assert len(protocol.steps) == 3 + assert protocol.evidence_level == "1A" + + def test_protocol_category(self): + """Test protocol category field.""" + protocol = self._make_protocol(category="Cardiology") + assert protocol.category == "Cardiology" + + def test_protocol_decision_points(self): + """Test protocol decision points.""" + protocol = self._make_protocol( + decision_points=["BP < 90 systolic", "SpO2 < 92%"] + ) + assert len(protocol.decision_points) == 2 + + +# ============================================================================= +# MedicalKnowledgeBase Tests +# ============================================================================= + + +class TestMedicalKnowledgeBase: + """Tests for MedicalKnowledgeBase class.""" + + def test_create_knowledge_base(self): + """Test creating an empty knowledge base.""" + kb = MedicalKnowledgeBase() + assert kb is not None + + def test_knowledge_base_has_conditions(self): + """Test knowledge base conditions dict.""" + kb = MedicalKnowledgeBase() + assert isinstance(kb.conditions, dict) + + def test_knowledge_base_has_medications(self): + """Test knowledge base medications dict.""" + kb = MedicalKnowledgeBase() + assert isinstance(kb.medications, dict) + + def test_knowledge_base_has_procedures(self): + """Test knowledge base procedures dict.""" + kb = MedicalKnowledgeBase() + assert isinstance(kb.procedures, dict) + + def test_knowledge_base_has_protocols(self): + """Test knowledge base protocols dict.""" + kb = MedicalKnowledgeBase() + assert isinstance(kb.protocols, dict) + + +# ============================================================================= +# DrugMonograph Tests +# ============================================================================= + + +class TestDrugMonograph: + """Tests for DrugMonograph dataclass.""" + + def test_create_monograph(self): + """Test creating a drug monograph with all required fields.""" + mono = DrugMonograph( + name="Metformin", + generic_name="Metformin HCl", + brand_names=["Glucophage"], + drug_class="Biguanide", + therapeutic_category="Antidiabetic", + mechanism_of_action="Decreases hepatic glucose production", + indications=["Type 2 diabetes mellitus"], + dosages=["500mg BID"], + contraindications=["eGFR < 30"], + adverse_effects=["GI upset"], + pharmacokinetics="T1/2 ~6h", + monitoring_parameters=["HbA1c", "Renal function"], + patient_counseling=["Take with meals"], + storage_conditions="Room temperature", + pregnancy_category="B", + lactation_safety="Compatible", + pediatric_use=">10 years", + geriatric_use="Adjust for renal function", + cost_effectiveness="High", + ) + + assert mono.name == "Metformin" + assert mono.generic_name == "Metformin HCl" + assert mono.drug_class == "Biguanide" + assert "Glucophage" in mono.brand_names + + def test_monograph_has_expected_fields(self): + """Test monograph has key expected fields.""" + fields = DrugMonograph.__dataclass_fields__ + assert "name" in fields + assert "generic_name" in fields + assert "drug_class" in fields + + +# ============================================================================= +# DrugInteraction Tests +# ============================================================================= + + +class TestDrugInteraction: + """Tests for DrugInteraction dataclass.""" + + def test_create_interaction(self): + """Test creating a drug interaction.""" + interaction = DrugInteraction( + drug_a="Warfarin", + drug_b="Aspirin", + severity=InteractionSeverity.MAJOR, + mechanism="Synergistic anticoagulation", + clinical_effect="Increased bleeding risk", + management="Avoid or monitor INR closely", + onset="rapid", + documentation="excellent", + ) + + assert interaction.drug_a == "Warfarin" + assert interaction.drug_b == "Aspirin" + assert interaction.severity == InteractionSeverity.MAJOR + assert "bleeding" in interaction.clinical_effect.lower() + + def test_create_moderate_interaction(self): + """Test creating a moderate interaction.""" + interaction = DrugInteraction( + drug_a="Ibuprofen", + drug_b="Lisinopril", + severity=InteractionSeverity.MODERATE, + mechanism="NSAIDs reduce ACE inhibitor efficacy", + clinical_effect="Reduced antihypertensive effect", + management="Monitor blood pressure", + onset="delayed", + documentation="good", + ) + + assert interaction.severity == InteractionSeverity.MODERATE + + def test_create_contraindicated_interaction(self): + """Test creating a contraindicated interaction.""" + interaction = DrugInteraction( + drug_a="Methotrexate", + drug_b="Trimethoprim", + severity=InteractionSeverity.CONTRAINDICATED, + mechanism="Bone marrow suppression", + clinical_effect="Pancytopenia risk", + management="Do not combine", + onset="delayed", + documentation="excellent", + ) + + assert interaction.severity == InteractionSeverity.CONTRAINDICATED + + +# ============================================================================= +# PharmaceuticalDatabase Tests +# ============================================================================= + + +class TestPharmaceuticalDatabase: + """Tests for PharmaceuticalDatabase class.""" + + def test_create_database(self): + """Test creating a pharmaceutical database.""" + db = PharmaceuticalDatabase() + assert db is not None + + def test_database_has_monographs(self): + """Test database monographs collection.""" + db = PharmaceuticalDatabase() + assert isinstance(db.monographs, dict) + + def test_database_has_interactions(self): + """Test database interactions collection.""" + db = PharmaceuticalDatabase() + # interactions may be list or dict depending on implementation + assert isinstance(db.interactions, (list, dict)) + + +# ============================================================================= +# Run Tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 0360661643568d2c0cce9f6d74f07b35a66d710d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:45 -0300 Subject: [PATCH 4/6] test: add tests/test_providers.py --- tests/test_providers.py | 473 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 473 insertions(+) create mode 100644 tests/test_providers.py diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 0000000..ac8b8a8 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,473 @@ +# ============================================================================= +# MedeX - Providers Module Tests +# ============================================================================= +""" +Tests for the MedeX Providers module. + +Covers: +- ProviderConfig: creation, API key resolution (env / file / missing) +- ProviderResponse: success property, error handling +- ProviderStatus: enum values +- ModelProvider (ABC): initialization, status management, error classification +- ProviderManager: initialization, provider listing, set_provider +""" + +from __future__ import annotations + +import os +from collections.abc import Generator +from unittest.mock import patch + +import pytest + +from medex.providers.base import ( + ModelProvider, + ProviderConfig, + ProviderResponse, + ProviderStatus, +) + +# ============================================================================= +# ProviderStatus Enum Tests +# ============================================================================= + + +class TestProviderStatus: + """Tests for ProviderStatus enum.""" + + def test_status_values(self): + """Test all status enum values.""" + assert ProviderStatus.AVAILABLE.value == "available" + assert ProviderStatus.RATE_LIMITED.value == "rate_limited" + assert ProviderStatus.QUOTA_EXCEEDED.value == "quota_exceeded" + assert ProviderStatus.AUTH_ERROR.value == "auth_error" + assert ProviderStatus.UNAVAILABLE.value == "unavailable" + assert ProviderStatus.UNKNOWN.value == "unknown" + + def test_status_members(self): + """Test expected number of status members.""" + members = list(ProviderStatus) + assert len(members) == 6 + + def test_status_distinct_values(self): + """Test all status values are distinct.""" + values = [s.value for s in ProviderStatus] + assert len(values) == len(set(values)) + + +# ============================================================================= +# ProviderConfig Tests +# ============================================================================= + + +class TestProviderConfig: + """Tests for ProviderConfig dataclass.""" + + def test_create_config(self): + """Test creating a provider config with required fields.""" + config = ProviderConfig( + name="openai", + model_id="gpt-4", + ) + + assert config.name == "openai" + assert config.model_id == "gpt-4" + + def test_config_defaults(self): + """Test config defaults are set correctly.""" + config = ProviderConfig(name="test", model_id="m") + + assert config.api_key_env == "" + assert config.api_key_file == "" + assert config.base_url == "" + assert config.max_tokens == 4096 + assert config.temperature == 0.7 + assert config.supports_streaming is True + assert config.supports_vision is False + assert config.is_free is False + assert config.description == "" + + def test_config_with_base_url(self): + """Test config with custom base URL.""" + config = ProviderConfig( + name="custom", + model_id="local-model", + base_url="http://localhost:8080/v1", + ) + + assert config.base_url == "http://localhost:8080/v1" + + def test_config_with_custom_params(self): + """Test config with custom temperature and max_tokens.""" + config = ProviderConfig( + name="huggingface", + model_id="medex/classifier", + temperature=0.3, + max_tokens=1024, + ) + + assert config.temperature == 0.3 + assert config.max_tokens == 1024 + + def test_get_api_key_from_env(self): + """Test getting API key from environment variable.""" + config = ProviderConfig( + name="openai", + model_id="gpt-4", + api_key_env="TEST_MEDEX_API_KEY_PROV", + ) + + with patch.dict(os.environ, {"TEST_MEDEX_API_KEY_PROV": "sk-env-key-123"}): + key = config.get_api_key() + assert key == "sk-env-key-123" + + def test_get_api_key_from_file(self, tmp_path): + """Test getting API key from file.""" + key_file = tmp_path / "api_key.txt" + key_file.write_text("sk-file-key-456\n") + + config = ProviderConfig( + name="openai", + model_id="gpt-4", + api_key_file=str(key_file), + ) + + key = config.get_api_key() + assert key == "sk-file-key-456" + + def test_get_api_key_missing(self): + """Test getting API key when none is configured.""" + config = ProviderConfig( + name="openai", + model_id="gpt-4", + ) + + key = config.get_api_key() + assert key is None + + def test_get_api_key_env_takes_precedence(self, tmp_path): + """Test env var takes precedence over file.""" + key_file = tmp_path / "api_key.txt" + key_file.write_text("file-key") + + config = ProviderConfig( + name="test", + model_id="m", + api_key_env="TEST_MEDEX_PRIO_KEY", + api_key_file=str(key_file), + ) + + with patch.dict(os.environ, {"TEST_MEDEX_PRIO_KEY": "env-key"}): + key = config.get_api_key() + assert key == "env-key" + + def test_vision_support_flag(self): + """Test vision support configuration.""" + config = ProviderConfig( + name="vision-provider", + model_id="gpt-4-vision", + supports_vision=True, + ) + + assert config.supports_vision is True + + +# ============================================================================= +# ProviderResponse Tests +# ============================================================================= + + +class TestProviderResponse: + """Tests for ProviderResponse dataclass.""" + + def test_create_success_response(self): + """Test creating a successful response.""" + response = ProviderResponse( + content="Diagnosis: pneumonia", + model="gpt-4", + provider="openai", + tokens_used=150, + ) + + assert response.content == "Diagnosis: pneumonia" + assert response.model == "gpt-4" + assert response.provider == "openai" + assert response.tokens_used == 150 + assert response.success # truthy: content present and no error + + def test_create_error_response(self): + """Test creating an error response.""" + response = ProviderResponse( + content="", + model="gpt-4", + provider="openai", + error="Rate limited", + ) + + assert not response.success + assert response.error == "Rate limited" + + def test_success_requires_content_and_no_error(self): + """Test success property requires both content and no error.""" + # Content but no error → success + ok = ProviderResponse(content="result", model="m", provider="p") + assert ok.success + + # Error even with content → not success + err = ProviderResponse( + content="partial", model="m", provider="p", error="failed" + ) + assert not err.success + + # No content and no error → falsy + empty = ProviderResponse(content="", model="m", provider="p") + assert not empty.success + + def test_response_defaults(self): + """Test response default values.""" + response = ProviderResponse() + + assert response.content == "" + assert response.provider == "" + assert response.model == "" + assert response.status == ProviderStatus.AVAILABLE + assert response.tokens_used == 0 + assert response.error is None + + def test_response_with_status(self): + """Test response with explicit status.""" + response = ProviderResponse( + content="", + provider="openai", + model="gpt-4", + status=ProviderStatus.RATE_LIMITED, + error="Too many requests", + ) + + assert response.status == ProviderStatus.RATE_LIMITED + + +# ============================================================================= +# ModelProvider (ABC) Tests +# ============================================================================= + + +class ConcreteProvider(ModelProvider): + """Concrete implementation for testing abstract base class.""" + + def initialize(self) -> bool: + self._status = ProviderStatus.AVAILABLE + return True + + def generate( + self, + messages: list[dict], + system_prompt: str = "", + max_tokens: int | None = None, + temperature: float | None = None, + ) -> ProviderResponse: + return ProviderResponse( + content=f"Mock response (messages={len(messages)})", + model=self.config.model_id, + provider=self.config.name, + ) + + def stream( + self, + messages: list[dict], + system_prompt: str = "", + max_tokens: int | None = None, + temperature: float | None = None, + ) -> Generator[str, None, ProviderResponse]: + yield "chunk1" + yield "chunk2" + return ProviderResponse( + content="chunk1chunk2", + model=self.config.model_id, + provider=self.config.name, + ) + + +class TestModelProvider: + """Tests for ModelProvider abstract base class.""" + + def _make_provider( + self, name: str = "test", model_id: str = "test-model" + ) -> ConcreteProvider: + config = ProviderConfig(name=name, model_id=model_id) + return ConcreteProvider(config) + + def test_provider_initialization(self): + """Test provider initializes with config.""" + provider = self._make_provider() + + assert provider.config.name == "test" + assert provider.config.model_id == "test-model" + assert provider.status == ProviderStatus.UNKNOWN + + def test_provider_name_property(self): + """Test provider name property.""" + provider = self._make_provider(name="medex-hf") + assert provider.name == "medex-hf" + + def test_provider_model_id_property(self): + """Test provider model_id property.""" + provider = self._make_provider(model_id="gpt-4-turbo") + assert provider.model_id == "gpt-4-turbo" + + def test_provider_status_after_init(self): + """Test provider status is UNKNOWN after __init__.""" + provider = self._make_provider() + assert provider.status == ProviderStatus.UNKNOWN + + def test_provider_initialize(self): + """Test provider initialize sets status to AVAILABLE.""" + provider = self._make_provider() + result = provider.initialize() + + assert result is True + assert provider.status == ProviderStatus.AVAILABLE + + def test_provider_is_available(self): + """Test is_available property.""" + provider = self._make_provider() + assert provider.is_available is False # UNKNOWN initially + + provider.initialize() + assert provider.is_available is True + + def test_provider_generate(self): + """Test provider generate method.""" + provider = self._make_provider() + response = provider.generate(messages=[{"role": "user", "content": "Hello"}]) + + assert response.success + assert "messages=1" in response.content + + def test_update_status_quota_exceeded(self): + """Test status update on quota error (429).""" + provider = self._make_provider() + provider.initialize() + + status = provider._update_status_from_error( + Exception("Error 429: quota exceeded") + ) + assert status == ProviderStatus.QUOTA_EXCEEDED + assert provider.status == ProviderStatus.QUOTA_EXCEEDED + + def test_update_status_auth_error(self): + """Test status update on auth error (401).""" + provider = self._make_provider() + provider.initialize() + + status = provider._update_status_from_error( + Exception("Error 401: Unauthorized") + ) + assert status == ProviderStatus.AUTH_ERROR + assert provider.status == ProviderStatus.AUTH_ERROR + + def test_update_status_rate_limited(self): + """Test status update on rate limit error.""" + provider = self._make_provider() + provider.initialize() + + status = provider._update_status_from_error(Exception("Rate limit exceeded")) + assert status == ProviderStatus.RATE_LIMITED + assert provider.status == ProviderStatus.RATE_LIMITED + + def test_update_status_server_error(self): + """Test status update on generic error.""" + provider = self._make_provider() + provider.initialize() + + status = provider._update_status_from_error(Exception("Internal server error")) + assert status == ProviderStatus.UNAVAILABLE + assert provider.status == ProviderStatus.UNAVAILABLE + + +# ============================================================================= +# Provider Manager Tests +# ============================================================================= + + +class TestProviderManager: + """Tests for ProviderManager.""" + + def test_manager_import(self): + """Test manager can be imported.""" + from medex.providers.manager import ProviderManager + + assert ProviderManager is not None + + def test_manager_initialization(self): + """Test manager creates with default providers.""" + from medex.providers.manager import ProviderManager + + manager = ProviderManager() + # Manager auto-configures default providers (Moonshot, HuggingFace, etc.) + assert manager is not None + assert manager.auto_fallback is True + assert manager.primary_provider is not None + + def test_manager_with_primary(self): + """Test manager with a primary provider.""" + from medex.providers.manager import ProviderManager + + config = ProviderConfig(name="test", model_id="test-model") + provider = ConcreteProvider(config) + + manager = ProviderManager(primary_provider=provider) + assert manager.primary_provider is provider + + def test_manager_get_available_providers_default(self): + """Test listing available providers with default config.""" + from medex.providers.manager import ProviderManager + + manager = ProviderManager() + available = manager.get_available_providers() + assert isinstance(available, list) + # Default manager comes pre-configured with providers + assert len(available) >= 1 + + def test_manager_get_available_providers(self): + """Test listing available providers with initialized provider.""" + from medex.providers.manager import ProviderManager + + config = ProviderConfig(name="test", model_id="test-model") + provider = ConcreteProvider(config) + provider.initialize() + + manager = ProviderManager( + primary_provider=provider, + fallback_providers=[], + ) + available = manager.get_available_providers() + assert len(available) >= 1 + + def test_manager_set_provider(self): + """Test setting current provider.""" + from medex.providers.manager import ProviderManager + + config1 = ProviderConfig(name="provider-a", model_id="model-a") + config2 = ProviderConfig(name="provider-b", model_id="model-b") + prov1 = ConcreteProvider(config1) + prov2 = ConcreteProvider(config2) + prov1.initialize() + prov2.initialize() + + manager = ProviderManager( + primary_provider=prov1, + fallback_providers=[prov2], + ) + + manager.set_provider("provider-b") + assert manager.current_provider is not None + assert manager.current_provider.name == "provider-b" + + +# ============================================================================= +# Run Tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 2a10f319058d0d62572444d1bc341d320873035f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:46 -0300 Subject: [PATCH 5/6] test: add .github/workflows/ci.yml --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 50baae2..ef6de60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,6 +92,10 @@ jobs: tests/test_detection.py \ tests/test_differential_diagnosis.py \ tests/test_medex_logger.py \ + tests/test_tools_v2.py \ + tests/test_vision.py \ + tests/test_knowledge.py \ + tests/test_providers.py \ -v --tb=short \ --continue-on-collection-errors env: @@ -198,6 +202,10 @@ jobs: --ignore=tests/test_detection.py \ --ignore=tests/test_differential_diagnosis.py \ --ignore=tests/test_medex_logger.py \ + --ignore=tests/test_tools_v2.py \ + --ignore=tests/test_vision.py \ + --ignore=tests/test_knowledge.py \ + --ignore=tests/test_providers.py \ --continue-on-collection-errors 2>&1 | tail -50 || true env: MEDEX_ENV: test From 451a63a22928af02368d113945f595545f1bd029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Romero=20=20=20=20=20-=F0=9D=94=87=F0=9D=94=A2?= =?UTF-8?q?=F0=9D=94=A2=F0=9D=94=AD=E2=84=9C=F0=9D=94=9E=F0=9D=94=B1?= Date: Thu, 26 Feb 2026 16:56:48 -0300 Subject: [PATCH 6/6] test: remove broken V1 test_tools.py (replaced by test_tools_v2.py) --- tests/test_tools.py | 829 -------------------------------------------- 1 file changed, 829 deletions(-) delete mode 100644 tests/test_tools.py diff --git a/tests/test_tools.py b/tests/test_tools.py deleted file mode 100644 index bc640a4..0000000 --- a/tests/test_tools.py +++ /dev/null @@ -1,829 +0,0 @@ -# ============================================================================= -# MedeX - Tool System Tests -# ============================================================================= -""" -Comprehensive tests for the MedeX Tool System. - -Tests cover: -- Tool models (ToolParameter, ToolDefinition, ToolCall, ToolResult) -- Tool registry (registration, retrieval, filtering) -- Tool executor (execution, timeout, retry, caching) -- Medical tools (drug interactions, dosage, labs, emergency) -- Tool service (facade integration) -""" - -from __future__ import annotations - -import asyncio -import json -from uuid import uuid4 - -import pytest - -from medex.tools.executor import ToolExecutor, create_tool_executor -from medex.tools.models import ( - ParameterType, - ToolCall, - ToolCategory, - ToolDefinition, - ToolError, - ToolParameter, - ToolResult, - ToolStatus, -) -from medex.tools.registry import ( - ToolRegistry, - number_param, - string_param, - tool, -) - -# ============================================================================= -# Fixtures -# ============================================================================= - - -@pytest.fixture -def sample_parameter() -> ToolParameter: - """Create a sample tool parameter.""" - return ToolParameter( - name="medication", - type=ParameterType.STRING, - description="Name of the medication", - required=True, - ) - - -@pytest.fixture -def sample_definition() -> ToolDefinition: - """Create a sample tool definition.""" - return ToolDefinition( - name="check_medication", - description="Check medication information", - category=ToolCategory.DRUG, - parameters=[ - ToolParameter( - name="medication", - type=ParameterType.STRING, - description="Medication name", - required=True, - ), - ToolParameter( - name="dosage", - type=ParameterType.NUMBER, - description="Dosage in mg", - required=False, - ), - ], - handler=lambda medication, dosage=None: { - "medication": medication, - "dosage": dosage, - }, - tags=["medication", "drug"], - ) - - -@pytest.fixture -def fresh_registry() -> ToolRegistry: - """Create a fresh registry for testing.""" - return ToolRegistry() - - -@pytest.fixture -def executor() -> ToolExecutor: - """Create an executor for testing.""" - return create_tool_executor(max_concurrent=3, default_timeout=5.0) - - -# ============================================================================= -# ToolParameter Tests -# ============================================================================= - - -class TestToolParameter: - """Tests for ToolParameter model.""" - - def test_create_string_parameter(self): - """Test creating a string parameter.""" - param = ToolParameter( - name="patient_name", - type=ParameterType.STRING, - description="Patient's full name", - required=True, - ) - - assert param.name == "patient_name" - assert param.type == ParameterType.STRING - assert param.required is True - - def test_create_number_parameter_with_range(self): - """Test creating a number parameter with min/max.""" - param = ToolParameter( - name="age", - type=ParameterType.INTEGER, - description="Patient age", - required=True, - minimum=0, - maximum=150, - ) - - assert param.minimum == 0 - assert param.maximum == 150 - - def test_create_enum_parameter(self): - """Test creating an enum parameter.""" - param = ToolParameter( - name="severity", - type=ParameterType.STRING, - description="Severity level", - required=True, - enum=["low", "medium", "high"], - ) - - assert param.enum == ["low", "medium", "high"] - - def test_to_json_schema(self, sample_parameter: ToolParameter): - """Test conversion to JSON Schema.""" - schema = sample_parameter.to_json_schema() - - assert schema["type"] == "string" - assert schema["description"] == "Name of the medication" - - def test_to_json_schema_with_constraints(self): - """Test JSON Schema with constraints.""" - param = ToolParameter( - name="weight", - type=ParameterType.NUMBER, - description="Weight in kg", - required=True, - minimum=0.5, - maximum=500, - ) - - schema = param.to_json_schema() - - assert schema["type"] == "number" - assert schema["minimum"] == 0.5 - assert schema["maximum"] == 500 - - -# ============================================================================= -# ToolDefinition Tests -# ============================================================================= - - -class TestToolDefinition: - """Tests for ToolDefinition model.""" - - def test_create_definition(self, sample_definition: ToolDefinition): - """Test creating a tool definition.""" - assert sample_definition.name == "check_medication" - assert sample_definition.category == ToolCategory.DRUG - assert len(sample_definition.parameters) == 2 - assert sample_definition.status == ToolStatus.ENABLED - - def test_to_openai_format(self, sample_definition: ToolDefinition): - """Test conversion to OpenAI function calling format.""" - openai_format = sample_definition.to_openai_format() - - assert openai_format["type"] == "function" - assert openai_format["function"]["name"] == "check_medication" - assert "parameters" in openai_format["function"] - assert openai_format["function"]["parameters"]["type"] == "object" - assert "medication" in openai_format["function"]["parameters"]["properties"] - - def test_to_anthropic_format(self, sample_definition: ToolDefinition): - """Test conversion to Anthropic tool format.""" - anthropic_format = sample_definition.to_anthropic_format() - - assert anthropic_format["name"] == "check_medication" - assert "input_schema" in anthropic_format - assert anthropic_format["input_schema"]["type"] == "object" - assert "medication" in anthropic_format["input_schema"]["properties"] - - def test_validate_arguments_valid(self, sample_definition: ToolDefinition): - """Test argument validation with valid arguments.""" - errors = sample_definition.validate_arguments({"medication": "aspirin"}) - assert len(errors) == 0 - - def test_validate_arguments_missing_required( - self, sample_definition: ToolDefinition - ): - """Test argument validation with missing required parameter.""" - errors = sample_definition.validate_arguments({}) - assert len(errors) > 0 - assert "medication" in errors[0].lower() - - -# ============================================================================= -# ToolCall Tests -# ============================================================================= - - -class TestToolCall: - """Tests for ToolCall model.""" - - def test_create_tool_call(self): - """Test creating a tool call.""" - call = ToolCall( - id=str(uuid4()), - name="check_drug_interactions", - arguments={"drugs": ["aspirin", "ibuprofen"]}, - ) - - assert call.name == "check_drug_interactions" - assert call.arguments["drugs"] == ["aspirin", "ibuprofen"] - - def test_from_openai_response(self): - """Test creating from OpenAI response.""" - openai_response = { - "id": "call_123", - "function": { - "name": "calculate_dose", - "arguments": '{"drug": "amoxicillin", "weight": 70}', - }, - } - - call = ToolCall.from_openai_response(openai_response) - - assert call.id == "call_123" - assert call.name == "calculate_dose" - assert call.arguments["drug"] == "amoxicillin" - assert call.arguments["weight"] == 70 - - def test_from_anthropic_response(self): - """Test creating from Anthropic response.""" - anthropic_response = { - "id": "toolu_123", - "name": "interpret_lab", - "input": {"hemoglobin": 12.5, "sex": "male"}, - } - - call = ToolCall.from_anthropic_response(anthropic_response) - - assert call.id == "toolu_123" - assert call.name == "interpret_lab" - assert call.arguments["hemoglobin"] == 12.5 - - -# ============================================================================= -# ToolResult Tests -# ============================================================================= - - -class TestToolResult: - """Tests for ToolResult model.""" - - def test_create_success_result(self): - """Test creating a successful result.""" - result = ToolResult( - call_id="call_123", - tool_name="calculate_dose", - success=True, - data={"dose": 500, "unit": "mg", "frequency": "q8h"}, - ) - - assert result.success is True - assert result.data["dose"] == 500 - assert result.error is None - - def test_create_error_result(self): - """Test creating an error result.""" - error = ToolError( - code="VALIDATION_ERROR", - message="Invalid weight: must be positive", - ) - - result = ToolResult( - call_id="call_456", - tool_name="calculate_dose", - success=False, - error=error, - ) - - assert result.success is False - assert result.error.code == "VALIDATION_ERROR" - - def test_to_llm_format(self): - """Test formatting result for LLM.""" - result = ToolResult( - call_id="call_789", - tool_name="check_interactions", - success=True, - data={ - "interactions": [ - {"drug1": "warfarin", "drug2": "aspirin", "severity": "high"} - ] - }, - ) - - llm_format = result.to_llm_format() - - assert "tool_call_id" in llm_format or "id" in llm_format - assert "content" in llm_format - - # Content should be JSON string - content = json.loads(llm_format["content"]) - assert "interactions" in content - - -# ============================================================================= -# ToolRegistry Tests -# ============================================================================= - - -class TestToolRegistry: - """Tests for ToolRegistry.""" - - def test_register_tool(self, fresh_registry: ToolRegistry): - """Test registering a tool.""" - - async def my_tool(x: str) -> dict: - return {"result": x} - - definition = ToolDefinition( - name="my_tool", - description="A test tool", - category=ToolCategory.UTILITY, - parameters=[], - handler=my_tool, - ) - - fresh_registry.register(definition) - - assert fresh_registry.get_tool("my_tool") is not None - - def test_get_nonexistent_tool(self, fresh_registry: ToolRegistry): - """Test getting a tool that doesn't exist.""" - result = fresh_registry.get_tool("nonexistent") - assert result is None - - def test_get_tools_by_category(self, fresh_registry: ToolRegistry): - """Test filtering tools by category.""" - # Register tools in different categories - for i, category in enumerate( - [ToolCategory.DRUG, ToolCategory.DRUG, ToolCategory.LAB] - ): - - async def handler() -> dict: - return {} - - fresh_registry.register( - ToolDefinition( - name=f"tool_{i}", - description=f"Tool {i}", - category=category, - parameters=[], - handler=handler, - ) - ) - - drug_tools = fresh_registry.get_tools_by_category(ToolCategory.DRUG) - assert len(drug_tools) == 2 - - lab_tools = fresh_registry.get_tools_by_category(ToolCategory.LAB) - assert len(lab_tools) == 1 - - def test_enable_disable_tool( - self, fresh_registry: ToolRegistry, sample_definition: ToolDefinition - ): - """Test enabling and disabling a tool.""" - fresh_registry.register(sample_definition) - - # Disable - assert fresh_registry.disable_tool("check_medication") is True - tool = fresh_registry.get_tool("check_medication") - assert tool.status == ToolStatus.DISABLED - - # Enable - assert fresh_registry.enable_tool("check_medication") is True - tool = fresh_registry.get_tool("check_medication") - assert tool.status == ToolStatus.ENABLED - - def test_decorator_registration(self, fresh_registry: ToolRegistry): - """Test registering tool via decorator.""" - - @tool( - name="decorated_tool", - description="A tool registered via decorator", - category=ToolCategory.UTILITY, - parameters=[string_param("input", "The input")], - registry=fresh_registry, - ) - async def decorated_tool(input: str) -> dict: - return {"output": input.upper()} - - registered = fresh_registry.get_tool("decorated_tool") - assert registered is not None - assert registered.name == "decorated_tool" - - -# ============================================================================= -# ToolExecutor Tests -# ============================================================================= - - -class TestToolExecutor: - """Tests for ToolExecutor.""" - - @pytest.mark.asyncio - async def test_execute_simple_tool(self, executor: ToolExecutor): - """Test executing a simple tool.""" - - # Register a tool - async def greet(name: str) -> dict: - return {"greeting": f"Hello, {name}!"} - - definition = ToolDefinition( - name="greet", - description="Greet someone", - category=ToolCategory.UTILITY, - parameters=[string_param("name", "Name to greet")], - handler=greet, - ) - - executor._registry.register(definition) - - # Execute - call = ToolCall(id="call_1", name="greet", arguments={"name": "MedeX"}) - result = await executor.execute(call) - - assert result.success is True - assert result.data["greeting"] == "Hello, MedeX!" - - @pytest.mark.asyncio - async def test_execute_with_validation_error(self, executor: ToolExecutor): - """Test execution with invalid arguments.""" - - async def calculate(x: float, y: float) -> dict: - return {"sum": x + y} - - definition = ToolDefinition( - name="calculate", - description="Calculate sum", - category=ToolCategory.UTILITY, - parameters=[ - number_param("x", "First number", required=True), - number_param("y", "Second number", required=True), - ], - handler=calculate, - ) - - executor._registry.register(definition) - - # Missing required argument - call = ToolCall(id="call_2", name="calculate", arguments={"x": 5}) - result = await executor.execute(call) - - assert result.success is False - assert result.error is not None - - @pytest.mark.asyncio - async def test_execute_timeout(self, executor: ToolExecutor): - """Test execution timeout.""" - - async def slow_tool() -> dict: - await asyncio.sleep(10) # Will timeout - return {"done": True} - - definition = ToolDefinition( - name="slow_tool", - description="A slow tool", - category=ToolCategory.UTILITY, - parameters=[], - handler=slow_tool, - ) - - executor._registry.register(definition) - - call = ToolCall(id="call_3", name="slow_tool", arguments={}) - result = await executor.execute(call, timeout=0.1) - - assert result.success is False - assert ( - "timeout" in result.error.code.lower() - or "timeout" in result.error.message.lower() - ) - - @pytest.mark.asyncio - async def test_execute_batch(self, executor: ToolExecutor): - """Test batch execution.""" - - async def echo(msg: str) -> dict: - return {"echo": msg} - - definition = ToolDefinition( - name="echo", - description="Echo message", - category=ToolCategory.UTILITY, - parameters=[string_param("msg", "Message")], - handler=echo, - ) - - executor._registry.register(definition) - - calls = [ - ToolCall(id="call_a", name="echo", arguments={"msg": "one"}), - ToolCall(id="call_b", name="echo", arguments={"msg": "two"}), - ToolCall(id="call_c", name="echo", arguments={"msg": "three"}), - ] - - results = await executor.execute_batch(calls) - - assert len(results) == 3 - assert all(r.success for r in results) - assert results[0].data["echo"] == "one" - assert results[1].data["echo"] == "two" - assert results[2].data["echo"] == "three" - - @pytest.mark.asyncio - async def test_get_metrics(self, executor: ToolExecutor): - """Test metrics collection.""" - - async def metric_tool() -> dict: - return {"ok": True} - - definition = ToolDefinition( - name="metric_tool", - description="Tool for metrics", - category=ToolCategory.UTILITY, - parameters=[], - handler=metric_tool, - ) - - executor._registry.register(definition) - - # Execute a few times - for _ in range(3): - call = ToolCall(id=str(uuid4()), name="metric_tool", arguments={}) - await executor.execute(call) - - metrics = executor.get_metrics() - - assert metrics["total_calls"] == 3 - assert metrics["successful_calls"] == 3 - assert "metric_tool" in metrics.get("by_tool", {}) - - -# ============================================================================= -# Medical Tools Tests -# ============================================================================= - - -class TestDrugInteractionTools: - """Tests for drug interaction tools.""" - - @pytest.mark.asyncio - async def test_check_drug_interactions(self): - """Test checking drug interactions.""" - from medex.tools.medical.drug_interactions import check_drug_interactions - - result = await check_drug_interactions(["warfarin", "aspirin"]) - - assert "interactions" in result - assert len(result["interactions"]) > 0 - assert result["interactions"][0]["severity"] in ["alta", "moderada", "baja"] - - @pytest.mark.asyncio - async def test_check_drug_interactions_no_interactions(self): - """Test with drugs that don't interact.""" - from medex.tools.medical.drug_interactions import check_drug_interactions - - result = await check_drug_interactions(["amoxicillin", "paracetamol"]) - - assert "interactions" in result - # May or may not have interactions, but should not error - - @pytest.mark.asyncio - async def test_get_drug_info(self): - """Test getting drug information.""" - from medex.tools.medical.drug_interactions import get_drug_info - - result = await get_drug_info("metformin") - - assert "drug" in result - assert "found" in result - - -class TestDosageCalculatorTools: - """Tests for dosage calculator tools.""" - - @pytest.mark.asyncio - async def test_calculate_pediatric_dose(self): - """Test pediatric dose calculation.""" - from medex.tools.medical.dosage_calculator import calculate_pediatric_dose - - result = await calculate_pediatric_dose( - drug="amoxicillin", - weight_kg=20, - indication="otitis", - ) - - assert "drug" in result - assert "calculations" in result or "error" in result - - @pytest.mark.asyncio - async def test_calculate_bsa(self): - """Test BSA calculation.""" - from medex.tools.medical.dosage_calculator import calculate_bsa - - result = await calculate_bsa( - weight_kg=70, - height_cm=175, - formula="mosteller", - ) - - assert "bsa_m2" in result - assert result["bsa_m2"] > 0 - assert result["bsa_m2"] < 3 # Reasonable range - - @pytest.mark.asyncio - async def test_calculate_creatinine_clearance(self): - """Test CrCl calculation.""" - from medex.tools.medical.dosage_calculator import calculate_creatinine_clearance - - result = await calculate_creatinine_clearance( - age_years=65, - weight_kg=70, - serum_creatinine=1.2, - sex="male", - ) - - assert "crcl_ml_min" in result - assert result["crcl_ml_min"] > 0 - assert "stage" in result - - -class TestLabInterpreterTools: - """Tests for lab interpreter tools.""" - - @pytest.mark.asyncio - async def test_interpret_cbc(self): - """Test CBC interpretation.""" - from medex.tools.medical.lab_interpreter import interpret_cbc - - result = await interpret_cbc( - hemoglobin=10.5, - sex="female", - age_years=45, - mcv=75, - ) - - assert "interpretations" in result - assert "clinical_findings" in result - assert "differential_diagnoses" in result - - @pytest.mark.asyncio - async def test_interpret_liver_panel(self): - """Test liver panel interpretation.""" - from medex.tools.medical.lab_interpreter import interpret_liver_panel - - result = await interpret_liver_panel( - alt=150, - ast=120, - alp=90, - ) - - assert "pattern" in result - assert "de_ritis_ratio" in result - assert result["pattern"] in ["normal", "hepatocelular", "colestásico", "mixto"] - - @pytest.mark.asyncio - async def test_interpret_thyroid_panel(self): - """Test thyroid panel interpretation.""" - from medex.tools.medical.lab_interpreter import interpret_thyroid_panel - - result = await interpret_thyroid_panel( - tsh=0.1, - t4_free=2.5, - ) - - assert "thyroid_status" in result - assert "interpretations" in result - - -class TestEmergencyDetectorTools: - """Tests for emergency detector tools.""" - - @pytest.mark.asyncio - async def test_detect_emergency_cardiac(self): - """Test detection of cardiac emergency.""" - from medex.tools.medical.emergency_detector import detect_emergency - - result = await detect_emergency( - symptoms=["dolor torácico", "sudoración", "disnea"], - onset="súbito", - duration="30 minutos", - ) - - assert result["emergency_detected"] is True - assert result["triage"]["level"] <= 2 # High urgency - - @pytest.mark.asyncio - async def test_detect_emergency_non_urgent(self): - """Test with non-urgent symptoms.""" - from medex.tools.medical.emergency_detector import detect_emergency - - result = await detect_emergency( - symptoms=["tos leve", "congestión nasal"], - onset="gradual", - duration="3 días", - ) - - # Should not be marked as emergency - assert result["triage"]["level"] >= 4 - - @pytest.mark.asyncio - async def test_check_critical_values(self): - """Test critical lab value detection.""" - from medex.tools.medical.emergency_detector import check_critical_values - - result = await check_critical_values( - lab_values={"potassium": 7.0, "glucose": 40} - ) - - assert result["has_critical_values"] is True - assert len(result["critical_alerts"]) >= 1 - - @pytest.mark.asyncio - async def test_quick_triage(self): - """Test quick triage assessment.""" - from medex.tools.medical.emergency_detector import quick_triage - - result = await quick_triage( - chief_complaint="dolor abdominal severo", - severity="severo", - duration_hours=2, - worsening=True, - ) - - assert "triage_level" in result - assert "triage_color" in result - assert "recommendation" in result - - -# ============================================================================= -# Integration Tests -# ============================================================================= - - -class TestToolServiceIntegration: - """Integration tests for the complete tool system.""" - - @pytest.mark.asyncio - async def test_full_workflow(self): - """Test complete workflow from registration to execution.""" - from medex.tools.service import ToolService - - service = ToolService() - await service.initialize() - - try: - # Get tools for LLM - tools = service.get_tools_for_llm(format="openai") - assert len(tools) > 0 - - # Find and execute a tool - tool_call = ToolCall( - id="test_call", - name="calculate_bsa", - arguments={ - "weight_kg": 70, - "height_cm": 175, - "formula": "mosteller", - }, - ) - - result = await service.execute(tool_call) - assert result.success is True - assert "bsa_m2" in result.data - - finally: - await service.shutdown() - - @pytest.mark.asyncio - async def test_get_medical_tools(self): - """Test getting all medical tools.""" - from medex.tools.service import ToolService - - service = ToolService() - await service.initialize() - - try: - medical_tools = service.get_medical_tools() - assert len(medical_tools) > 0 - - # Check categories - categories = {t.category for t in medical_tools} - assert ToolCategory.DRUG in categories or ToolCategory.LAB in categories - - finally: - await service.shutdown() - - -# ============================================================================= -# Run Tests -# ============================================================================= - -if __name__ == "__main__": - pytest.main([__file__, "-v"])