diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..31f4212 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +# ============================================================================= +# MedeX — Pre-commit Hooks Configuration +# ============================================================================= +# Install: pip install pre-commit && pre-commit install +# Run all: pre-commit run --all-files +# Update: pre-commit autoupdate +# ============================================================================= + +repos: + # ── Black (code formatter) ───────────────────────────────────────────── + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + language_version: python3 + args: [--config=pyproject.toml] + + # ── Ruff (linter + import sorter) ────────────────────────────────────── + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.6 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + # ── General file hygiene ─────────────────────────────────────────────── + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-json + exclude: ".vscode/" + - id: check-added-large-files + args: [--maxkb=500] + - id: check-merge-conflict + - id: debug-statements + - id: detect-private-key + + # ── Security (bandit) ────────────────────────────────────────────────── + - repo: https://github.com/PyCQA/bandit + rev: 1.8.3 + hooks: + - id: bandit + args: [-r, src/, -c, pyproject.toml, -ll] + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index f66ba89..a6e2ee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,8 @@ exclude = ''' [tool.ruff] line-length = 88 target-version = "py310" + +[tool.ruff.lint] select = [ "E", # pycodestyle errors "W", # pycodestyle warnings @@ -108,7 +110,7 @@ ignore = [ "B008", # do not perform function calls in argument defaults ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["medex"] [tool.mypy] diff --git a/src/medex/agent/controller.py b/src/medex/agent/controller.py index 51de0cd..a1aa214 100644 --- a/src/medex/agent/controller.py +++ b/src/medex/agent/controller.py @@ -619,7 +619,7 @@ def set_rag_service(self, service: Any) -> None: def set_memory_service(self, service: Any) -> None: """Set memory service.""" - self.memory_service = memory_service + self.memory_service = service # ========================================================================= # Main Entry Points @@ -657,7 +657,7 @@ async def process( logger.warning(f"Failed to load history: {e}") # Initialize state - state = self.state_manager.initialize(context) + self.state_manager.initialize(context) try: # Run agent loop @@ -713,7 +713,7 @@ async def process_stream( ) # Initialize state - state = self.state_manager.initialize(context) + self.state_manager.initialize(context) # Setup event forwarding event_queue: asyncio.Queue[AgentEvent] = asyncio.Queue() @@ -837,7 +837,7 @@ async def _handle_emergency(self, context: AgentContext) -> AgentResult: # Add any tool results if context.tool_results: emergency_response += "**Assessment:**\n" - for tool, result in context.tool_results.items(): + for _tool, result in context.tool_results.items(): if isinstance(result, dict): emergency_response += f"- {result.get('recommendation', '')}\n" diff --git a/src/medex/api/__init__.py b/src/medex/api/__init__.py index 27f24d1..d5ae3ed 100644 --- a/src/medex/api/__init__.py +++ b/src/medex/api/__init__.py @@ -82,13 +82,11 @@ ConnectionState, WebSocketHandler, WSCloseCode, -) -from .websocket import WSMessage as WebSocketMessage -from .websocket import ( WSMessageType, create_connection_manager, create_websocket_handler, ) +from .websocket import WSMessage as WebSocketMessage __all__ = [ # App diff --git a/src/medex/api/middleware.py b/src/medex/api/middleware.py index 32485e8..5a8e655 100644 --- a/src/medex/api/middleware.py +++ b/src/medex/api/middleware.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - pass + from .models import ErrorCode # ============================================================================= diff --git a/src/medex/api/routes/tools.py b/src/medex/api/routes/tools.py index 8a157b7..8e93cd6 100644 --- a/src/medex/api/routes/tools.py +++ b/src/medex/api/routes/tools.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from datetime import datetime from typing import Any @@ -133,7 +134,7 @@ async def check_interactions( """ try: # Call the tool directly (synchronous, uses local dictionary) - result: DrugInteractionResult = await check_drug_interactions( + result: DrugInteractionResult = await check_drug_interactions( # noqa: F821 drugs=request.drugs ) @@ -161,7 +162,7 @@ async def check_interactions( logger.error(f"Drug interaction check failed: {e}") raise HTTPException( status_code=500, detail=f"Error checking interactions: {str(e)}" - ) + ) from e # ============================================================================= @@ -184,7 +185,7 @@ async def calculate_drug_dosage(request: DosageRequest) -> DosageResponse: """ try: # Call the tool directly - result: DosageResult = await calculate_dosage( + result: DosageResult = await calculate_dosage( # noqa: F821 drug_name=request.drug_name, patient_weight_kg=request.patient_weight, patient_age_years=request.patient_age, @@ -207,7 +208,7 @@ async def calculate_drug_dosage(request: DosageRequest) -> DosageResponse: logger.error(f"Dosage calculation failed: {e}") raise HTTPException( status_code=500, detail=f"Error calculating dosage: {str(e)}" - ) + ) from e # ============================================================================= @@ -230,7 +231,7 @@ async def interpret_lab(request: LabInterpretRequest) -> LabInterpretResponse: """ try: # Call the tool directly - result: LabInterpretationResult = await interpret_lab_value( + result: LabInterpretationResult = await interpret_lab_value( # noqa: F821 test_name=request.test_name, value=request.value, unit=request.unit, @@ -254,7 +255,7 @@ async def interpret_lab(request: LabInterpretRequest) -> LabInterpretResponse: logger.error(f"Lab interpretation failed: {e}") raise HTTPException( status_code=500, detail=f"Error interpreting lab value: {str(e)}" - ) + ) from e # ============================================================================= diff --git a/src/medex/db/repositories.py b/src/medex/db/repositories.py index ae1fa54..9c74473 100644 --- a/src/medex/db/repositories.py +++ b/src/medex/db/repositories.py @@ -981,8 +981,9 @@ async def get_metrics_by_tool( func.avg(ToolExecution.latency_ms).label("avg_latency"), func.sum( func.cast( - ToolExecution.cache_hit == True, type_=func.Integer - ) # noqa: E712 + ToolExecution.cache_hit == True, # noqa: E712 + type_=func.Integer, + ) ).label("cache_hits"), ) .where(ToolExecution.tool_name == tool_name) diff --git a/src/medex/llm/parser.py b/src/medex/llm/parser.py index 83aae87..c4b6be5 100644 --- a/src/medex/llm/parser.py +++ b/src/medex/llm/parser.py @@ -547,7 +547,7 @@ def _extract_entities(self, content: str) -> dict[str, list[str]]: # Extract CIE-10 codes if self.config.extract_cie10: cie10_matches = CIE10_PATTERN.findall(content) - entities["cie10_codes"] = list(set(c.upper() for c in cie10_matches)) + entities["cie10_codes"] = list({c.upper() for c in cie10_matches}) # Extract drug mentions if self.config.extract_drugs: diff --git a/src/medex/llm/router.py b/src/medex/llm/router.py index 854fed3..94fa736 100644 --- a/src/medex/llm/router.py +++ b/src/medex/llm/router.py @@ -472,7 +472,7 @@ async def stream(self, request: LLMRequest) -> AsyncIterator[StreamChunk]: # Emit finish event latency_ms = (time.time() - start_time) * 1000 - ttft_ms = ( + _ttft_ms = ( (first_token_time - start_time) * 1000 if first_token_time else None ) diff --git a/src/medex/llm/service.py b/src/medex/llm/service.py index 0340227..ad3dfbc 100644 --- a/src/medex/llm/service.py +++ b/src/medex/llm/service.py @@ -278,7 +278,7 @@ async def query( ) # Execute request - start_time = time.time() + _start_time = time.time() if use_stream: # Collect streaming response diff --git a/src/medex/medical/formatter.py b/src/medex/medical/formatter.py index 0e50913..85a5c02 100644 --- a/src/medex/medical/formatter.py +++ b/src/medex/medical/formatter.py @@ -26,6 +26,7 @@ DiagnosticHypothesis, DiagnosticPlan, Medication, + Specialty, TreatmentPlan, TriageAssessment, TriageLevel, diff --git a/src/medex/medical/models.py b/src/medex/medical/models.py index 456ab14..91c78d5 100644 --- a/src/medex/medical/models.py +++ b/src/medex/medical/models.py @@ -556,7 +556,7 @@ def to_dict(self) -> dict[str, Any]: "symptoms": [s.to_dict() for s in self.symptoms], "vital_signs": self.vital_signs.to_dict() if self.vital_signs else None, "physical_exam": self.physical_exam, - "lab_values": [l.to_dict() for l in self.lab_values], + "lab_values": [lv.to_dict() for lv in self.lab_values], "imaging": self.imaging, "triage": self.triage.to_dict() if self.triage else None, "differential_diagnosis": [ diff --git a/src/medex/medical/reasoner.py b/src/medex/medical/reasoner.py index 5dbfc20..61a5bc7 100644 --- a/src/medex/medical/reasoner.py +++ b/src/medex/medical/reasoner.py @@ -340,7 +340,7 @@ def analyze(self, case: ClinicalCase) -> list[DiagnosticHypothesis]: # Build hypotheses hypotheses = [] - for dx_id, score, dx_pattern in scored_diagnoses[ + for _dx_id, score, dx_pattern in scored_diagnoses[ : self.config.max_differential ]: hypothesis = self._build_hypothesis( @@ -425,7 +425,7 @@ def _calculate_diagnosis_score( pattern_labs = dx_pattern.get("labs", []) if pattern_labs: lab_text = " ".join(lab_findings) - matches = sum(1 for l in pattern_labs if l in lab_text) + matches = sum(1 for lab in pattern_labs if lab in lab_text) lab_score = matches / len(pattern_labs) elif lab_findings: lab_score = 0.1 # Some credit for having labs @@ -464,9 +464,9 @@ def _build_hypothesis( if s in symptoms_text: supporting.append(f"Presenta: {s}") - for l in dx_pattern.get("labs", []): - if any(l in lf for lf in lab_findings): - supporting.append(f"Laboratorio: {l}") + for lab in dx_pattern.get("labs", []): + if any(lab in lf for lf in lab_findings): + supporting.append(f"Laboratorio: {lab}") # Determine next steps next_steps = self._get_diagnostic_steps(dx_pattern, case) @@ -539,7 +539,7 @@ def _get_diagnostic_steps( ) # Always add basic labs if not already ordered - existing_labs = {l.name.lower() for l in case.lab_values} + existing_labs = {lv.name.lower() for lv in case.lab_values} if "hemograma" not in existing_labs and "hemoglobina" not in existing_labs: steps.append("Hemograma completo") if "bioquímica" not in existing_labs: diff --git a/src/medex/rag/chunker.py b/src/medex/rag/chunker.py index 27d5e5b..3484511 100644 --- a/src/medex/rag/chunker.py +++ b/src/medex/rag/chunker.py @@ -147,7 +147,7 @@ def chunk(self, document: Document) -> list[Chunk]: sections = self._split_into_sections(content) char_offset = 0 - for section_idx, (header, section_content) in enumerate(sections): + for _section_idx, (header, section_content) in enumerate(sections): section_chunks = self._chunk_section( section_content, document_id=document.id, diff --git a/src/medex/rag/embedder.py b/src/medex/rag/embedder.py index e117929..74895d1 100644 --- a/src/medex/rag/embedder.py +++ b/src/medex/rag/embedder.py @@ -139,7 +139,7 @@ async def embed_chunks(self, chunks: list[Chunk]) -> list[Chunk]: texts = [c.content for c in chunks] embeddings = await self.embed_texts(texts) - for chunk, embedding in zip(chunks, embeddings): + for chunk, embedding in zip(chunks, embeddings, strict=False): chunk.embedding = embedding.vector return chunks @@ -300,7 +300,7 @@ async def embed_texts(self, texts: list[str]) -> list[Embedding]: all_vectors.extend(batch_vectors) # Fill in results and cache - for idx, vector in zip(uncached_indices, all_vectors): + for idx, vector in zip(uncached_indices, all_vectors, strict=False): embeddings[idx] = Embedding( vector=vector, model=self.config.model_name, @@ -428,7 +428,7 @@ async def embed_texts(self, texts: list[str]) -> list[Embedding]: all_vectors.extend(batch_vectors) # Fill results - for idx, vector in zip(uncached_indices, all_vectors): + for idx, vector in zip(uncached_indices, all_vectors, strict=False): if self.config.normalize: vector = self._normalize(vector) diff --git a/src/medex/rag/reranker.py b/src/medex/rag/reranker.py index c6de71a..fd738ac 100644 --- a/src/medex/rag/reranker.py +++ b/src/medex/rag/reranker.py @@ -156,7 +156,7 @@ async def rerank( normalized_scores = [(s - min_score) / score_range for s in scores] # Update results with rerank scores - for result, rerank_score in zip(results, normalized_scores): + for result, rerank_score in zip(results, normalized_scores, strict=False): result.rerank_score = float(rerank_score) result.relevance = self._score_to_relevance(rerank_score) @@ -540,7 +540,8 @@ async def rerank( chunk_scores = all_scores[result.chunk.id] if chunk_scores: weighted_score = ( - sum(s * w for s, w in zip(chunk_scores, weights)) / total_weight + sum(s * w for s, w in zip(chunk_scores, weights, strict=False)) + / total_weight ) result.rerank_score = weighted_score result.relevance = self._score_to_relevance(weighted_score) diff --git a/src/medex/security/audit.py b/src/medex/security/audit.py index 1d9f4bd..c03a300 100644 --- a/src/medex/security/audit.py +++ b/src/medex/security/audit.py @@ -56,7 +56,7 @@ async def count(self, query: AuditQuery) -> int: """Count matching events.""" pass - async def close(self) -> None: + async def close(self) -> None: # noqa: B027 """Close backend connection.""" pass diff --git a/src/medex/tools/medical/emergency_detector.py b/src/medex/tools/medical/emergency_detector.py index 906b0cb..de9652d 100644 --- a/src/medex/tools/medical/emergency_detector.py +++ b/src/medex/tools/medical/emergency_detector.py @@ -347,7 +347,7 @@ async def detect_emergency( symptoms_text = " ".join(symptoms_lower) # Check each red flag pattern - for flag_id, flag_data in RED_FLAGS.items(): + for _flag_id, flag_data in RED_FLAGS.items(): matched = False # Check primary symptoms @@ -596,7 +596,7 @@ async def check_critical_values( normalized_values[normalized_key] = value # Check each critical value - for crit_id, crit_data in CRITICAL_VALUES.items(): + for _crit_id, crit_data in CRITICAL_VALUES.items(): param = crit_data["parameter"].lower().replace(" ", "_") # Check if this parameter is in the provided values diff --git a/tests/e2e/test_ui_e2e.py b/tests/e2e/test_ui_e2e.py index a546da0..abd8efc 100644 --- a/tests/e2e/test_ui_e2e.py +++ b/tests/e2e/test_ui_e2e.py @@ -200,9 +200,9 @@ def test_page_has_lang_attribute(self, page: Page, medex_url: str): """HTML has lang attribute""" page.goto(medex_url, wait_until="networkidle") - html_lang = page.locator("html").get_attribute("lang") - # Streamlit may not set this, so we just check it exists - # assert html_lang is not None + _html_lang = page.locator("html").get_attribute("lang") + # Reflex should set lang attribute + # assert _html_lang is not None @pytest.mark.e2e def test_images_have_alt_text(self, page: Page, medex_url: str): diff --git a/tests/test_differential_diagnosis.py b/tests/test_differential_diagnosis.py index 5a15fa6..1b2ab2d 100644 --- a/tests/test_differential_diagnosis.py +++ b/tests/test_differential_diagnosis.py @@ -206,7 +206,7 @@ def test_fuzzy_match_synonyms(self): ("vomitos", "náuseas"), # Similar ] - for query, expected_symptom in test_cases: + for query, _expected_symptom in test_cases: result = get_differential_for_symptom(query) # Puede que no todos hagan match fuzzy, pero verificamos que no crashee # y que los que sí existen funcionen diff --git a/tests/test_infrastructure.py b/tests/test_infrastructure.py index 1f8ece4..3e3c076 100644 --- a/tests/test_infrastructure.py +++ b/tests/test_infrastructure.py @@ -508,14 +508,14 @@ async def test_full_conversation_flow(self, db_session, redis_client): messages_for_cache = [] - for i, (role, content) in enumerate( + for _i, (role, content) in enumerate( [ (MessageRole.USER, "What is hypertension?"), (MessageRole.ASSISTANT, "Hypertension is high blood pressure..."), (MessageRole.USER, "What are the symptoms?"), ] ): - msg = await msg_repo.create_message( + msg = await msg_repo.create_message( # noqa: F841 conversation_id=conv.id, role=role, content=content, diff --git a/tests/test_medex_logger.py b/tests/test_medex_logger.py index 7ea0ab1..8d5bd1d 100644 --- a/tests/test_medex_logger.py +++ b/tests/test_medex_logger.py @@ -67,7 +67,7 @@ class TestLoggerInitialization: def test_logger_creates_log_directory(self, temp_log_dir): """Verifica que se crea el directorio de logs.""" log_path = Path(temp_log_dir) / "new_logs" - logger = MedeXLogger(log_dir=str(log_path), enable_console=False) + _logger = MedeXLogger(log_dir=str(log_path), enable_console=False) assert log_path.exists() def test_logger_generates_session_id(self, logger): @@ -426,7 +426,7 @@ def test_audit_trail_entry(self): d = audit.to_dict() assert d["action"] == "test_action" - assert d["success"] == True + assert d["success"] is True # ============================================================================ diff --git a/tests/test_observability.py b/tests/test_observability.py index 55f6cc6..538ac42 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -267,8 +267,8 @@ def test_create_span(self, tracer: Tracer) -> None: def test_nested_spans(self, tracer: Tracer) -> None: """Test nested spans share trace ID.""" - with tracer.start_span("parent") as parent: - with tracer.start_span("child") as child: + with tracer.start_span("parent") as parent: # noqa: F841 + with tracer.start_span("child") as child: # noqa: F841 pass tracer.flush() @@ -300,7 +300,7 @@ def test_span_attributes(self, tracer: Tracer) -> None: def test_span_error(self, tracer: Tracer) -> None: """Test span error handling.""" try: - with tracer.start_span("error_span") as span: + with tracer.start_span("error_span") as span: # noqa: F841 raise ValueError("Test error") except ValueError: pass @@ -316,7 +316,7 @@ def test_span_duration(self, tracer: Tracer) -> None: """Test span duration calculation.""" import time - with tracer.start_span("timed_span") as span: + with tracer.start_span("timed_span") as span: # noqa: F841 time.sleep(0.01) tracer.flush()