diff --git a/.agent/cohort_builder_enhancements.md b/.agent/cohort_builder_enhancements.md deleted file mode 100644 index 5a7c05e..0000000 --- a/.agent/cohort_builder_enhancements.md +++ /dev/null @@ -1,227 +0,0 @@ -# Cohort Builder API Enhancements - -## Summary - -This document summarizes the enhancements made to the `circe.cohort_builder` fluent API to support advanced cohort definition features including nested criteria groups, demographic filters, and comprehensive OHDSI CIRCE parity. - -## Key Features Added - -### 1. Nested Criteria Groups - -**Purpose**: Enable complex logical grouping of inclusion criteria (ANY, ALL, AT_LEAST) - -**Implementation**: -- Added `GroupConfig` dataclass to represent nested criteria groups -- Implemented `CriteriaGroupBuilder` class for fluent nested group construction -- Added methods: `any_of()`, `all_of()`, `at_least_of(count)`, `end_group()` -- Updated `_build_criteria_group()` to recursively process nested groups - -**Example**: -```python -from circe.cohort_builder import Cohort - -cohort = ( - Cohort('Complex Cohort') - .with_condition(1) - .any_of() - .require_drug(10).anytime_before() - .all_of() - .require_procedure(20).same_day() - .require_measurement(30).anytime_after() - .end_group() - .end_group() - .build() -) -``` - -### 2. Demographic Filters - -**Purpose**: Filter cohort members by demographic attributes (age, gender, race, ethnicity) - -**Implementation**: -- Extended `CohortSettings` with demographic fields: - - `gender_concepts: List[int]` - - `race_concepts: List[int]` - - `ethnicity_concepts: List[int]` - - `age_min: Optional[int]` - - `age_max: Optional[int]` -- Added methods to `CohortWithEntry` and `CohortWithCriteria`: - - `require_gender(*concept_ids)` - - `require_race(*concept_ids)` - - `require_ethnicity(*concept_ids)` - - `require_age(min_age, max_age)` -- Automatic creation of `DemographicCriteria` inclusion rule - -**Example**: -```python -cohort = ( - Cohort('Adults Only') - .with_condition(1) - .require_age(18, 65) - .require_gender(8507) # Male - .build() -) -``` - -### 3. Named Inclusion Rules - -**Purpose**: Support attrition tracking by naming inclusion rules - -**Implementation**: -- Added `begin_rule(name: str)` method to start a new named rule -- Modified `CohortWithCriteria` to manage rules as list of dictionaries -- Each rule contains: `{"name": str, "group": GroupConfig}` - -**Example**: -```python -cohort = ( - Cohort('Multi-Rule Cohort') - .with_condition(1) - .begin_rule('Age Criteria') - .require_age(18, 65) - .begin_rule('Drug Exposure') - .require_drug(10).anytime_before() - .build() -) -``` - -### 4. Advanced Query Filters - -**Purpose**: Support comprehensive OHDSI filtering capabilities - -**New QueryConfig Fields**: -- `gender_concepts: List[int]` - Filter by gender -- `visit_type_concepts: List[int]` - Filter by visit type -- `provider_specialty_concepts: List[int]` - Filter by provider specialty -- `source_concept_set_id: Optional[int]` - Filter by source concepts -- `restrict_visit: bool` - Restrict to same visit as index -- `ignore_observation_period: bool` - Ignore observation period constraints - -**New BaseQuery Methods**: -- `with_gender(*concept_ids)` -- `with_visit_type(*concept_ids)` -- `with_provider_specialty(*concept_ids)` -- `with_source_concept(concept_set_id)` -- `restrict_to_visit()` -- `ignore_observation_period()` -- `relative_to_index_end()` -- `relative_to_event_end()` - -**Example**: -```python -cohort = ( - Cohort('Inpatient Drug Exposure') - .with_condition(1) - .require_drug(10) - .with_visit_type(9201) # Inpatient - .restrict_to_visit() - .anytime_before() - .build() -) -``` - -### 5. Time Window Enhancements - -**Purpose**: Support relative time windows to index/event end dates - -**Implementation**: -- Added `use_index_end: bool` to `TimeWindow` -- Added `use_event_end: bool` to `TimeWindow` -- Implemented `relative_to_index_end()` and `relative_to_event_end()` methods -- Correctly map these flags to `Window` model in `_build_correlated_criteria()` - -### 6. Domain-Specific Filters - -**Purpose**: Support domain-specific filtering (measurements, drug exposures, eras, etc.) - -**Enhanced Mapping in `_config_to_criteria()`**: -- Measurement: `value_as_number` (value_min/max) -- DrugExposure: `days_supply`, `quantity` -- DrugEra/ConditionEra: `era_length` -- DoseEra: `dose` -- VisitOccurrence: `visit_length` -- VisitDetail: `visit_detail_length` -- ObservationPeriod: `period_length` -- Gender, visit type, provider specialty mapping -- Source concept mapping for all relevant domains - -## Architecture Changes - -### State Transitions - -The API maintains the following state progression: -``` -CohortBuilder (Cohort) - ↓ with_*() -CohortWithEntry - ↓ require_*() / exclude_*() / any_of() / all_of() -CohortWithCriteria | CriteriaGroupBuilder - ↓ build() / end_group() -CohortExpression -``` - -### Key Classes - -1. **CohortBuilder (alias: Cohort)** - - Entry point for cohort definition - - Sets entry event and title - -2. **CohortWithEntry** - - Configures observation windows - - Sets demographic filters - - Transitions to criteria state - -3. **CohortWithCriteria** - - Manages inclusion/exclusion rules - - Supports nested groups - - Builds final `CohortExpression` - -4. **CriteriaGroupBuilder** - - Manages nested criteria groups - - Supports recursive nesting - - Returns to parent on `end_group()` - -5. **CohortSettings** - - Stores cohort-wide configuration - - Demographics, exit strategy, era collapse - -6. **GroupConfig** - - Represents a criteria group - - Supports types: ALL, ANY, AT_LEAST, AT_MOST - - Contains list of `CriteriaConfig` and nested `GroupConfig` - -## Testing - -All features have been tested with example cohorts demonstrating: -- Simple cohorts with single criteria -- Nested groups (ANY within ALL) -- Demographic filtering -- Named inclusion rules -- Advanced query filters -- Time window configurations - -## Future Enhancements - -Potential areas for future development: -1. Correlated criteria within groups -2. Date adjustment support -3. Censoring criteria -4. Custom era strategies -5. Additional domain-specific filters -6. Validation and error handling improvements - -## API Compatibility - -All changes maintain backward compatibility with existing code while adding new optional features. The API is designed to be: -- **LLM-friendly**: Clear method names and guided state transitions -- **Type-safe**: Proper type hints throughout -- **Fluent**: Method chaining for readable cohort definitions -- **CIRCE-compatible**: Generates valid OHDSI CIRCE JSON - -## Files Modified - -- `circe/cohort_builder/builder.py` - Core builder implementation -- `circe/cohort_builder/query_builder.py` - Query configuration and base query -- `circe/cohort_builder/__init__.py` - Public API exports -- `circe/cohortdefinition/criteria.py` - Criteria models (no changes needed) -- `circe/cohortdefinition/core.py` - Core models (no changes needed) diff --git a/README.md b/README.md index 4e87bfd..4ccd550 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,68 @@ circe process cohort.json --validate --sql --markdown See the [CLI Documentation](#command-line-interface) section below for more details. -### Python API +## High-Level Building APIs + +CIRCE Python provides two high-level APIs for building cohorts without manually constructing complex JSON/Pydantic models. + +### 1. Context Manager API (`circe.cohort_builder`) +**Best for: LLMs, beginners, and interactive development.** +This Pythonic API uses `with` blocks and auto-builds on exit. + +```python +from circe.cohort_builder import CohortBuilder +from circe.vocabulary import concept_set, descendants +from circe.api import build_cohort_query + +# 1. Define concept sets +t2dm = concept_set(descendants(201826), id=1, name="T2DM") +metformin = concept_set(descendants(1503297), id=2, name="Metformin") + +# 2. Build cohort using context manager +with CohortBuilder("New Metformin Users with T2DM") as cohort: + cohort.with_concept_sets(t2dm, metformin) + cohort.with_drug(concept_set_id=2) # Entry: metformin exposure + cohort.first_occurrence() # First exposure only + cohort.with_observation_window(prior_days=365) # 365 days prior + cohort.min_age(18) # Adults only + cohort.require_condition(concept_set_id=1, within_days_before=365) + + with cohort.include_rule("No Prior Insulin") as rule: + rule.exclude_drug(3, anytime_before=True) + +# 3. Access the built expression and generate SQL +sql = build_cohort_query(cohort.expression) +``` + +### 2. Capr-style API (`circe.capr`) +**Best for: Power users familiar with the R `Capr` package.** +A functional, declarative API for programmatic cohort definition. + +```python +from circe.capr import ( + cohort, entry, condition_occurrence, drug_exposure, + at_least, exactly, with_all, during_interval, event_starts +) + +# Build using functional composition +my_cohort = cohort( + title="T2DM on Metformin", + entry=entry( + drug_exposure(concept_set_id=2, first_occurrence=True), + observation_window=(365, 0) + ), + attrition=attrition( + has_t2dm=with_all( + at_least(1, condition_occurrence(1), + during_interval(event_starts(before=365))) + ) + ) +).build() +``` + +## Advanced Usage: Raw Pydantic Models + +For full control, you can use the underlying Pydantic models that replicate the Java CIRCE-BE internal structure. ```python from circe import CohortExpression @@ -91,56 +152,44 @@ from circe.cohortdefinition import PrimaryCriteria, ConditionOccurrence from circe.cohortdefinition.core import ObservationFilter, ResultLimit from circe.vocabulary import ConceptSet, ConceptSetExpression, ConceptSetItem, Concept -# Create a cohort expression +# Create a cohort expression using raw models cohort = CohortExpression( - title="Type 2 Diabetes Cohort", + title="Raw Model Example", primary_criteria=PrimaryCriteria( - criteria_list=[ - ConditionOccurrence( - codeset_id=1, - first=True - ) - ], + criteria_list=[ConditionOccurrence(codeset_id=1, first=True)], observation_window=ObservationFilter(prior_days=0, post_days=0), primary_limit=ResultLimit(type="All") ), - concept_sets=[ - ConceptSet( - id=1, - name="Type 2 Diabetes", - expression=ConceptSetExpression( - items=[ - ConceptSetItem( - concept=Concept( - concept_id=201826, - concept_name="Type 2 diabetes mellitus" - ), - include_descendants=True - ) - ] - ) - ) - ] + concept_sets=[...] ) +``` -# Generate SQL using the API -from circe.api import build_cohort_query -from circe.cohortdefinition import BuildExpressionQueryOptions - -options = BuildExpressionQueryOptions() -options.cdm_schema = 'cdm' -options.vocabulary_schema = 'cdm' -options.cohort_id = 1 -options.target_table = 'scratch.cohort' -sql = build_cohort_query(cohort, options) -print(sql) +## AI Agent Integration + +CIRCE Python provides skill documentation for AI agents that need to generate cohort definitions programmatically. + +```python +from circe import get_cohort_builder_skill, list_skills + +# List available skills +print(list_skills()) # ['cohort_builder'] + +# Get skill documentation for an AI agent +skill_docs = get_cohort_builder_skill() +# Returns markdown documentation describing the CohortBuilder API ``` +The skill documentation includes: +- Context manager API usage patterns +- Available entry event methods +- Inclusion/exclusion criteria syntax +- Named rule contexts for attrition tracking + ## What's Included This package provides a complete Python implementation of CIRCE-BE with: -- **3,400+ passing tests** with focused coverage on core logic +- **800+ passing tests** with focused coverage on core logic - **18+ SQL builders** for all OMOP CDM domains: - Condition Occurrence/Era - Drug Exposure/Era diff --git a/circe/__init__.py b/circe/__init__.py index 93222da..bd4e165 100644 --- a/circe/__init__.py +++ b/circe/__init__.py @@ -32,6 +32,7 @@ build_cohort_query, cohort_print_friendly, ) +from .skills import get_cohort_builder_skill, get_skill, list_skills from circe.cohortdefinition import ( CohortExpression, Criteria, CorelatedCriteria, DemographicCriteria, @@ -186,5 +187,10 @@ def get_json_schema() -> dict: "cohort_expression_from_json", "build_cohort_query", "cohort_print_friendly", - "safe_model_rebuild" + "safe_model_rebuild", + # Skills for AI agents + "get_cohort_builder_skill", + "get_skill", + "list_skills", ] + diff --git a/circe/api.py b/circe/api.py index 363ed63..ae7b6be 100644 --- a/circe/api.py +++ b/circe/api.py @@ -7,7 +7,7 @@ - cohort_print_friendly(): Generate Markdown from cohort expression """ -from typing import Optional, List +from typing import Optional, List, Dict, Any, Literal from .cohortdefinition import ( CohortExpression, CohortExpressionQueryBuilder, @@ -127,3 +127,4 @@ def cohort_print_friendly( renderer = MarkdownRender(concept_sets=concept_sets, include_concept_sets=include_concept_sets) return renderer.render_cohort_expression(expression, title=title) + diff --git a/circe/capr/__init__.py b/circe/capr/__init__.py new file mode 100644 index 0000000..8cbe0e8 --- /dev/null +++ b/circe/capr/__init__.py @@ -0,0 +1,82 @@ +""" +Capr-style API for building OHDSI cohort definitions. + +This module provides a fluent interface modeled after the OHDSI/Capr R package +for constructing cohort definitions in Python. + +Example: + >>> from circe.capr import ( + ... cohort, entry, attrition, exit_strategy, + ... condition_occurrence, drug_exposure, + ... at_least, during_interval, event_starts + ... ) + >>> + >>> t2dm_cohort = cohort( + ... entry=entry( + ... condition_occurrence(concept_set_id=1, first_occurrence=True), + ... observation_window=(365, 0) + ... ) + ... ) +""" + +from circe.capr.cohort import cohort, entry, exit_strategy, era +from circe.capr.query import ( + condition_occurrence, condition_era, + drug_exposure, drug_era, dose_era, + procedure, measurement, observation, + visit, visit_detail, + device_exposure, specimen, death, + observation_period, payer_plan_period, location_region +) +from circe.capr.criteria import ( + at_least, at_most, exactly, + with_all, with_any +) +from circe.capr.window import ( + during_interval, event_starts, event_ends, + continuous_observation +) +from circe.capr.attrition import attrition +from circe.capr.templates import ( + sensitive_disease_cohort, + specific_disease_cohort, + acute_disease_cohort, + chronic_disease_cohort, + new_user_drug_cohort +) + +__all__ = [ + # Cohort construction + "cohort", "entry", "exit_strategy", "era", "attrition", + + # Domain queries - Conditions + "condition_occurrence", "condition_era", + + # Domain queries - Drugs + "drug_exposure", "drug_era", "dose_era", + + # Domain queries - Other clinical + "procedure", "measurement", "observation", + "visit", "visit_detail", + "device_exposure", "specimen", "death", + + # Domain queries - Administrative + "observation_period", "payer_plan_period", "location_region", + + # Occurrence counting + "at_least", "at_most", "exactly", + + # Grouping + "with_all", "with_any", + + # Time windows + "during_interval", "event_starts", "event_ends", + "continuous_observation", + + # Templates + "sensitive_disease_cohort", + "specific_disease_cohort", + "acute_disease_cohort", + "chronic_disease_cohort", + "new_user_drug_cohort" +] diff --git a/circe/capr/attrition.py b/circe/capr/attrition.py new file mode 100644 index 0000000..501af6c --- /dev/null +++ b/circe/capr/attrition.py @@ -0,0 +1,119 @@ +""" +Attrition functions for CohortComposer. + +These functions define inclusion and exclusion rules that filter the cohort +after the initial entry event. Modeled after OHDSI/Capr's attrition function. +""" + +from typing import Dict, Union, Optional, List +from dataclasses import dataclass, field +from circe.capr.criteria import CriteriaGroup, Criteria + + +@dataclass +class AttritionRule: + """ + Represents a named attrition rule (inclusion/exclusion criterion). + + Attributes: + name: Human-readable name for the rule (used in attrition reports) + description: Optional detailed description + expression: CriteriaGroup defining the rule logic + """ + name: str + description: Optional[str] = None + expression: Optional[CriteriaGroup] = None + + +@dataclass +class AttritionRules: + """ + Collection of attrition rules to apply to the cohort. + + Attributes: + rules: List of AttritionRule objects in order of application + """ + rules: List[AttritionRule] = field(default_factory=list) + + +def attrition( + **named_rules: Union[CriteriaGroup, Criteria] +) -> AttritionRules: + """ + Define attrition rules for the cohort. + + Args: + **named_rules: Keyword arguments where name is the rule name + and value is a CriteriaGroup or Criteria + + Returns: + AttritionRules object containing all defined rules + + Example: + >>> attrition( + ... no_prior_insulin=with_all( + ... exactly(0, drug_exposure(insulin_cs), aperture) + ... ), + ... has_t2dm_diagnosis=with_all( + ... at_least(1, condition_occurrence(t2dm_cs), aperture) + ... ) + ... ) + + Note: + Rule names should be descriptive as they appear in attrition reports. + Use underscores for multi-word names; they will be displayed with spaces. + """ + rules = [] + for name, expression in named_rules.items(): + # Convert single Criteria to CriteriaGroup if needed + if isinstance(expression, Criteria): + expression = CriteriaGroup( + group_type='ALL', + criteria_list=[expression] + ) + + # Convert underscores to spaces for display + display_name = name.replace('_', ' ') + + rules.append(AttritionRule( + name=display_name, + expression=expression + )) + + return AttritionRules(rules=rules) + + +def inclusion_rule( + name: str, + expression: Union[CriteriaGroup, Criteria], + description: Optional[str] = None +) -> AttritionRule: + """ + Create a single named inclusion rule. + + Args: + name: Human-readable name for the rule + expression: CriteriaGroup or Criteria defining the rule + description: Optional detailed description + + Returns: + AttritionRule object + + Example: + >>> rule = inclusion_rule( + ... name="Has HbA1c within 6 months", + ... expression=at_least(1, measurement(hba1c_cs), aperture), + ... description="Patient must have HbA1c measurement within 6 months after index" + ... ) + """ + if isinstance(expression, Criteria): + expression = CriteriaGroup( + group_type='ALL', + criteria_list=[expression] + ) + + return AttritionRule( + name=name, + expression=expression, + description=description + ) diff --git a/circe/capr/cohort.py b/circe/capr/cohort.py new file mode 100644 index 0000000..7580167 --- /dev/null +++ b/circe/capr/cohort.py @@ -0,0 +1,486 @@ +""" +Main cohort construction functions for CohortComposer. + +This module provides the primary API for building complete cohort definitions. +Modeled after OHDSI/Capr's cohort() function and related components. +""" + +from typing import Optional, List, Union, Any, TYPE_CHECKING +from dataclasses import dataclass, field + +from circe.capr.query import Query +from circe.capr.window import ObservationWindow, continuous_observation +from circe.capr.criteria import CriteriaGroup, Criteria +from circe.capr.attrition import AttritionRules, AttritionRule + +# Import circe models for conversion +from circe.cohortdefinition import ( + CohortExpression, PrimaryCriteria, InclusionRule, + CorelatedCriteria, Occurrence, ConditionOccurrence, DrugExposure, + ProcedureOccurrence, Measurement, Observation, VisitOccurrence, + VisitDetail, DeviceExposure, Specimen, Death, ObservationPeriod, + PayerPlanPeriod, LocationRegion, ConditionEra, DrugEra, DoseEra +) +from circe.cohortdefinition.core import ( + ObservationFilter, ResultLimit, Window, WindowBound, + NumericRange, DateOffsetStrategy, CustomEraStrategy, CollapseSettings +) +from circe.vocabulary import ConceptSet + + +@dataclass +class EntryEvent: + """ + Represents the entry event (index event) for a cohort. + + Attributes: + query: Domain query defining the entry event + observation_window: Continuous observation requirements + primary_criteria_limit: How to limit primary criteria ('First', 'All', 'Last') + additional_criteria: Optional additional criteria to apply + """ + query: Query + observation_window: ObservationWindow = field(default_factory=lambda: ObservationWindow(0, 0)) + primary_criteria_limit: str = "All" + additional_criteria: Optional[CriteriaGroup] = None + + +@dataclass +class ExitStrategy: + """ + Represents the exit strategy for a cohort. + + Attributes: + strategy_type: 'observation', 'date_offset', or 'custom_era' + offset_days: Days to offset from index (for date_offset) + offset_field: Which date to offset from ('startDate' or 'endDate') + custom_era_gap_days: Gap days for custom era strategy + drug_codeset_id: Concept set ID for custom drug era + censor_events: List of queries for censoring events + """ + strategy_type: str = "observation" # 'observation', 'date_offset', 'custom_era' + offset_days: int = 0 + offset_field: str = "startDate" + custom_era_gap_days: int = 0 + drug_codeset_id: Optional[int] = None + censor_events: List[Query] = field(default_factory=list) + + +@dataclass +class CohortEra: + """ + Represents era settings for the cohort. + + Attributes: + era_days: Days to collapse era (gap days) + study_start_date: Optional study start date + study_end_date: Optional study end date + """ + era_days: int = 0 + study_start_date: Optional[str] = None + study_end_date: Optional[str] = None + + +@dataclass +class ComposedCohort: + """ + Intermediate representation of a composed cohort before conversion to CohortExpression. + + Attributes: + title: Cohort title + entry_event: Entry event definition + attrition: Attrition rules + exit: Exit strategy + era: Era settings + concept_sets: List of concept sets + """ + title: Optional[str] = None + entry_event: Optional[EntryEvent] = None + attrition: Optional[AttritionRules] = None + exit: Optional[ExitStrategy] = None + era: Optional[CohortEra] = None + concept_sets: List[ConceptSet] = field(default_factory=list) + + def build(self) -> CohortExpression: + """Convert to a CohortExpression that can generate SQL.""" + return _convert_to_cohort_expression(self) + + +def entry( + query: Query, + observation_window: Union[ObservationWindow, tuple, None] = None, + primary_criteria_limit: str = "All", + additional_criteria: Optional[CriteriaGroup] = None +) -> EntryEvent: + """ + Define the entry event (index event) for the cohort. + + Args: + query: Domain query defining the entry event + observation_window: Observation requirements as ObservationWindow or (prior_days, post_days) + primary_criteria_limit: 'First', 'All', or 'Last' + additional_criteria: Optional additional criteria to apply at entry + + Returns: + EntryEvent object + + Example: + >>> entry( + ... condition_occurrence(concept_set_id=1, first_occurrence=True), + ... observation_window=(365, 0), + ... primary_criteria_limit="First" + ... ) + """ + # Convert tuple to ObservationWindow + if isinstance(observation_window, tuple): + observation_window = continuous_observation( + prior_days=observation_window[0], + post_days=observation_window[1] if len(observation_window) > 1 else 0 + ) + elif observation_window is None: + observation_window = continuous_observation(0, 0) + + return EntryEvent( + query=query, + observation_window=observation_window, + primary_criteria_limit=primary_criteria_limit, + additional_criteria=additional_criteria + ) + + +def exit_strategy( + end_strategy: str = "observation", + offset_days: int = 0, + offset_field: str = "startDate", + custom_era_gap_days: int = 0, + drug_codeset_id: Optional[int] = None, + censor_events: Optional[List[Query]] = None +) -> ExitStrategy: + """ + Define the exit strategy for the cohort. + + Args: + end_strategy: 'observation' (end of observation), 'date_offset', or 'custom_era' + offset_days: Days to offset from index (for date_offset strategy) + offset_field: Which date to use for offset ('startDate' or 'endDate') + custom_era_gap_days: Gap days for custom drug era strategy + drug_codeset_id: Concept set ID for custom drug era + censor_events: List of queries that trigger censoring + + Returns: + ExitStrategy object + + Example: + >>> # End at end of observation + >>> exit_strategy(end_strategy="observation") + + >>> # End 365 days after index + >>> exit_strategy(end_strategy="date_offset", offset_days=365) + + >>> # Custom drug era with 30 day gap + >>> exit_strategy( + ... end_strategy="custom_era", + ... drug_codeset_id=2, + ... custom_era_gap_days=30 + ... ) + """ + return ExitStrategy( + strategy_type=end_strategy, + offset_days=offset_days, + offset_field=offset_field, + custom_era_gap_days=custom_era_gap_days, + drug_codeset_id=drug_codeset_id, + censor_events=censor_events or [] + ) + + +def observation_exit() -> ExitStrategy: + """Convenience function for observation-based exit strategy.""" + return exit_strategy(end_strategy="observation") + + +def date_offset_exit(days: int, from_field: str = "startDate") -> ExitStrategy: + """Convenience function for date offset exit strategy.""" + return exit_strategy( + end_strategy="date_offset", + offset_days=days, + offset_field=from_field + ) + + +def era( + era_days: int = 0, + study_start_date: Optional[str] = None, + study_end_date: Optional[str] = None +) -> CohortEra: + """ + Define era settings for the cohort. + + Args: + era_days: Days to collapse successive cohort entries (gap days) + study_start_date: Optional study start date (YYYY-MM-DD) + study_end_date: Optional study end date (YYYY-MM-DD) + + Returns: + CohortEra object + + Example: + >>> era(era_days=0) # No collapsing + >>> era(era_days=30) # Collapse entries within 30 days + """ + return CohortEra( + era_days=era_days, + study_start_date=study_start_date, + study_end_date=study_end_date + ) + + +def cohort( + title: Optional[str] = None, + entry: Optional[EntryEvent] = None, + attrition: Optional[AttritionRules] = None, + exit: Optional[ExitStrategy] = None, + era: Optional[CohortEra] = None, + concept_sets: Optional[List[ConceptSet]] = None +) -> ComposedCohort: + """ + Create a complete cohort definition. + + Args: + title: Cohort title + entry: Entry event definition + attrition: Attrition rules (inclusion/exclusion) + exit: Exit strategy + era: Era settings + concept_sets: List of concept sets used in the cohort + + Returns: + ComposedCohort object that can be converted to CohortExpression + + Example: + >>> my_cohort = cohort( + ... title="T2DM on Metformin", + ... entry=entry( + ... drug_exposure(concept_set_id=2, first_occurrence=True), + ... observation_window=(365, 0) + ... ), + ... attrition=attrition( + ... has_t2dm=with_all( + ... at_least(1, condition_occurrence(1), aperture) + ... ) + ... ), + ... concept_sets=[t2dm_cs, metformin_cs] + ... ) + >>> + >>> cohort_expression = my_cohort.build() + """ + return ComposedCohort( + title=title, + entry_event=entry, + attrition=attrition, + exit=exit, + era=era, + concept_sets=concept_sets or [] + ) + + +# ============================================================================= +# CONVERSION FUNCTIONS +# ============================================================================= + +def _convert_to_cohort_expression(composed: ComposedCohort) -> CohortExpression: + """Convert a ComposedCohort to a CohortExpression.""" + + # Build primary criteria + primary_criteria = None + if composed.entry_event: + primary_criteria = _build_primary_criteria(composed.entry_event) + + # Build inclusion rules from attrition + inclusion_rules = [] + if composed.attrition and composed.attrition.rules: + inclusion_rules = [ + _build_inclusion_rule(rule) + for rule in composed.attrition.rules + ] + + # Build end strategy + end_strategy = None + censoring_criteria = [] + if composed.exit: + end_strategy = _build_end_strategy(composed.exit) + if composed.exit.censor_events: + censoring_criteria = [ + _query_to_criteria(q) for q in composed.exit.censor_events + ] + + # Build collapse settings from era + collapse_settings = None + if composed.era: + collapse_settings = CollapseSettings( + era_pad=composed.era.era_days + ) + + return CohortExpression( + title=composed.title, + concept_sets=composed.concept_sets, + primary_criteria=primary_criteria, + inclusion_rules=inclusion_rules, + end_strategy=end_strategy, + censoring_criteria=censoring_criteria, + collapse_settings=collapse_settings + ) + + +def _build_primary_criteria(entry_event: EntryEvent) -> PrimaryCriteria: + """Convert EntryEvent to PrimaryCriteria.""" + criteria = _query_to_criteria(entry_event.query) + + return PrimaryCriteria( + criteria_list=[criteria], + observation_window=ObservationFilter( + prior_days=entry_event.observation_window.prior_days, + post_days=entry_event.observation_window.post_days + ), + primary_limit=ResultLimit(type=entry_event.primary_criteria_limit) + ) + + +def _build_inclusion_rule(rule: AttritionRule) -> InclusionRule: + """Convert AttritionRule to InclusionRule.""" + from circe.cohortdefinition.criteria import CriteriaGroup as CirceCriteriaGroup + + expression = None + if rule.expression: + expression = _build_criteria_group(rule.expression) + + return InclusionRule( + name=rule.name, + description=rule.description, + expression=expression + ) + + +def _build_criteria_group(group: CriteriaGroup): + """Convert composer CriteriaGroup to circe CriteriaGroup.""" + from circe.cohortdefinition.criteria import CriteriaGroup as CirceCriteriaGroup + + criteria_list = [] + for item in group.criteria_list: + if isinstance(item, Criteria): + criteria_list.append(_build_correlated_criteria(item)) + elif isinstance(item, CriteriaGroup): + # Nested group - not directly supported, flatten or nest + pass + + return CirceCriteriaGroup( + type=group.group_type, + criteria_list=criteria_list + ) + + +def _build_correlated_criteria(criteria: Criteria) -> CorelatedCriteria: + """Convert composer Criteria to CorelatedCriteria.""" + from circe.cohortdefinition.core import Window, WindowBound + + query_criteria = _query_to_criteria(criteria.query) + + # Build occurrence + occurrence_type_map = { + 'atLeast': 2, # AT_LEAST + 'atMost': 1, # AT_MOST + 'exactly': 0 # EXACTLY + } + occurrence = Occurrence( + type=occurrence_type_map.get(criteria.occurrence_type, 2), + count=criteria.count, + is_distinct=criteria.is_distinct + ) + + # Build window from aperture + start_window = None + if criteria.aperture and criteria.aperture.start_window: + interval = criteria.aperture.start_window + start_window = Window( + use_event_end=criteria.aperture.use_event_end, + start=WindowBound( + coeff=-1, + days=interval.start + ), + end=WindowBound( + coeff=1, + days=interval.end + ) + ) + + return CorelatedCriteria( + criteria=query_criteria, + start_window=start_window, + occurrence=occurrence, + restrict_visit=criteria.aperture.restrict_visit if criteria.aperture else False, + ignore_observation_period=criteria.aperture.ignore_observation_period if criteria.aperture else False + ) + + +def _build_end_strategy(exit_strat: ExitStrategy): + """Convert ExitStrategy to circe end strategy.""" + if exit_strat.strategy_type == "date_offset": + return DateOffsetStrategy( + date_field=exit_strat.offset_field, + offset=exit_strat.offset_days + ) + elif exit_strat.strategy_type == "custom_era": + return CustomEraStrategy( + drug_codeset_id=exit_strat.drug_codeset_id, + gap_days=exit_strat.custom_era_gap_days, + offset=0 + ) + # Observation exit returns None (default behavior) + return None + + +def _query_to_criteria(query: Query): + """Convert a Query to the appropriate domain Criteria object.""" + + domain_map = { + 'ConditionOccurrence': ConditionOccurrence, + 'ConditionEra': ConditionEra, + 'DrugExposure': DrugExposure, + 'DrugEra': DrugEra, + 'DoseEra': DoseEra, + 'ProcedureOccurrence': ProcedureOccurrence, + 'Measurement': Measurement, + 'Observation': Observation, + 'VisitOccurrence': VisitOccurrence, + 'VisitDetail': VisitDetail, + 'DeviceExposure': DeviceExposure, + 'Specimen': Specimen, + 'Death': Death, + 'ObservationPeriod': ObservationPeriod, + 'PayerPlanPeriod': PayerPlanPeriod, + 'LocationRegion': LocationRegion + } + + criteria_class = domain_map.get(query.domain) + if not criteria_class: + raise ValueError(f"Unknown domain: {query.domain}") + + # Build kwargs from query + kwargs = { + 'codeset_id': query.concept_set_id, + 'first': query.first_occurrence if query.first_occurrence else None + } + + # Add domain-specific options + options = query.criteria_options + if 'age' in options: + age = options['age'] + if isinstance(age, tuple): + kwargs['age'] = NumericRange(value=age[0], op='gte') + else: + kwargs['age'] = NumericRange(value=age, op='gte') + + # Filter out None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return criteria_class(**kwargs) diff --git a/circe/capr/criteria.py b/circe/capr/criteria.py new file mode 100644 index 0000000..5eab0db --- /dev/null +++ b/circe/capr/criteria.py @@ -0,0 +1,247 @@ +""" +Criteria functions for CohortComposer. + +These functions define occurrence counting and grouping logic for cohort criteria. +Modeled after OHDSI/Capr's atLeast, atMost, exactly, and group functions. +""" + +from typing import Optional, List, Union +from dataclasses import dataclass, field +from circe.capr.query import Query +from circe.capr.window import TimeWindow + + +@dataclass +class Criteria: + """ + Represents a criteria object that counts occurrences of a query within a time window. + + Attributes: + occurrence_type: 'atLeast', 'atMost', or 'exactly' + count: Number of occurrences required + query: The domain query to count + aperture: Time window for the search + is_distinct: Whether to count distinct occurrences + """ + occurrence_type: str # 'atLeast', 'atMost', 'exactly' + count: int + query: Query + aperture: Optional[TimeWindow] = None + is_distinct: bool = False + + +@dataclass +class CriteriaGroup: + """ + Represents a group of criteria combined with AND/OR logic. + + Attributes: + group_type: 'ALL' (AND) or 'ANY' (OR) + criteria_list: List of Criteria or nested CriteriaGroup objects + demographic_criteria: Optional demographic restrictions + """ + group_type: str # 'ALL' or 'ANY' + criteria_list: List[Union[Criteria, 'CriteriaGroup']] = field(default_factory=list) + demographic_criteria: Optional[dict] = None + + +def at_least( + count: int, + query: Query, + aperture: Optional[TimeWindow] = None, + is_distinct: bool = False +) -> Criteria: + """ + Create a criteria requiring at least N occurrences. + + Args: + count: Minimum number of occurrences required (>=) + query: Domain query to search for + aperture: Time window for the search + is_distinct: If True, count distinct occurrences + + Returns: + Criteria object + + Example: + >>> # At least 2 drug exposures within 365 days before index + >>> at_least( + ... count=2, + ... query=drug_exposure(concept_set_id=1), + ... aperture=during_interval(event_starts(before=365, after=0)) + ... ) + """ + return Criteria( + occurrence_type='atLeast', + count=count, + query=query, + aperture=aperture, + is_distinct=is_distinct + ) + + +def at_most( + count: int, + query: Query, + aperture: Optional[TimeWindow] = None, + is_distinct: bool = False +) -> Criteria: + """ + Create a criteria requiring at most N occurrences. + + Args: + count: Maximum number of occurrences allowed (<=) + query: Domain query to search for + aperture: Time window for the search + is_distinct: If True, count distinct occurrences + + Returns: + Criteria object + + Example: + >>> # At most 1 prior condition + >>> at_most( + ... count=1, + ... query=condition_occurrence(concept_set_id=2), + ... aperture=during_interval(event_starts(before=365, after=0)) + ... ) + """ + return Criteria( + occurrence_type='atMost', + count=count, + query=query, + aperture=aperture, + is_distinct=is_distinct + ) + + +def exactly( + count: int, + query: Query, + aperture: Optional[TimeWindow] = None, + is_distinct: bool = False +) -> Criteria: + """ + Create a criteria requiring exactly N occurrences. + + Args: + count: Exact number of occurrences required (==) + query: Domain query to search for + aperture: Time window for the search + is_distinct: If True, count distinct occurrences + + Returns: + Criteria object + + Example: + >>> # Exactly 0 prior drug exposures (exclusion) + >>> exactly( + ... count=0, + ... query=drug_exposure(concept_set_id=3), + ... aperture=during_interval(event_starts(before=365, after=1)) + ... ) + """ + return Criteria( + occurrence_type='exactly', + count=count, + query=query, + aperture=aperture, + is_distinct=is_distinct + ) + + +def with_all(*criteria: Union[Criteria, 'CriteriaGroup']) -> CriteriaGroup: + """ + Combine criteria with AND logic - all must be satisfied. + + Args: + *criteria: Variable number of Criteria or CriteriaGroup objects + + Returns: + CriteriaGroup with ALL (AND) logic + + Example: + >>> # Must have both conditions + >>> with_all( + ... at_least(1, condition_occurrence(1), aperture), + ... at_least(1, drug_exposure(2), aperture) + ... ) + """ + return CriteriaGroup( + group_type='ALL', + criteria_list=list(criteria) + ) + + +def with_any(*criteria: Union[Criteria, 'CriteriaGroup']) -> CriteriaGroup: + """ + Combine criteria with OR logic - at least one must be satisfied. + + Args: + *criteria: Variable number of Criteria or CriteriaGroup objects + + Returns: + CriteriaGroup with ANY (OR) logic + + Example: + >>> # Must have at least one of these conditions + >>> with_any( + ... at_least(1, condition_occurrence(1), aperture), + ... at_least(1, condition_occurrence(2), aperture) + ... ) + """ + return CriteriaGroup( + group_type='ANY', + criteria_list=list(criteria) + ) + + +# Convenience function for exclusions +def none_of( + query: Query, + aperture: Optional[TimeWindow] = None +) -> Criteria: + """ + Convenience function for excluding patients with any occurrence. + Equivalent to exactly(0, query, aperture). + + Args: + query: Domain query that should have zero occurrences + aperture: Time window for the search + + Returns: + Criteria requiring exactly 0 occurrences + + Example: + >>> # Exclude patients with prior insulin + >>> none_of( + ... drug_exposure(concept_set_id=3), + ... aperture=during_interval(event_starts(before=365, after=1)) + ... ) + """ + return exactly(count=0, query=query, aperture=aperture) + + +def any_of( + query: Query, + aperture: Optional[TimeWindow] = None +) -> Criteria: + """ + Convenience function for requiring at least one occurrence. + Equivalent to at_least(1, query, aperture). + + Args: + query: Domain query that should have at least one occurrence + aperture: Time window for the search + + Returns: + Criteria requiring at least 1 occurrence + + Example: + >>> # Require at least one prior diagnosis + >>> any_of( + ... condition_occurrence(concept_set_id=1), + ... aperture=during_interval(event_starts(before=365, after=0)) + ... ) + """ + return at_least(count=1, query=query, aperture=aperture) diff --git a/circe/capr/query.py b/circe/capr/query.py new file mode 100644 index 0000000..432fedc --- /dev/null +++ b/circe/capr/query.py @@ -0,0 +1,837 @@ +""" +Domain query functions for CohortComposer. + +Each function creates a query object that specifies which concepts to extract +from which OMOP CDM domain table. These are used as entry events or within +criteria for attrition rules. + +Modeled after OHDSI/Capr R package query functions. +""" + +from typing import Optional, List, Union, Any +from dataclasses import dataclass, field + + +@dataclass +class Query: + """ + Base query object representing a domain-specific event search. + + Attributes: + domain: The OMOP CDM domain (e.g., 'ConditionOccurrence') + concept_set_id: ID of the concept set to search for + first_occurrence: If True, only use the first occurrence per person + criteria_options: Additional domain-specific filter options + """ + domain: str + concept_set_id: Optional[int] = None + first_occurrence: bool = False + criteria_options: dict = field(default_factory=dict) + + def first(self) -> 'Query': + """Return a copy with first_occurrence=True.""" + return Query( + domain=self.domain, + concept_set_id=self.concept_set_id, + first_occurrence=True, + criteria_options=self.criteria_options.copy() + ) + + +def condition_occurrence( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + condition_status: Optional[List[int]] = None, + condition_type: Optional[List[int]] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for condition occurrences. + + Args: + concept_set_id: ID of the concept set defining the conditions + first_occurrence: If True, only use first occurrence per person + age: Tuple of (min, max) age at occurrence, or single value for minimum + gender: List of gender concept IDs + condition_status: List of condition status concept IDs + condition_type: List of condition type concept IDs + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for condition occurrences + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if condition_status is not None: + options['condition_status'] = condition_status + if condition_type is not None: + options['condition_type'] = condition_type + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='ConditionOccurrence', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def condition_era( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + era_length: Optional[tuple] = None, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for condition eras. + + Args: + concept_set_id: ID of the concept set defining the conditions + first_occurrence: If True, only use first era per person + era_length: Tuple of (min, max) era length in days + age: Tuple of (min, max) age at era start + gender: List of gender concept IDs + + Returns: + Query object for condition eras + """ + options = {} + if era_length is not None: + options['era_length'] = era_length + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + options.update(kwargs) + + return Query( + domain='ConditionEra', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def drug_exposure( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + drug_type: Optional[List[int]] = None, + route: Optional[List[int]] = None, + dose_unit: Optional[List[int]] = None, + days_supply: Optional[tuple] = None, + quantity: Optional[tuple] = None, + refills: Optional[tuple] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for drug exposures. + + Args: + concept_set_id: ID of the concept set defining the drugs + first_occurrence: If True, only use first exposure per person + age: Tuple of (min, max) age at exposure + gender: List of gender concept IDs + drug_type: List of drug type concept IDs + route: List of route concept IDs + dose_unit: List of dose unit concept IDs + days_supply: Tuple of (min, max) days supply + quantity: Tuple of (min, max) quantity + refills: Tuple of (min, max) refills + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for drug exposures + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if drug_type is not None: + options['drug_type'] = drug_type + if route is not None: + options['route'] = route + if dose_unit is not None: + options['dose_unit'] = dose_unit + if days_supply is not None: + options['days_supply'] = days_supply + if quantity is not None: + options['quantity'] = quantity + if refills is not None: + options['refills'] = refills + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='DrugExposure', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def drug_era( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + era_length: Optional[tuple] = None, + gap_days: Optional[tuple] = None, + occurrence_count: Optional[tuple] = None, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for drug eras. + + Args: + concept_set_id: ID of the concept set defining the drugs + first_occurrence: If True, only use first era per person + era_length: Tuple of (min, max) era length in days + gap_days: Tuple of (min, max) gap days + occurrence_count: Tuple of (min, max) occurrence count within era + age: Tuple of (min, max) age at era start + gender: List of gender concept IDs + + Returns: + Query object for drug eras + """ + options = {} + if era_length is not None: + options['era_length'] = era_length + if gap_days is not None: + options['gap_days'] = gap_days + if occurrence_count is not None: + options['occurrence_count'] = occurrence_count + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + options.update(kwargs) + + return Query( + domain='DrugEra', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def dose_era( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + era_length: Optional[tuple] = None, + unit: Optional[List[int]] = None, + dose_value: Optional[tuple] = None, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for dose eras. + + Args: + concept_set_id: ID of the concept set defining the drugs + first_occurrence: If True, only use first era per person + era_length: Tuple of (min, max) era length in days + unit: List of unit concept IDs + dose_value: Tuple of (min, max) dose value + age: Tuple of (min, max) age at era start + gender: List of gender concept IDs + + Returns: + Query object for dose eras + """ + options = {} + if era_length is not None: + options['era_length'] = era_length + if unit is not None: + options['unit'] = unit + if dose_value is not None: + options['dose_value'] = dose_value + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + options.update(kwargs) + + return Query( + domain='DoseEra', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def procedure( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + procedure_type: Optional[List[int]] = None, + modifier: Optional[List[int]] = None, + quantity: Optional[tuple] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for procedure occurrences. + + Args: + concept_set_id: ID of the concept set defining the procedures + first_occurrence: If True, only use first procedure per person + age: Tuple of (min, max) age at procedure + gender: List of gender concept IDs + procedure_type: List of procedure type concept IDs + modifier: List of modifier concept IDs + quantity: Tuple of (min, max) quantity + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for procedure occurrences + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if procedure_type is not None: + options['procedure_type'] = procedure_type + if modifier is not None: + options['modifier'] = modifier + if quantity is not None: + options['quantity'] = quantity + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='ProcedureOccurrence', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def measurement( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + measurement_type: Optional[List[int]] = None, + operator: Optional[List[int]] = None, + value_as_number: Optional[tuple] = None, + value_as_concept: Optional[List[int]] = None, + unit: Optional[List[int]] = None, + range_low: Optional[tuple] = None, + range_high: Optional[tuple] = None, + abnormal: Optional[bool] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for measurements. + + Args: + concept_set_id: ID of the concept set defining the measurements + first_occurrence: If True, only use first measurement per person + age: Tuple of (min, max) age at measurement + gender: List of gender concept IDs + measurement_type: List of measurement type concept IDs + operator: List of operator concept IDs + value_as_number: Tuple of (min, max) numeric value + value_as_concept: List of value as concept IDs + unit: List of unit concept IDs + range_low: Tuple of (min, max) range low + range_high: Tuple of (min, max) range high + abnormal: If True, only include abnormal values + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for measurements + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if measurement_type is not None: + options['measurement_type'] = measurement_type + if operator is not None: + options['operator'] = operator + if value_as_number is not None: + options['value_as_number'] = value_as_number + if value_as_concept is not None: + options['value_as_concept'] = value_as_concept + if unit is not None: + options['unit'] = unit + if range_low is not None: + options['range_low'] = range_low + if range_high is not None: + options['range_high'] = range_high + if abnormal is not None: + options['abnormal'] = abnormal + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='Measurement', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def observation( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + observation_type: Optional[List[int]] = None, + value_as_number: Optional[tuple] = None, + value_as_string: Optional[str] = None, + value_as_concept: Optional[List[int]] = None, + qualifier: Optional[List[int]] = None, + unit: Optional[List[int]] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for observations. + + Args: + concept_set_id: ID of the concept set defining the observations + first_occurrence: If True, only use first observation per person + age: Tuple of (min, max) age at observation + gender: List of gender concept IDs + observation_type: List of observation type concept IDs + value_as_number: Tuple of (min, max) numeric value + value_as_string: String pattern to match + value_as_concept: List of value as concept IDs + qualifier: List of qualifier concept IDs + unit: List of unit concept IDs + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for observations + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if observation_type is not None: + options['observation_type'] = observation_type + if value_as_number is not None: + options['value_as_number'] = value_as_number + if value_as_string is not None: + options['value_as_string'] = value_as_string + if value_as_concept is not None: + options['value_as_concept'] = value_as_concept + if qualifier is not None: + options['qualifier'] = qualifier + if unit is not None: + options['unit'] = unit + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='Observation', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def visit( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + visit_type: Optional[List[int]] = None, + visit_length: Optional[tuple] = None, + place_of_service: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for visit occurrences. + + Args: + concept_set_id: ID of the concept set defining the visits + first_occurrence: If True, only use first visit per person + age: Tuple of (min, max) age at visit + gender: List of gender concept IDs + visit_type: List of visit type concept IDs + visit_length: Tuple of (min, max) visit length in days + place_of_service: List of place of service concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for visit occurrences + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if visit_type is not None: + options['visit_type'] = visit_type + if visit_length is not None: + options['visit_length'] = visit_length + if place_of_service is not None: + options['place_of_service'] = place_of_service + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='VisitOccurrence', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def visit_detail( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + visit_detail_type: Optional[List[int]] = None, + visit_detail_length: Optional[tuple] = None, + place_of_service: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for visit details. + + Args: + concept_set_id: ID of the concept set defining the visit details + first_occurrence: If True, only use first visit detail per person + age: Tuple of (min, max) age at visit detail + gender: List of gender concept IDs + visit_detail_type: List of visit detail type concept IDs + visit_detail_length: Tuple of (min, max) visit detail length in days + place_of_service: List of place of service concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for visit details + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if visit_detail_type is not None: + options['visit_detail_type'] = visit_detail_type + if visit_detail_length is not None: + options['visit_detail_length'] = visit_detail_length + if place_of_service is not None: + options['place_of_service'] = place_of_service + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='VisitDetail', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def device_exposure( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + device_type: Optional[List[int]] = None, + unique_device_id: Optional[str] = None, + quantity: Optional[tuple] = None, + visit_type: Optional[List[int]] = None, + provider_specialty: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for device exposures. + + Args: + concept_set_id: ID of the concept set defining the devices + first_occurrence: If True, only use first device exposure per person + age: Tuple of (min, max) age at device exposure + gender: List of gender concept IDs + device_type: List of device type concept IDs + unique_device_id: Device identifier pattern + quantity: Tuple of (min, max) quantity + visit_type: List of visit type concept IDs + provider_specialty: List of provider specialty concept IDs + + Returns: + Query object for device exposures + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if device_type is not None: + options['device_type'] = device_type + if unique_device_id is not None: + options['unique_device_id'] = unique_device_id + if quantity is not None: + options['quantity'] = quantity + if visit_type is not None: + options['visit_type'] = visit_type + if provider_specialty is not None: + options['provider_specialty'] = provider_specialty + options.update(kwargs) + + return Query( + domain='DeviceExposure', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def specimen( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + specimen_type: Optional[List[int]] = None, + anatomic_site: Optional[List[int]] = None, + disease_status: Optional[List[int]] = None, + unit: Optional[List[int]] = None, + quantity: Optional[tuple] = None, + **kwargs +) -> Query: + """ + Create a query for specimens. + + Args: + concept_set_id: ID of the concept set defining the specimens + first_occurrence: If True, only use first specimen per person + age: Tuple of (min, max) age at specimen collection + gender: List of gender concept IDs + specimen_type: List of specimen type concept IDs + anatomic_site: List of anatomic site concept IDs + disease_status: List of disease status concept IDs + unit: List of unit concept IDs + quantity: Tuple of (min, max) quantity + + Returns: + Query object for specimens + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if specimen_type is not None: + options['specimen_type'] = specimen_type + if anatomic_site is not None: + options['anatomic_site'] = anatomic_site + if disease_status is not None: + options['disease_status'] = disease_status + if unit is not None: + options['unit'] = unit + if quantity is not None: + options['quantity'] = quantity + options.update(kwargs) + + return Query( + domain='Specimen', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def death( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + death_type: Optional[List[int]] = None, + **kwargs +) -> Query: + """ + Create a query for death events. + + Args: + concept_set_id: ID of the concept set defining the cause of death (optional) + first_occurrence: If True, only use first death event per person + age: Tuple of (min, max) age at death + gender: List of gender concept IDs + death_type: List of death type concept IDs + + Returns: + Query object for death events + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if death_type is not None: + options['death_type'] = death_type + options.update(kwargs) + + return Query( + domain='Death', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def observation_period( + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + period_type: Optional[List[int]] = None, + period_length: Optional[tuple] = None, + user_defined_period: Optional[tuple] = None, + **kwargs +) -> Query: + """ + Create a query for observation periods. + + Args: + first_occurrence: If True, only use first observation period per person + age: Tuple of (min, max) age at period start + gender: List of gender concept IDs + period_type: List of period type concept IDs + period_length: Tuple of (min, max) period length in days + user_defined_period: Tuple of (start_date, end_date) to filter periods + + Returns: + Query object for observation periods + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if period_type is not None: + options['period_type'] = period_type + if period_length is not None: + options['period_length'] = period_length + if user_defined_period is not None: + options['user_defined_period'] = user_defined_period + options.update(kwargs) + + return Query( + domain='ObservationPeriod', + concept_set_id=None, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def payer_plan_period( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + age: Optional[tuple] = None, + gender: Optional[List[int]] = None, + payer_concept: Optional[List[int]] = None, + plan_concept: Optional[List[int]] = None, + period_length: Optional[tuple] = None, + **kwargs +) -> Query: + """ + Create a query for payer plan periods. + + Args: + concept_set_id: ID of the concept set defining the payer/plan + first_occurrence: If True, only use first payer plan period per person + age: Tuple of (min, max) age at period start + gender: List of gender concept IDs + payer_concept: List of payer concept IDs + plan_concept: List of plan concept IDs + period_length: Tuple of (min, max) period length in days + + Returns: + Query object for payer plan periods + """ + options = {} + if age is not None: + options['age'] = age + if gender is not None: + options['gender'] = gender + if payer_concept is not None: + options['payer_concept'] = payer_concept + if plan_concept is not None: + options['plan_concept'] = plan_concept + if period_length is not None: + options['period_length'] = period_length + options.update(kwargs) + + return Query( + domain='PayerPlanPeriod', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=options + ) + + +def location_region( + concept_set_id: Optional[int] = None, + first_occurrence: bool = False, + **kwargs +) -> Query: + """ + Create a query for location regions. + + Args: + concept_set_id: ID of the concept set defining the regions + first_occurrence: If True, only use first location per person + + Returns: + Query object for location regions + """ + return Query( + domain='LocationRegion', + concept_set_id=concept_set_id, + first_occurrence=first_occurrence, + criteria_options=kwargs + ) diff --git a/circe/capr/templates.py b/circe/capr/templates.py new file mode 100644 index 0000000..1cb2337 --- /dev/null +++ b/circe/capr/templates.py @@ -0,0 +1,140 @@ +from typing import Optional, List +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition import CohortExpression +from circe.vocabulary import ConceptSet + + +def sensitive_disease_cohort( + concept_set_id: int, + title: Optional[str] = None, + observation_prior_days: int = 0, + observation_post_days: int = 0, + concept_sets: Optional[List[ConceptSet]] = None +) -> CohortExpression: + """ + Create a sensitive disease cohort - all occurrences from a concept set. + """ + builder = ( + CohortBuilder(title or "Sensitive Disease Cohort") + .with_condition(concept_set_id) + .with_observation(prior_days=observation_prior_days, post_days=observation_post_days) + .all_occurrences() + ) + + if concept_sets: + builder.with_concept_sets(*concept_sets) + + return builder.build() + + +def specific_disease_cohort( + concept_set_id: int, + confirmation_days: int = 30, + inpatient_visit_concept_set_id: Optional[int] = None, + title: Optional[str] = None, + observation_prior_days: int = 365, + concept_sets: Optional[List[ConceptSet]] = None +) -> CohortExpression: + """ + Create a specific disease cohort requiring confirmation. + """ + builder = ( + CohortBuilder(title or "Specific Disease Cohort") + .with_condition(concept_set_id) + .first_occurrence() + .with_observation(prior_days=observation_prior_days) + ) + + if concept_sets: + builder.with_concept_sets(*concept_sets) + + # confirmation criteria + builder = builder.require_condition(concept_set_id).within_days_after(confirmation_days) + + return builder.build() + + +def acute_disease_cohort( + concept_set_id: int, + washout_days: int = 180, + title: Optional[str] = None, + observation_prior_days: int = 0, + concept_sets: Optional[List[ConceptSet]] = None +) -> CohortExpression: + """ + Create an acute disease cohort with washout period. + """ + if observation_prior_days == 0: + observation_prior_days = washout_days + + builder = ( + CohortBuilder(title or "Acute Disease Cohort") + .with_condition(concept_set_id) + .all_occurrences() + .with_observation(prior_days=observation_prior_days) + ) + + if concept_sets: + builder.with_concept_sets(*concept_sets) + + builder = builder.exclude_condition(concept_set_id).within_days_before(washout_days) + + return builder.build() + + +def chronic_disease_cohort( + concept_set_id: int, + lookback_days: int = 365, + title: Optional[str] = None, + concept_sets: Optional[List[ConceptSet]] = None +) -> CohortExpression: + """ + Create a chronic disease cohort - first ever diagnosis. + """ + builder = ( + CohortBuilder(title or "Chronic Disease Cohort") + .with_condition(concept_set_id) + .first_occurrence() + .with_observation(prior_days=lookback_days) + ) + + if concept_sets: + builder.with_concept_sets(*concept_sets) + + builder = builder.exclude_condition(concept_set_id).anytime_before() + + return builder.build() + + +def new_user_drug_cohort( + drug_concept_set_id: int, + washout_days: int = 365, + indication_concept_set_id: Optional[int] = None, + indication_lookback_days: int = 365, + title: Optional[str] = None, + observation_prior_days: int = 0, + concept_sets: Optional[List[ConceptSet]] = None +) -> CohortExpression: + """ + Create a new user drug cohort with clean washout. + """ + if observation_prior_days == 0: + observation_prior_days = washout_days + + builder = ( + CohortBuilder(title or "New User Drug Cohort") + .with_drug_era(drug_concept_set_id) + .first_occurrence() + .with_observation(prior_days=observation_prior_days) + ) + + builder = builder.exclude_drug(drug_concept_set_id).within_days_before(washout_days) + + if indication_concept_set_id: + builder = builder.require_condition(indication_concept_set_id).within_days_before(indication_lookback_days) + + if concept_sets: + builder.with_concept_sets(*concept_sets) + + return builder.build() + diff --git a/circe/capr/window.py b/circe/capr/window.py new file mode 100644 index 0000000..8a68367 --- /dev/null +++ b/circe/capr/window.py @@ -0,0 +1,219 @@ +""" +Time window functions for CohortComposer. + +These functions define the temporal relationship between events and an index point. +Modeled after OHDSI/Capr's aperture and window functions. +""" + +from typing import Optional, Union +from dataclasses import dataclass + + +@dataclass +class WindowBoundary: + """ + Represents a boundary point for a time window. + + Attributes: + days: Number of days from the index point + index: Which date to use as reference ('startDate' or 'endDate') + """ + days: int + index: str = "startDate" + + +@dataclass +class TimeWindow: + """ + Represents a time window (aperture) relative to an index event. + + Attributes: + start_window: The start boundary of the window + end_window: The end boundary of the window (optional) + use_event_end: Whether to use event end date instead of start date + restrict_visit: Whether to restrict to same visit + ignore_observation_period: Whether to ignore observation period bounds + """ + start_window: Optional['Interval'] = None + end_window: Optional['Interval'] = None + use_event_end: bool = False + restrict_visit: bool = False + ignore_observation_period: bool = False + + +@dataclass +class Interval: + """ + Represents an interval with start and end boundaries. + + Attributes: + start: Start boundary (days before index, positive = before) + end: End boundary (days after index, positive = after) + index: Which date to use as reference ('startDate' or 'endDate') + """ + start: int # days before (positive = before index) + end: int # days after (positive = after index) + index: str = "startDate" + + +@dataclass +class ObservationWindow: + """ + Represents the continuous observation requirement. + + Attributes: + prior_days: Days of observation required before index + post_days: Days of observation required after index + """ + prior_days: int = 0 + post_days: int = 0 + + +def event_starts( + before: int = 0, + after: int = 0, + index: str = "startDate" +) -> Interval: + """ + Create an interval relative to the event start date. + + Args: + before: Days before the index (positive value) + after: Days after the index (positive value) + index: Which date to use as reference ('startDate' or 'endDate') + + Returns: + Interval object + + Example: + >>> # 365 days before to 0 days after index start + >>> event_starts(before=365, after=0) + + >>> # 0 to 30 days after index start + >>> event_starts(before=0, after=30) + """ + return Interval( + start=before, + end=after, + index=index + ) + + +def event_ends( + before: int = 0, + after: int = 0, + index: str = "endDate" +) -> Interval: + """ + Create an interval relative to the event end date. + + Args: + before: Days before the index end (positive value) + after: Days after the index end (positive value) + index: Which date to use as reference (default: 'endDate') + + Returns: + Interval object + + Example: + >>> # 0 to 30 days after index end + >>> event_ends(before=0, after=30) + """ + return Interval( + start=before, + end=after, + index=index + ) + + +def during_interval( + start_window: Optional[Interval] = None, + end_window: Optional[Interval] = None, + use_event_end: bool = False, + restrict_visit: bool = False, + ignore_observation_period: bool = False +) -> TimeWindow: + """ + Create a time window (aperture) for criteria matching. + + Args: + start_window: Interval for the start of the window + end_window: Interval for the end of the window (optional) + use_event_end: If True, use event end date for matching + restrict_visit: If True, require same visit + ignore_observation_period: If True, ignore observation period bounds + + Returns: + TimeWindow object + + Example: + >>> # Events occurring 365 days before to 0 days after index + >>> during_interval( + ... start_window=event_starts(before=365, after=0) + ... ) + + >>> # Events occurring 0 to 30 days after index + >>> during_interval( + ... start_window=event_starts(before=0, after=30) + ... ) + """ + return TimeWindow( + start_window=start_window, + end_window=end_window, + use_event_end=use_event_end, + restrict_visit=restrict_visit, + ignore_observation_period=ignore_observation_period + ) + + +def continuous_observation( + prior_days: int = 0, + post_days: int = 0 +) -> ObservationWindow: + """ + Define continuous observation requirements around the index event. + + Args: + prior_days: Days of continuous observation required before index + post_days: Days of continuous observation required after index + + Returns: + ObservationWindow object + + Example: + >>> # Require 365 days of observation before index + >>> continuous_observation(prior_days=365) + + >>> # Require 365 days before and 180 days after + >>> continuous_observation(prior_days=365, post_days=180) + """ + return ObservationWindow( + prior_days=prior_days, + post_days=post_days + ) + + +# Convenience aliases for common patterns +def anytime_before(index: str = "startDate") -> Interval: + """Events any time before the index (no lower bound).""" + return Interval(start=99999, end=1, index=index) + + +def anytime_after(index: str = "startDate") -> Interval: + """Events any time after the index (no upper bound).""" + return Interval(start=0, end=99999, index=index) + + +def same_day(index: str = "startDate") -> Interval: + """Events on the same day as the index.""" + return Interval(start=0, end=0, index=index) + + +def within_days_before(days: int, index: str = "startDate") -> Interval: + """Events within N days before the index (exclusive of index day).""" + return Interval(start=days, end=1, index=index) + + +def within_days_after(days: int, index: str = "startDate") -> Interval: + """Events within N days after the index (exclusive of index day).""" + return Interval(start=0, end=days, index=index) diff --git a/circe/chat.py b/circe/chat.py deleted file mode 100644 index 60b9afb..0000000 --- a/circe/chat.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -Chat module for interacting with LLMs to generate cohort definitions. -""" -import sys -import os -import json -import re -from pathlib import Path -from typing import Optional, List, Dict, Any - -from circe.prompt_builder import CohortPromptBuilder, ConceptSet - -def chat_command(args): - """ - Entry point for the chat command. - """ - start_chat( - model=args.model, - prompt_type=args.prompt_type, - output=args.output, - concept_sets_file=args.concept_sets, - input_file=args.input_file - ) - return 0 - -def start_chat( - model: Optional[str], - prompt_type: str, - output: Optional[str], - concept_sets_file: Optional[str], - input_file: Optional[str] = None -): - """ - Start the interactive chat session. - """ - # Check dependencies - try: - import litellm - from dotenv import load_dotenv - except ImportError: - print("Error: 'litellm' and 'python-dotenv' are required for chat functionality.", file=sys.stderr) - print("Please install them with: pip install litellm python-dotenv", file=sys.stderr) - return 1 - - # Load environment variables - load_dotenv() - - # Determine model - if not model: - model = os.getenv("LLM_MODEL", "gpt-4o") - # Handle optional temperature if needed, but litellm handles it or we pass it - - print(f"🚀 Starting Circe Chat") - print(f" Model: {model}") - print(f" Prompt: {prompt_type}") - print("-" * 50) - - # Load concept sets if provided - concept_sets_data = [] - if concept_sets_file: - try: - with open(concept_sets_file, 'r') as f: - raw_data = json.load(f) - # Expecting list of dicts with id, name - for item in raw_data: - concept_sets_data.append(ConceptSet( - id=item.get('id'), - name=item.get('name'), - description=item.get('description') - )) - print(f" Loaded {len(concept_sets_data)} concept sets from {concept_sets_file}") - except Exception as e: - print(f"Error loading concept sets: {e}", file=sys.stderr) - return 1 - - # Initialize builder - builder = CohortPromptBuilder() - - try: - system_prompt = builder.load_system_prompt(prompt_type) - except Exception as e: - print(f"Error loading system prompt: {e}", file=sys.stderr) - return 1 - - # Add inference instruction if no concept sets provided - if not concept_sets_data: - system_prompt += "\n\nIMPORTANT: No concept sets were provided.\n" \ - "You MUST infer appropriate concept sets from the clinical description.\n" \ - "1. Define them using `circe.vocabulary.concept_set`.\n" \ - "2. Add them to the builder using `.with_concept_sets(...)`.\n" \ - "3. Use valid OMOP Concept IDs (or realistic placeholders if exact IDs are unknown)." - - messages = [{"role": "system", "content": system_prompt}] - - print("\nPlease describe the cohort you want to build (or type 'quit' to exit):") - - first_turn = True - initial_input = None - - if input_file: - try: - initial_input = Path(input_file).read_text() - print(f" Loaded clinical description from {input_file}") - except Exception as e: - print(f"Error reading input file: {e}", file=sys.stderr) - return 1 - - while True: - try: - if first_turn and initial_input: - user_input = initial_input - print(f"\n> [Processing input from file...]") - else: - user_input = input("\n> ") - except (EOFError, KeyboardInterrupt): - print("\nExiting chat.") - break - - if user_input.lower() in ('quit', 'exit'): - break - - if not user_input.strip(): - continue - - # Turn off first_turn flag after we have a valid input - if first_turn: - first_turn = False - - # Construct user message - if len(messages) == 1: - # First user message - format nicely - formatted_content = f"\n---\n## User Task\n**Clinical Description:**\n{user_input}\n" - if concept_sets_data: - formatted_content += builder.format_concept_sets(concept_sets_data) - else: - formatted_content += "\nNo pre-defined concept sets provided. Please infer them." - - messages.append({"role": "user", "content": formatted_content}) - else: - messages.append({"role": "user", "content": user_input}) - - # Call AI - print("Thinking...") - try: - response = litellm.completion(model=model, messages=messages) - content = response.choices[0].message.content - print("\n" + content) - - messages.append({"role": "assistant", "content": content}) - - # Extract and process code - _process_response_content(content, output) - - except Exception as e: - print(f"\nError during API call: {e}", file=sys.stderr) - - -def _process_response_content(content: str, output_base: Optional[str]): - """ - Extract logic to find Python code, save it, and attempt to run it to generate JSON. - """ - # Look for python code block - code_match = re.search(r'```python\n(.*?)\n```', content, re.DOTALL) - if not code_match: - return - - code = code_match.group(1) - - # Determine output filenames - if output_base: - py_file = Path(output_base + ".py") - json_file = Path(output_base + ".json") - else: - # Default name - py_file = Path("cohort_definition.py") - json_file = Path("cohort_definition.json") - - # Save Python code - try: - py_file.write_text(code) - print(f"\n✅ Saved Python code to {py_file}") - except Exception as e: - print(f"Error saving Python file: {e}") - return - - # Attempt to execute and save JSON - # This involves running the code and capturing the 'cohort' variable or 'expression' variable - print(" Attempting to generate JSON...") - - try: - # Create a local scope - local_scope = {} - # We need to make sure the CWD is in path so imports work? - # Assuming we are running from project root or installed package - - exec(code, {}, local_scope) - - # Look for a CohortExpression or CohortBuilder object - # The prompt usually produces: - # cohort = CohortBuilder(...).build() - # So we look for 'cohort' - - cohort_obj = local_scope.get('cohort') - if not cohort_obj: - # Try to find any variable that is a tuple (builder) or CohortExpression - for k, v in local_scope.items(): - if hasattr(v, 'to_json'): # CohortExpression has to_json? Check API. - cohort_obj = v - break - - if cohort_obj: - # If it's the builder (tuple in some cases?), checks if it has build() - # But the prompt says `.build()` returns CohortExpression. - - # Check if it has 'to_json' or similar. - # circe.cohortdefinition.CohortExpression uses Pydantic? - # It inherits from Serializable? - - json_output = None - if hasattr(cohort_obj, 'json'): # Pydantic v1/v2 - json_output = cohort_obj.model_dump_json(indent=2) if hasattr(cohort_obj, 'model_dump_json') else cohort_obj.json(indent=2) - elif hasattr(cohort_obj, 'to_json'): - json_output = cohort_obj.to_json() - else: - # It might be a dict? - if isinstance(cohort_obj, dict): - json_output = json.dumps(cohort_obj, indent=2) - - if json_output: - json_file.write_text(json_output) - print(f"✅ Saved Cohort JSON to {json_file}") - else: - print(" Could not serialize 'cohort' object to JSON.") - else: - print(" Could not find 'cohort' variable in executed code.") - - except Exception as e: - print(f" Error executing generated code: {e}") - print(" (Ensure the generated code is valid and all dependencies are installed)") diff --git a/circe/cohort_builder/__init__.py b/circe/cohort_builder/__init__.py new file mode 100644 index 0000000..8315408 --- /dev/null +++ b/circe/cohort_builder/__init__.py @@ -0,0 +1,30 @@ +""" +State-Based Fluent Builder for OHDSI Cohort Definitions. + +This module provides a guided, LLM-friendly API where each method returns +an object with only valid next methods. The API guides users through +cohort construction step by step. + +Example: + >>> from circe.cohort_builder import Cohort + >>> + >>> cohort = ( + ... Cohort("T2DM Patients") + ... .with_condition(concept_set_id=1) + ... .first_occurrence() + ... .with_observation(prior_days=365) + ... .require_condition(concept_set_id=1) + ... .within_days_before(365) + ... .exclude_drug(concept_set_id=2) + ... .anytime_before() + ... .build() + ... ) +""" + +from circe.cohort_builder.builder import CohortBuilder, CohortWithEntry, CohortWithCriteria + +__all__ = [ + "CohortBuilder", + "CohortWithEntry", + "CohortWithCriteria" +] diff --git a/circe/cohort_builder/builder.py b/circe/cohort_builder/builder.py new file mode 100644 index 0000000..e7c68b9 --- /dev/null +++ b/circe/cohort_builder/builder.py @@ -0,0 +1,2758 @@ +""" +State-Based Fluent Builder for OHDSI Cohort Definitions. + +This module implements a guided API where each method returns an object +with only valid next methods, making it ideal for LLM agents. + +The state progression is: + Cohort -> CohortWithEntry -> CohortWithCriteria -> CohortExpression +""" + +from typing import Optional, List, Union, Dict, Any, TYPE_CHECKING +from dataclasses import dataclass, field +import copy + +from circe.cohort_builder.query_builder import ( + QueryConfig, TimeWindow, BaseQuery, + ConditionQuery, DrugQuery, DrugEraQuery, MeasurementQuery, + ProcedureQuery, VisitQuery, ObservationQuery, DeathQuery, + ConditionEraQuery, DeviceExposureQuery, SpecimenQuery, + ObservationPeriodQuery, PayerPlanPeriodQuery, LocationRegionQuery, + VisitDetailQuery, DoseEraQuery, CriteriaConfig, GroupConfig, CriteriaGroupBuilder +) + +# Import circe models for conversion +from circe.cohortdefinition import ( + CohortExpression, PrimaryCriteria, InclusionRule, + CorelatedCriteria, Occurrence, ConditionOccurrence, ConditionEra, DrugExposure, + ProcedureOccurrence, Measurement, Observation, VisitOccurrence, VisitDetail, + DeviceExposure, Death, DrugEra, DoseEra, Specimen, ObservationPeriod, + PayerPlanPeriod, LocationRegion, DemographicCriteria +) +from circe.cohortdefinition.core import ( + ObservationFilter, ResultLimit, Window, WindowBound, + NumericRange, DateAdjustment +) +from circe.cohortdefinition.criteria import CriteriaGroup as CirceCriteriaGroup +from circe.vocabulary import ConceptSet, Concept + +@dataclass +class CohortSettings: + """Stores cohort-wide settings like exit strategy and era logic.""" + exit_strategy_type: str = "observation" # observation, date_offset + exit_offset_days: int = 0 + exit_offset_field: str = "startDate" + era_days: int = 0 + censor_queries: List[QueryConfig] = field(default_factory=list) + + # Custom Era Strategy + custom_era_drug_codeset_id: Optional[int] = None + custom_era_gap_days: int = 30 + custom_era_offset: int = 0 + custom_era_days_supply_override: Optional[int] = None + + # Demographics + gender_concepts: List[int] = field(default_factory=list) + race_concepts: List[int] = field(default_factory=list) + ethnicity_concepts: List[int] = field(default_factory=list) + age_min: Optional[int] = None + age_max: Optional[int] = None + +class CohortBuilder: + """ + Starting point for building a cohort definition. + + Supports both context manager and fluent API patterns. + + Context Manager Example (Recommended): + >>> with CohortBuilder("My Cohort") as cohort: + ... cohort.with_condition(1) + ... cohort.require_drug(2, within_days_before=30) + >>> expression = cohort.expression # Built CohortExpression + + Fluent API Example: + >>> expression = (CohortBuilder("My Cohort") + ... .with_condition(1) + ... .require_drug(2, within_days_before=30) + ... .build()) + """ + + def __init__(self, title: str = "Untitled Cohort"): + """ + Create a new cohort builder. + + Args: + title: Human-readable title for the cohort + """ + self._title = title + self._concept_sets: List[ConceptSet] = [] + # Context manager state + self._state: Optional['CohortWithEntry'] = None + self._expression: Optional[CohortExpression] = None + self._in_context: bool = False + + # ========================================================================= + # CONTEXT MANAGER PROTOCOL + # ========================================================================= + + def __enter__(self) -> 'CohortBuilder': + """Enter context manager mode.""" + self._in_context = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + """Exit context manager and auto-build the cohort.""" + self._in_context = False + if self._state is not None: + self._expression = self._state.build() + return False # Don't suppress exceptions + + @property + def expression(self) -> CohortExpression: + """ + Get the built CohortExpression. + + Only available after exiting the context manager. + + Raises: + RuntimeError: If accessed before context exit or if no entry event defined + """ + if self._expression is None: + if self._in_context: + raise RuntimeError( + "Cannot access 'expression' while inside the context manager. " + "Exit the 'with' block first." + ) + raise RuntimeError( + "No cohort has been built. Define at least one entry event " + "(e.g., with_condition, with_drug) inside the context manager." + ) + return self._expression + + def _ensure_state(self) -> Union['CohortWithEntry', 'CohortWithCriteria']: + """Get or create the current state for method chaining within context.""" + if self._state is None: + raise RuntimeError( + "No entry event defined. Call with_condition(), with_drug(), etc. first." + ) + return self._state + + def _update_state(self, new_state: Union['CohortWithEntry', 'CohortWithCriteria', None]) -> None: + """Update internal state after a method that may change state.""" + if new_state is not None: + self._state = new_state + + # ========================================================================= + # CONTEXT MANAGER DELEGATION METHODS + # ========================================================================= + + def first_occurrence(self) -> 'CohortBuilder': + """Only use the first occurrence per person for entry events.""" + self._ensure_state().first_occurrence() + return self + + def all_occurrences(self) -> 'CohortBuilder': + """Use all occurrences per person.""" + self._ensure_state().all_occurrences() + return self + + def with_observation_window(self, prior_days: int = 0, post_days: int = 0) -> 'CohortBuilder': + """Set continuous observation requirements.""" + self._ensure_state().with_observation(prior_days=prior_days, post_days=post_days) + return self + + def min_age(self, age: int) -> 'CohortBuilder': + """Require minimum age at entry.""" + self._ensure_state().min_age(age) + return self + + def max_age(self, age: int) -> 'CohortBuilder': + """Require maximum age at entry.""" + self._ensure_state().max_age(age) + return self + + def require_gender(self, *concept_ids: int) -> 'CohortBuilder': + """Require specific gender concept IDs.""" + self._ensure_state().require_gender(*concept_ids) + return self + + def require_race(self, *concept_ids: int) -> 'CohortBuilder': + """Require specific race concept IDs.""" + self._ensure_state().require_race(*concept_ids) + return self + + def require_ethnicity(self, *concept_ids: int) -> 'CohortBuilder': + """Require specific ethnicity concept IDs.""" + self._ensure_state().require_ethnicity(*concept_ids) + return self + + def require_age(self, min_age: Optional[int] = None, max_age: Optional[int] = None) -> 'CohortBuilder': + """Require specific age range.""" + self._ensure_state().require_age(min_age, max_age) + return self + + def require_condition(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required condition occurrence.""" + result = self._ensure_state().require_condition(concept_set_id, **kwargs) + self._update_state(result) + return self + + def require_drug(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required drug exposure.""" + result = self._ensure_state().require_drug(concept_set_id, **kwargs) + self._update_state(result) + return self + + def require_procedure(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required procedure occurrence.""" + result = self._ensure_state().require_procedure(concept_set_id, **kwargs) + self._update_state(result) + return self + + def require_measurement(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required measurement.""" + result = self._ensure_state().require_measurement(concept_set_id, **kwargs) + self._update_state(result) + return self + + def require_observation(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required observation.""" + result = self._ensure_state().require_observation(concept_set_id, **kwargs) + self._update_state(result) + return self + + def require_visit(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Add a required visit occurrence.""" + result = self._ensure_state().require_visit(concept_set_id, **kwargs) + self._update_state(result) + return self + + def exclude_condition(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Exclude patients with a condition occurrence.""" + result = self._ensure_state().exclude_condition(concept_set_id, **kwargs) + self._update_state(result) + return self + + def exclude_drug(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Exclude patients with a drug exposure.""" + result = self._ensure_state().exclude_drug(concept_set_id, **kwargs) + self._update_state(result) + return self + + def exclude_procedure(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Exclude patients with a procedure occurrence.""" + result = self._ensure_state().exclude_procedure(concept_set_id, **kwargs) + self._update_state(result) + return self + + def exclude_measurement(self, concept_set_id: int, **kwargs) -> 'CohortBuilder': + """Exclude patients with a measurement.""" + result = self._ensure_state().exclude_measurement(concept_set_id, **kwargs) + self._update_state(result) + return self + + def include_rule(self, name: str) -> 'InclusionRuleContext': + """ + Create a named inclusion rule context. + + Use with a nested 'with' block to group criteria: + + with CohortBuilder("My Cohort") as cohort: + cohort.with_condition(1) + + with cohort.include_rule("Prior Treatment") as rule: + rule.require_drug(2, anytime_before=True) + """ + return InclusionRuleContext(self, name) + + def exit_at_observation_end(self) -> 'CohortBuilder': + """Exit cohort at the end of the observation period.""" + self._ensure_state().exit_at_observation_end() + return self + + def exit_after_days(self, days: int, from_field: str = "startDate") -> 'CohortBuilder': + """Exit cohort N days after index start/end.""" + self._ensure_state().exit_after_days(days, from_field) + return self + + def collapse_era(self, days: int) -> 'CohortBuilder': + """Set the number of gap days to collapse successive cohort entries.""" + self._ensure_state()._to_criteria().collapse_era(days) + return self + + # ========================================================================= + # CONCEPT SET MANAGEMENT + # ========================================================================= + + def with_concept_sets(self, *concept_sets: ConceptSet) -> 'CohortBuilder': + """Add concept sets to the cohort.""" + self._concept_sets.extend(concept_sets) + return self + + # ========================================================================= + # COHORT MODIFICATION METHODS + # ========================================================================= + + @classmethod + def from_expression(cls, expression: CohortExpression, title: Optional[str] = None) -> 'CohortBuilder': + """ + Create a builder from an existing CohortExpression for modification. + + This creates a modifiable copy of the cohort. The original expression is preserved. + + Args: + expression: Existing CohortExpression to modify + title: Optional new title (keeps original if not provided) + + Returns: + CohortBuilder initialized with the existing expression state + + Raises: + ValueError: If expression has no primary criteria + + Example: + >>> # Load existing cohort + >>> existing = CohortExpression.model_validate_json(json_data) + >>> + >>> # Modify it + >>> with CohortBuilder.from_expression(existing) as cohort: + ... cohort.require_drug(5, within_days_before=30) + ... cohort.remove_inclusion_rule("Old Rule") + >>> + >>> modified = cohort.expression + """ + # Create new builder + builder = cls(title or expression.title or "Modified Cohort") + + # Deep copy concept sets to avoid mutations + if expression.concept_sets: + builder._concept_sets = copy.deepcopy(expression.concept_sets) + + # Reconstruct state from expression + if not expression.primary_criteria: + raise ValueError("Cannot modify cohort without primary criteria (entry events)") + + builder._state = cls._reconstruct_state_from_expression(builder, expression) + + return builder + + def remove_inclusion_rule(self, name: str) -> 'CohortBuilder': + """ + Remove an inclusion rule by name. + + Args: + name: Name of the inclusion rule to remove + + Raises: + KeyError: If no rule with the given name exists + RuntimeError: If called before entry event is defined + + Returns: + Self for chaining + + Example: + >>> with CohortBuilder.from_expression(expr) as cohort: + ... cohort.remove_inclusion_rule("Prior Treatment") + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + + # Find and remove the rule + found = False + for i, rule in enumerate(criteria_state._rules): + if rule["name"] == name: + criteria_state._rules.pop(i) + found = True + break + + if not found: + raise KeyError(f"No inclusion rule found with name: {name}") + + self._update_state(criteria_state) + return self + + def remove_censoring_criteria(self, + criteria_type: Optional[str] = None, + concept_set_id: Optional[int] = None, + index: Optional[int] = None) -> 'CohortBuilder': + """ + Remove censoring criteria by type, concept set ID, or index. + + Exactly one argument must be provided. + + Args: + criteria_type: Type of criteria (e.g., "ConditionOccurrence", "DrugExposure", "Death") + concept_set_id: Concept set ID to match + index: Zero-based index of criteria to remove + + Raises: + ValueError: If no matching criteria found, multiple arguments provided, or no arguments + RuntimeError: If called before entry event is defined + + Returns: + Self for chaining + + Example: + >>> # Remove by type + >>> cohort.remove_censoring_criteria(criteria_type="Death") + >>> + >>> # Remove by concept set + >>> cohort.remove_censoring_criteria(concept_set_id=5) + >>> + >>> # Remove by index + >>> cohort.remove_censoring_criteria(index=0) + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + + # Validate arguments + args_provided = sum([criteria_type is not None, concept_set_id is not None, index is not None]) + if args_provided == 0: + raise ValueError("Must provide one of: criteria_type, concept_set_id, or index") + if args_provided > 1: + raise ValueError("Can only provide one of: criteria_type, concept_set_id, or index") + + censor_queries = criteria_state._settings.censor_queries + + # Remove by index + if index is not None: + if index < 0 or index >= len(censor_queries): + raise ValueError(f"Index {index} out of range (0-{len(censor_queries)-1})") + censor_queries.pop(index) + self._update_state(criteria_state) + return self + + # Remove by type or concept set ID + found_idx = None + for i, query in enumerate(censor_queries): + if criteria_type is not None and query.domain == criteria_type: + found_idx = i + break + if concept_set_id is not None and query.concept_set_id == concept_set_id: + found_idx = i + break + + if found_idx is None: + if criteria_type: + raise ValueError(f"No censoring criteria found with type: {criteria_type}") + else: + raise ValueError(f"No censoring criteria found with concept_set_id: {concept_set_id}") + + censor_queries.pop(found_idx) + self._update_state(criteria_state) + return self + + def remove_entry_event(self, + criteria_type: Optional[str] = None, + concept_set_id: Optional[int] = None, + index: Optional[int] = None) -> 'CohortBuilder': + """ + Remove an entry event from primary criteria. + + Note: At least one entry event must remain after removal. + Exactly one argument must be provided. + + Args: + criteria_type: Type of criteria (e.g., "ConditionOccurrence", "DrugExposure") + concept_set_id: Concept set ID to match + index: Zero-based index of entry event to remove + + Raises: + ValueError: If removal would leave no entry events, no match found, or invalid arguments + RuntimeError: If called before entry event is defined + + Returns: + Self for chaining + + Example: + >>> # Remove condition entry event with specific concept set + >>> cohort.remove_entry_event(concept_set_id=1) + >>> + >>> # Remove by type (removes first match) + >>> cohort.remove_entry_event(criteria_type="DrugExposure") + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + + # Validate arguments + args_provided = sum([criteria_type is not None, concept_set_id is not None, index is not None]) + if args_provided == 0: + raise ValueError("Must provide one of: criteria_type, concept_set_id, or index") + if args_provided > 1: + raise ValueError("Can only provide one of: criteria_type, concept_set_id, or index") + + entry_configs = criteria_state._entry_configs + + # Check that we won't remove the last entry event + if len(entry_configs) <= 1: + raise ValueError("Cannot remove the last entry event. At least one entry event must remain.") + + # Remove by index + if index is not None: + if index < 0 or index >= len(entry_configs): + raise ValueError(f"Index {index} out of range (0-{len(entry_configs)-1})") + entry_configs.pop(index) + self._update_state(criteria_state) + return self + + # Remove by type or concept set ID + found_idx = None + for i, config in enumerate(entry_configs): + if criteria_type is not None and config.domain == criteria_type: + found_idx = i + break + if concept_set_id is not None and config.concept_set_id == concept_set_id: + found_idx = i + break + + if found_idx is None: + if criteria_type: + raise ValueError(f"No entry event found with type: {criteria_type}") + else: + raise ValueError(f"No entry event found with concept_set_id: {concept_set_id}") + + entry_configs.pop(found_idx) + self._update_state(criteria_state) + return self + + def remove_concept_set(self, concept_set_id: int) -> 'CohortBuilder': + """ + Remove a concept set by ID. + + Warning: This does not remove criteria that reference this concept set. + Consider using remove_all_references() to clean up orphaned references. + + Args: + concept_set_id: ID of the concept set to remove + + Raises: + KeyError: If no concept set with the given ID exists + + Returns: + Self for chaining + + Example: + >>> cohort.remove_concept_set(concept_set_id=3) + """ + found_idx = None + for i, cs in enumerate(self._concept_sets): + if cs.id == concept_set_id: + found_idx = i + break + + if found_idx is None: + raise KeyError(f"No concept set found with ID: {concept_set_id}") + + self._concept_sets.pop(found_idx) + return self + + def remove_all_references(self, concept_set_id: int) -> 'CohortBuilder': + """ + Remove a concept set and all criteria that reference it. + + This removes: + - The concept set itself + - Entry events using this concept set + - Inclusion/exclusion criteria using this concept set + - Censoring criteria using this concept set + + Args: + concept_set_id: ID of the concept set to remove + + Returns: + Self for chaining + + Example: + >>> # Remove diabetes concept set and all related criteria + >>> cohort.remove_all_references(concept_set_id=3) + """ + # Remove concept set + try: + self.remove_concept_set(concept_set_id) + except KeyError: + pass # Concept set doesn't exist, continue with cleanup + + # Remove from entry events (if possible) + if self._state: + criteria_state = self._state._to_criteria() if hasattr(self._state, '_to_criteria') else self._state + + # Remove entry events with this concept set (keep at least one) + entry_configs = criteria_state._entry_configs + filtered_entries = [cfg for cfg in entry_configs if cfg.concept_set_id != concept_set_id] + if filtered_entries: # Only update if we have remaining entries + criteria_state._entry_configs = filtered_entries + + # Remove censoring criteria with this concept set + criteria_state._settings.censor_queries = [ + q for q in criteria_state._settings.censor_queries + if q.concept_set_id != concept_set_id + ] + + # Remove inclusion/exclusion criteria with this concept set + for rule in criteria_state._rules: + self._remove_criteria_with_concept_set(rule["group"], concept_set_id) + + self._update_state(criteria_state) + + return self + + def _remove_criteria_with_concept_set(self, group, concept_set_id: int): + """Recursively remove criteria referencing a concept set from a group.""" + from circe.cohort_builder.query_builder import GroupConfig, CriteriaConfig + + if not hasattr(group, 'criteria'): + return + + filtered_criteria = [] + for criterion in group.criteria: + if isinstance(criterion, GroupConfig): + # Recursively clean nested groups + self._remove_criteria_with_concept_set(criterion, concept_set_id) + # Keep the group if it still has criteria + if criterion.criteria: + filtered_criteria.append(criterion) + elif isinstance(criterion, CriteriaConfig): + # Keep criteria that don't reference this concept set + if criterion.query_config.concept_set_id != concept_set_id: + filtered_criteria.append(criterion) + else: + # Keep other types as-is + filtered_criteria.append(criterion) + + group.criteria = filtered_criteria + + def clear_inclusion_rules(self) -> 'CohortBuilder': + """ + Remove all inclusion rules. + + Returns: + Self for chaining + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + criteria_state._rules = [{"name": "Inclusion Criteria", "group": GroupConfig(type="ALL")}] + self._update_state(criteria_state) + return self + + def clear_censoring_criteria(self) -> 'CohortBuilder': + """ + Remove all censoring criteria. + + Returns: + Self for chaining + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + criteria_state._settings.censor_queries = [] + self._update_state(criteria_state) + return self + + def clear_demographic_criteria(self) -> 'CohortBuilder': + """ + Clear all demographic restrictions (age, gender, race, ethnicity). + + Returns: + Self for chaining + """ + state = self._ensure_state() + criteria_state = state._to_criteria() if hasattr(state, '_to_criteria') else state + criteria_state._settings.gender_concepts = [] + criteria_state._settings.race_concepts = [] + criteria_state._settings.ethnicity_concepts = [] + criteria_state._settings.age_min = None + criteria_state._settings.age_max = None + self._update_state(criteria_state) + return self + + @staticmethod + def _reconstruct_state_from_expression(builder: 'CohortBuilder', + expression: CohortExpression) -> 'CohortWithCriteria': + """ + Reconstruct builder state from a CohortExpression. + + This reverse-engineers the internal state so modifications can be applied. + + Args: + builder: The parent CohortBuilder instance + expression: The CohortExpression to reconstruct from + + Returns: + CohortWithCriteria state ready for modifications + """ + import copy + from circe.cohort_builder.query_builder import GroupConfig, CriteriaConfig + + # Extract entry events from primary criteria + entry_configs = [] + if expression.primary_criteria and expression.primary_criteria.criteria_list: + for criteria in expression.primary_criteria.criteria_list: + config = builder._criteria_to_query_config(criteria) + entry_configs.append(config) + + # Extract observation window + prior_obs = 0 + post_obs = 0 + if expression.primary_criteria and expression.primary_criteria.observation_window: + prior_obs = expression.primary_criteria.observation_window.prior_days or 0 + post_obs = expression.primary_criteria.observation_window.post_days or 0 + + # Extract limits + limit = "All" + qualified_limit = "First" + expression_limit = "All" + + if expression.primary_criteria and expression.primary_criteria.primary_limit: + limit = expression.primary_criteria.primary_limit.type or "All" + if expression.qualified_limit: + qualified_limit = expression.qualified_limit.type or "First" + if expression.expression_limit: + expression_limit = expression.expression_limit.type or "All" + + # Extract settings + settings = builder._extract_settings_from_expression(expression) + + # Create CohortWithCriteria state + state = CohortWithCriteria( + parent=builder, + entry_configs=entry_configs, + prior_observation=prior_obs, + post_observation=post_obs, + limit=limit, + qualified_limit=qualified_limit, + expression_limit=expression_limit, + settings=settings + ) + + # Reconstruct inclusion rules + if expression.inclusion_rules: + # Clear default rule + state._rules = [] + + for rule in expression.inclusion_rules: + # Skip demographic criteria rule (handled in settings) + if rule.name == "Demographic Criteria": + continue + + group = GroupConfig(type="ALL") + if rule.expression: + builder._reconstruct_criteria_group(rule.expression, group) + + state._rules.append({"name": rule.name, "group": group}) + + # If no rules were added, ensure we have the default rule + if not state._rules: + state._rules = [{"name": "Inclusion Criteria", "group": GroupConfig(type="ALL")}] + + return state + + @staticmethod + def _criteria_to_query_config(criteria): + """ + Convert a Criteria object to a QueryConfig for the builder. + + Maps CIRCE criteria objects back to builder query configurations. + """ + from circe.cohort_builder.query_builder import QueryConfig, TimeWindow + + # Determine domain from criteria type + domain = criteria.__class__.__name__ + + # Extract concept set ID + concept_set_id = getattr(criteria, 'codeset_id', None) + + # Create basic config + config = QueryConfig( + domain=domain, + concept_set_id=concept_set_id, + time_window=TimeWindow() + ) + + # Extract first occurrence flag + if hasattr(criteria, 'first') and criteria.first: + config.first_occurrence = True + + # Extract age constraints + if hasattr(criteria, 'age') and criteria.age: + if hasattr(criteria.age, 'value'): + config.age_min = criteria.age.value + if hasattr(criteria.age, 'extent'): + config.age_max = criteria.age.extent + + # Note: More complex criteria attributes (date ranges, correlated criteria, etc.) + # are not fully reconstructed. This is a simplified mapping for common cases. + + return config + + @staticmethod + def _extract_settings_from_expression(expression: CohortExpression) -> CohortSettings: + """ + Extract cohort settings from expression. + + Extracts: + - Exit strategy (observation end, date offset, custom era) + - Era collapse settings + - Censoring criteria + - Demographic criteria + """ + from circe.cohortdefinition.core import DateOffsetStrategy, CustomEraStrategy + + settings = CohortSettings() + + # Extract exit strategy + if expression.end_strategy: + if isinstance(expression.end_strategy, DateOffsetStrategy): + settings.exit_strategy_type = "date_offset" + settings.exit_offset_field = expression.end_strategy.date_field or "startDate" + settings.exit_offset_days = expression.end_strategy.offset or 0 + elif isinstance(expression.end_strategy, CustomEraStrategy): + settings.exit_strategy_type = "custom_era" + settings.custom_era_drug_codeset_id = expression.end_strategy.drug_codeset_id + settings.custom_era_gap_days = expression.end_strategy.gap_days or 30 + settings.custom_era_offset = expression.end_strategy.offset or 0 + settings.custom_era_days_supply_override = expression.end_strategy.days_supply_override + + # Extract era collapse + if expression.collapse_settings and expression.collapse_settings.era_pad: + settings.era_days = expression.collapse_settings.era_pad + + # Extract censoring criteria + if expression.censoring_criteria: + for criteria in expression.censoring_criteria: + config = CohortBuilder._criteria_to_query_config(criteria) + settings.censor_queries.append(config) + + # Extract demographic criteria from inclusion rules + if expression.inclusion_rules: + for rule in expression.inclusion_rules: + if rule.name == "Demographic Criteria" and rule.expression: + if rule.expression.demographic_criteria_list: + demo = rule.expression.demographic_criteria_list[0] + + if demo.gender: + settings.gender_concepts = [c.concept_id for c in demo.gender] + if demo.race: + settings.race_concepts = [c.concept_id for c in demo.race] + if demo.ethnicity: + settings.ethnicity_concepts = [c.concept_id for c in demo.ethnicity] + if demo.age: + settings.age_min = demo.age.value + settings.age_max = demo.age.extent + + return settings + + @staticmethod + def _reconstruct_criteria_group(circe_group, builder_group): + """ + Recursively reconstruct a CriteriaGroup from CIRCE format to builder format. + + This is a simplified reconstruction that handles common cases. + Complex nested groups and correlated criteria may not be fully supported. + """ + from circe.cohort_builder.query_builder import GroupConfig, CriteriaConfig + + # Set group type + if circe_group.type: + builder_group.type = circe_group.type + + # Reconstruct count for AT_LEAST groups + if hasattr(circe_group, 'count') and circe_group.count: + builder_group.count = circe_group.count + + # Reconstruct criteria list + if hasattr(circe_group, 'criteria_list') and circe_group.criteria_list: + for item in circe_group.criteria_list: + # Check if it's a windowed criteria (has criteria field) + if hasattr(item, 'criteria'): + # This is a windowed criteria + config = CohortBuilder._criteria_to_query_config(item.criteria) + + # Determine if it's an exclusion based on occurrence count + is_exclusion = False + if hasattr(item, 'occurrence') and item.occurrence: + # Count = 0 typically means exclusion + is_exclusion = item.occurrence.count == 0 + + builder_group.criteria.append(CriteriaConfig( + query_config=config, + is_exclusion=is_exclusion + )) + elif hasattr(item, 'type'): + # This is a nested group + nested_group = GroupConfig(type=item.type) + CohortBuilder._reconstruct_criteria_group(item, nested_group) + builder_group.criteria.append(nested_group) + + # Handle groups (nested criteria groups) + if hasattr(circe_group, 'groups') and circe_group.groups: + for nested in circe_group.groups: + nested_group = GroupConfig(type=nested.type if hasattr(nested, 'type') else "ALL") + CohortBuilder._reconstruct_criteria_group(nested, nested_group) + builder_group.criteria.append(nested_group) + + # ========================================================================= + # ENTRY EVENT METHODS + # ========================================================================= + + def _create_entry_event( + self, + query_class: type, + concept_set_id: Optional[int] = None, + **kwargs + ) -> Union['CohortBuilder', 'CohortWithEntry']: + """ + Helper method to create entry events with consistent logic. + + Args: + query_class: The query class to instantiate (e.g., ConditionQuery, DrugQuery) + concept_set_id: Optional concept set ID (some queries like Death don't need this) + **kwargs: Additional parameters to pass to apply_params() + + Returns: + Self if in context manager mode, otherwise CohortWithEntry + """ + # Create query with or without concept_set_id + if concept_set_id is not None: + query = query_class(concept_set_id, is_entry=True) + else: + query = query_class(is_entry=True) + + # Apply any additional parameters + query.apply_params(**kwargs) + + # Create cohort state and link parent + cohort = CohortWithEntry(self, query) + query._parent = cohort + + # Handle context manager vs fluent API + if self._in_context: + self._state = cohort + return self + return cohort + + def with_condition(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a condition occurrence.""" + return self._create_entry_event(ConditionQuery, concept_set_id, **kwargs) + + def with_drug(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a drug exposure.""" + return self._create_entry_event(DrugQuery, concept_set_id, **kwargs) + + def with_drug_era(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a drug era.""" + return self._create_entry_event(DrugEraQuery, concept_set_id, **kwargs) + + def with_procedure(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a procedure occurrence.""" + return self._create_entry_event(ProcedureQuery, concept_set_id, **kwargs) + + def with_measurement(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a measurement.""" + return self._create_entry_event(MeasurementQuery, concept_set_id, **kwargs) + + def with_visit(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a visit occurrence.""" + return self._create_entry_event(VisitQuery, concept_set_id, **kwargs) + + def with_observation(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to an observation (with concept set).""" + return self._create_entry_event(ObservationQuery, concept_set_id, **kwargs) + + def with_condition_era(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a condition era.""" + return self._create_entry_event(ConditionEraQuery, concept_set_id, **kwargs) + + def with_device_exposure(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a device exposure.""" + return self._create_entry_event(DeviceExposureQuery, concept_set_id, **kwargs) + + def with_specimen(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a specimen.""" + return self._create_entry_event(SpecimenQuery, concept_set_id, **kwargs) + + def with_death(self) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to death.""" + return self._create_entry_event(DeathQuery) + + def with_observation_period(self) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to an observation period.""" + return self._create_entry_event(ObservationPeriodQuery) + + def with_payer_plan_period(self, concept_set_id: int, **kwargs) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a payer plan period.""" + return self._create_entry_event(PayerPlanPeriodQuery, concept_set_id, **kwargs) + + def with_location_region(self, concept_set_id: int) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a location/region.""" + return self._create_entry_event(LocationRegionQuery, concept_set_id) + + def with_visit_detail(self, concept_set_id: int) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a visit detail.""" + return self._create_entry_event(VisitDetailQuery, concept_set_id) + + def with_dose_era(self, concept_set_id: int) -> Union['CohortBuilder', 'CohortWithEntry']: + """Set entry event to a dose era.""" + return self._create_entry_event(DoseEraQuery, concept_set_id) + +class InclusionRuleContext: + """ + Context manager for named inclusion rules. + + Provides a clean way to group criteria under a named rule for attrition tracking. + """ + + def __init__(self, builder: CohortBuilder, name: str): + self._builder = builder + self._name = name + + def __enter__(self) -> 'InclusionRuleContext': + """Enter the inclusion rule context.""" + state = self._builder._ensure_state() + new_state = state.begin_rule(self._name) + self._builder._update_state(new_state) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + """Exit the inclusion rule context.""" + state = self._builder._ensure_state() + new_state = state.end_rule() + self._builder._update_state(new_state) + return False + + def require_condition(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required condition occurrence to this rule.""" + result = self._builder._ensure_state().require_condition(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def require_drug(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required drug exposure to this rule.""" + result = self._builder._ensure_state().require_drug(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def require_procedure(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required procedure occurrence to this rule.""" + result = self._builder._ensure_state().require_procedure(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def require_measurement(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required measurement to this rule.""" + result = self._builder._ensure_state().require_measurement(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def require_observation(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required observation to this rule.""" + result = self._builder._ensure_state().require_observation(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def require_visit(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Add a required visit occurrence to this rule.""" + result = self._builder._ensure_state().require_visit(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def exclude_condition(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Exclude patients with a condition occurrence.""" + result = self._builder._ensure_state().exclude_condition(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def exclude_drug(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Exclude patients with a drug exposure.""" + result = self._builder._ensure_state().exclude_drug(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def exclude_procedure(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Exclude patients with a procedure occurrence.""" + result = self._builder._ensure_state().exclude_procedure(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + + def exclude_measurement(self, concept_set_id: int, **kwargs) -> 'InclusionRuleContext': + """Exclude patients with a measurement.""" + result = self._builder._ensure_state().exclude_measurement(concept_set_id, **kwargs) + self._builder._update_state(result) + return self + +class CohortWithEntry: + """ + Cohort state after entry event is defined. + + Available methods: + - first_occurrence(): Only use first occurrence per person + - with_observation(): Set observation window + - require_*(): Add inclusion criteria + - exclude_*(): Add exclusion criteria + - build(): Finalize and create CohortExpression + """ + + def __init__(self, parent: CohortBuilder, entry_query: BaseQuery): + self._parent = parent + self._entry_queries = [entry_query] + self._prior_observation_days = 0 + self._post_observation_days = 0 + self._limit = "All" + self._qualified_limit = "First" + self._expression_limit = "All" + self._settings = CohortSettings() + + def _add_query(self, config: QueryConfig, is_exclusion: bool = False) -> 'CohortWithCriteria': + """Delegate query addition to the criteria state.""" + return self._to_criteria()._add_query(config, is_exclusion) + + def _add_censor_query(self, config: QueryConfig) -> 'CohortWithCriteria': + """Delegate censor query addition to the criteria state.""" + return self._to_criteria()._add_censor_query(config) + + def _add_or_entry(self, query_class: type, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """ + Helper method to add alternative entry events (OR logic). + + Args: + query_class: The query class to instantiate + concept_set_id: Concept set ID for the query + **kwargs: Additional parameters to pass to apply_params() + + Returns: + Self for chaining + """ + query = query_class(concept_set_id, is_entry=True, parent=self) + query.apply_params(**kwargs) + self._entry_queries.append(query) + return self + + def or_with_condition(self, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """Add an alternative condition occurrence as entry event (OR logic).""" + return self._add_or_entry(ConditionQuery, concept_set_id, **kwargs) + + def or_with_drug(self, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """Add an alternative drug exposure as entry event (OR logic).""" + return self._add_or_entry(DrugQuery, concept_set_id, **kwargs) + + def or_with_procedure(self, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """Add an alternative procedure occurrence as entry event (OR logic).""" + return self._add_or_entry(ProcedureQuery, concept_set_id, **kwargs) + + def or_with_measurement(self, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """Add an alternative measurement as entry event (OR logic).""" + return self._add_or_entry(MeasurementQuery, concept_set_id, **kwargs) + + def or_with_visit(self, concept_set_id: int, **kwargs) -> 'CohortWithEntry': + """Add an alternative visit occurrence as entry event (OR logic).""" + return self._add_or_entry(VisitQuery, concept_set_id, **kwargs) + + def with_qualified_limit(self, limit: str) -> 'CohortWithEntry': + """Set the qualified limit (First, Last, All).""" + self._qualified_limit = limit + return self + + def with_expression_limit(self, limit: str) -> 'CohortWithEntry': + """Set the expression limit (First, Last, All).""" + self._expression_limit = limit + return self + + # Entry query filters (delegate to the last added query) + + def with_all(self) -> 'CriteriaGroupBuilder': + """Start a correlated criteria group for the last added entry.""" + return self._entry_queries[-1].with_all() + + def with_any(self) -> 'CriteriaGroupBuilder': + """Start a correlated criteria group for the last added entry.""" + return self._entry_queries[-1].with_any() + + def first_occurrence(self) -> 'CohortWithEntry': + """Only use the first occurrence per person for entry events.""" + for q in self._entry_queries: + q._get_config().first_occurrence = True + self._limit = "First" + return self + + def all_occurrences(self) -> 'CohortWithEntry': + """Use all occurrences per person.""" + self._limit = "All" + for q in self._entry_queries: + q._get_config().first_occurrence = False + return self + + def with_observation( + self, + prior_days: int = 0, + post_days: int = 0 + ) -> 'CohortWithEntry': + """ + Set continuous observation requirements. + + Args: + prior_days: Days of observation required before index + post_days: Days of observation required after index + + Returns: + Self for chaining + """ + self._prior_observation_days = prior_days + self._post_observation_days = post_days + return self + + def min_age(self, age: int) -> 'CohortWithEntry': + """Require minimum age at entry for all entry events.""" + for q in self._entry_queries: + q._get_config().age_min = age + return self + + def max_age(self, age: int) -> 'CohortWithEntry': + """Require maximum age at entry for all entry events.""" + for q in self._entry_queries: + q._get_config().age_max = age + return self + + def require_gender(self, *concept_ids: int) -> 'CohortWithEntry': + """Require specific gender concept IDs.""" + self._settings.gender_concepts.extend(concept_ids) + return self + + def require_race(self, *concept_ids: int) -> 'CohortWithEntry': + """Require specific race concept IDs.""" + self._settings.race_concepts.extend(concept_ids) + return self + + def require_ethnicity(self, *concept_ids: int) -> 'CohortWithEntry': + """Require specific ethnicity concept IDs.""" + self._settings.ethnicity_concepts.extend(concept_ids) + return self + + def require_age(self, min_age: Optional[int] = None, max_age: Optional[int] = None) -> 'CohortWithEntry': + """Require specific age range.""" + self._settings.age_min = min_age + self._settings.age_max = max_age + return self + + def begin_rule(self, name: str) -> 'CohortWithCriteria': + """Start a new named inclusion rule.""" + return self._to_criteria().begin_rule(name) + + def end_rule(self) -> 'CohortWithCriteria': + """Finish the current inclusion rule.""" + return self._to_criteria().end_rule() + + # Transition to CohortWithCriteria + def require_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """require_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """require_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """censor_on_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """censor_on_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_procedure(self, concept_set_id: int, **kwargs) -> Union['ProcedureQuery', 'CohortWithCriteria']: + """censor_on_procedure (Supports both chaining and parameter-based API)""" + query = ProcedureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_measurement(self, concept_set_id: int, **kwargs) -> Union['MeasurementQuery', 'CohortWithCriteria']: + """censor_on_measurement (Supports both chaining and parameter-based API)""" + query = MeasurementQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_observation(self, concept_set_id: int) -> ObservationQuery: + """Censor if an observation occurs.""" + return self._to_criteria().censor_on_observation(concept_set_id) + + def censor_on_visit(self, concept_set_id: int) -> VisitQuery: + """Censor if a visit occurs.""" + return self._to_criteria().censor_on_visit(concept_set_id) + + def censor_on_death(self, concept_set_id: Optional[int] = None) -> DeathQuery: + """Censor on death.""" + return self._to_criteria().censor_on_death(concept_set_id) + + def require_measurement(self, concept_set_id: int, **kwargs) -> Union['MeasurementQuery', 'CohortWithCriteria']: + """require_measurement (Supports both chaining and parameter-based API)""" + query = MeasurementQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """exclude_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """exclude_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """require_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_device_exposure(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + """require_device_exposure (Supports both chaining and parameter-based API)""" + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """require_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_visit_detail(self, concept_set_id: int, **kwargs) -> Union['VisitDetailQuery', 'CohortWithCriteria']: + """require_visit_detail (Supports both chaining and parameter-based API)""" + query = VisitDetailQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """require_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_payer_plan_period(self, concept_set_id: int, **kwargs) -> Union['PayerPlanPeriodQuery', 'CohortWithCriteria']: + """require_payer_plan_period (Supports both chaining and parameter-based API)""" + query = PayerPlanPeriodQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """exclude_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_device_exposure(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + """exclude_device_exposure (Supports both chaining and parameter-based API)""" + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """exclude_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_visit_detail(self, concept_set_id: int, **kwargs) -> Union['VisitDetailQuery', 'CohortWithCriteria']: + """exclude_visit_detail (Supports both chaining and parameter-based API)""" + query = VisitDetailQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """exclude_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_payer_plan_period(self, concept_set_id: int, **kwargs) -> Union['PayerPlanPeriodQuery', 'CohortWithCriteria']: + """exclude_payer_plan_period (Supports both chaining and parameter-based API)""" + query = PayerPlanPeriodQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + # Exit strategies + def exit_at_observation_end(self) -> 'CohortWithCriteria': + """Exit cohort at the end of the observation period.""" + return self._to_criteria().exit_at_observation_end() + + def exit_after_days(self, days: int, from_field: str = "startDate") -> 'CohortWithCriteria': + """Exit cohort N days after index start/end.""" + return self._to_criteria().exit_after_days(days, from_field) + + def exit_at_era_end(self, concept_set_id: int, gap_days: int = 30, offset: int = 0, supply_override: Optional[int] = None) -> 'CohortWithCriteria': + """Exit cohort at the end of a drug era.""" + return self._to_criteria().exit_at_era_end(concept_set_id, gap_days, offset, supply_override) + + def any_of(self) -> 'CriteriaGroupBuilder': + """Start an 'Any Of' group in the current rule.""" + return self._to_criteria().any_of() + + def all_of(self) -> 'CriteriaGroupBuilder': + """Start an 'All Of' group in the current rule.""" + return self._to_criteria().all_of() + + def at_least_of(self, count: int) -> 'CriteriaGroupBuilder': + """Start an 'At Least N Of' group in the current rule.""" + return self._to_criteria().at_least_of(count) + + # Collection method delegates + def require_any_of(self, **kwargs) -> 'CohortWithCriteria': + """Delegate to CohortWithCriteria. See CohortWithCriteria.require_any_of for documentation.""" + return self._to_criteria().require_any_of(**kwargs) + + def require_all_of(self, **kwargs) -> 'CohortWithCriteria': + """Delegate to CohortWithCriteria. See CohortWithCriteria.require_all_of for documentation.""" + return self._to_criteria().require_all_of(**kwargs) + + def require_at_least_of(self, count: int, **kwargs) -> 'CohortWithCriteria': + """Delegate to CohortWithCriteria. See CohortWithCriteria.require_at_least_of for documentation.""" + return self._to_criteria().require_at_least_of(count, **kwargs) + + def exclude_any_of(self, **kwargs) -> 'CohortWithCriteria': + """Delegate to CohortWithCriteria. See CohortWithCriteria.exclude_any_of for documentation.""" + return self._to_criteria().exclude_any_of(**kwargs) + + def _to_criteria(self) -> 'CohortWithCriteria': + """Transition to criteria state.""" + return CohortWithCriteria( + parent=self._parent, + entry_configs=[q._get_config() for q in self._entry_queries], + prior_observation=self._prior_observation_days, + post_observation=self._post_observation_days, + limit=self._limit, + qualified_limit=self._qualified_limit, + expression_limit=self._expression_limit, + settings=self._settings + ) + + def build(self) -> CohortExpression: + """Build the final CohortExpression.""" + return self._to_criteria().build() + +class CohortWithCriteria: + """ + Cohort state after criteria have been added. + + Available methods: + - require_*(): Add more inclusion criteria + - exclude_*(): Add more exclusion criteria + - exit_at_*(): Set cohort exit strategy + - collapse_era(): Set era gap days + - censor_with_*(): Add censoring events + - build(): Finalize and create CohortExpression + """ + + def __init__( + self, + parent: CohortBuilder, + entry_configs: List[QueryConfig], + prior_observation: int = 0, + post_observation: int = 0, + limit: str = "All", + qualified_limit: str = "First", + expression_limit: str = "All", + settings: Optional[CohortSettings] = None + ): + self._parent = parent + self._entry_configs = entry_configs + self._prior_observation = prior_observation + self._post_observation = post_observation + self._limit = limit + self._qualified_limit = qualified_limit + self._expression_limit = expression_limit + self._rules = [{"name": "Inclusion Criteria", "group": GroupConfig(type="ALL")}] + self._settings = settings or CohortSettings() + + def begin_rule(self, name: str) -> 'CohortWithCriteria': + """ + Start a new named inclusion rule. + + Subsequent criteria will be added to this rule until build() or + another begin_rule() is called. This is useful for attrition tracking. + """ + self._rules.append({"name": name, "group": GroupConfig(type="ALL")}) + return self + + def end_rule(self) -> 'CohortWithCriteria': + """ + Finish the current inclusion rule. + + This method is provided to balance .begin_rule() and make blocks more explicit. + """ + return self + + def _add_query(self, config: QueryConfig, is_exclusion: bool = False) -> 'CohortWithCriteria': + """Add a configured query to the current rule's criteria list.""" + self._rules[-1]["group"].criteria.append(CriteriaConfig( + query_config=config, + is_exclusion=is_exclusion + )) + return self + + def _add_censor_query(self, config: QueryConfig) -> 'CohortWithCriteria': + """Add a configured query to the censoring criteria list.""" + self._settings.censor_queries.append(config) + return self + + def any_of(self) -> 'CriteriaGroupBuilder': + """Start an 'Any Of' group in the current rule.""" + group = GroupConfig(type="ANY") + self._rules[-1]["group"].criteria.append(group) + return CriteriaGroupBuilder(self, group) + + def all_of(self) -> 'CriteriaGroupBuilder': + """Start an 'All Of' group in the current rule.""" + group = GroupConfig(type="ALL") + self._rules[-1]["group"].criteria.append(group) + return CriteriaGroupBuilder(self, group) + + def at_least_of(self, count: int) -> 'CriteriaGroupBuilder': + """Start an 'At Least N Of' group in the current rule.""" + group = GroupConfig(type="AT_LEAST", count=count) + self._rules[-1]["group"].criteria.append(group) + return CriteriaGroupBuilder(self, group) + + # Collection methods for simplified group creation + def require_any_of( + self, + condition_ids: Optional[List[int]] = None, + drug_ids: Optional[List[int]] = None, + drug_era_ids: Optional[List[int]] = None, + procedure_ids: Optional[List[int]] = None, + measurement_ids: Optional[List[int]] = None, + observation_ids: Optional[List[int]] = None, + visit_ids: Optional[List[int]] = None + ) -> 'CohortWithCriteria': + """ + Require ANY of the specified criteria (OR logic). + + This is a shortcut for creating an ANY group with multiple criteria + without manually chaining .any_of()...end_group(). + + Args: + condition_ids: List of condition concept set IDs + drug_ids: List of drug concept set IDs + drug_era_ids: List of drug era concept set IDs + procedure_ids: List of procedure concept set IDs + measurement_ids: List of measurement concept set IDs + observation_ids: List of observation concept set IDs + visit_ids: List of visit concept set IDs + + Returns: + Self for continued chaining + + Example: + >>> cohort.require_any_of(drug_ids=[1, 2, 3]) + # Patient must have at least one of Drug 1, 2, or 3 + """ + group = GroupConfig(type="ANY") + + if condition_ids: + for cid in condition_ids: + config = QueryConfig(domain="ConditionOccurrence", concept_set_id=cid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) # Default: all time + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_ids: + for did in drug_ids: + config = QueryConfig(domain="DrugExposure", concept_set_id=did) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_era_ids: + for deid in drug_era_ids: + config = QueryConfig(domain="DrugEra", concept_set_id=deid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if procedure_ids: + for pid in procedure_ids: + config = QueryConfig(domain="ProcedureOccurrence", concept_set_id=pid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if measurement_ids: + for mid in measurement_ids: + config = QueryConfig(domain="Measurement", concept_set_id=mid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if observation_ids: + for oid in observation_ids: + config = QueryConfig(domain="Observation", concept_set_id=oid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if visit_ids: + for vid in visit_ids: + config = QueryConfig(domain="VisitOccurrence", concept_set_id=vid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if group.criteria: # Only add if we have at least one criterion + self._rules[-1]["group"].criteria.append(group) + + return self + + def require_all_of( + self, + condition_ids: Optional[List[int]] = None, + drug_ids: Optional[List[int]] = None, + drug_era_ids: Optional[List[int]] = None, + procedure_ids: Optional[List[int]] = None, + measurement_ids: Optional[List[int]] = None, + observation_ids: Optional[List[int]] = None, + visit_ids: Optional[List[int]] = None + ) -> 'CohortWithCriteria': + """ + Require ALL of the specified criteria (AND logic). + + This is a shortcut for creating an ALL group with multiple criteria. + + Args: + condition_ids: List of condition concept set IDs + drug_ids: List of drug concept set IDs + drug_era_ids: List of drug era concept set IDs + procedure_ids: List of procedure concept set IDs + measurement_ids: List of measurement concept set IDs + observation_ids: List of observation concept set IDs + visit_ids: List of visit concept set IDs + + Returns: + Self for continued chaining + + Example: + >>> cohort.require_all_of(drug_ids=[1, 2], procedure_ids=[10]) + # Patient must have Drug 1 AND Drug 2 AND Procedure 10 + """ + group = GroupConfig(type="ALL") + + if condition_ids: + for cid in condition_ids: + config = QueryConfig(domain="ConditionOccurrence", concept_set_id=cid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_ids: + for did in drug_ids: + config = QueryConfig(domain="DrugExposure", concept_set_id=did) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_era_ids: + for deid in drug_era_ids: + config = QueryConfig(domain="DrugEra", concept_set_id=deid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if procedure_ids: + for pid in procedure_ids: + config = QueryConfig(domain="ProcedureOccurrence", concept_set_id=pid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if measurement_ids: + for mid in measurement_ids: + config = QueryConfig(domain="Measurement", concept_set_id=mid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if observation_ids: + for oid in observation_ids: + config = QueryConfig(domain="Observation", concept_set_id=oid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if visit_ids: + for vid in visit_ids: + config = QueryConfig(domain="VisitOccurrence", concept_set_id=vid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if group.criteria: + self._rules[-1]["group"].criteria.append(group) + + return self + + def require_at_least_of( + self, + count: int, + condition_ids: Optional[List[int]] = None, + drug_ids: Optional[List[int]] = None, + drug_era_ids: Optional[List[int]] = None, + procedure_ids: Optional[List[int]] = None, + measurement_ids: Optional[List[int]] = None, + observation_ids: Optional[List[int]] = None, + visit_ids: Optional[List[int]] = None + ) -> 'CohortWithCriteria': + """ + Require at least N of the specified criteria. + + This is a shortcut for creating an AT_LEAST group. + + Args: + count: Minimum number of criteria that must be met + condition_ids: List of condition concept set IDs + drug_ids: List of drug concept set IDs + drug_era_ids: List of drug era concept set IDs + procedure_ids: List of procedure concept set IDs + measurement_ids: List of measurement concept set IDs + observation_ids: List of observation concept set IDs + visit_ids: List of visit concept set IDs + + Returns: + Self for continued chaining + + Example: + >>> cohort.require_at_least_of(2, procedure_ids=[10, 11, 12]) + # Patient must have at least 2 of the 3 procedures + """ + group = GroupConfig(type="AT_LEAST", count=count) + + if condition_ids: + for cid in condition_ids: + config = QueryConfig(domain="ConditionOccurrence", concept_set_id=cid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_ids: + for did in drug_ids: + config = QueryConfig(domain="DrugExposure", concept_set_id=did) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if drug_era_ids: + for deid in drug_era_ids: + config = QueryConfig(domain="DrugEra", concept_set_id=deid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if procedure_ids: + for pid in procedure_ids: + config = QueryConfig(domain="ProcedureOccurrence", concept_set_id=pid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if measurement_ids: + for mid in measurement_ids: + config = QueryConfig(domain="Measurement", concept_set_id=mid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if observation_ids: + for oid in observation_ids: + config = QueryConfig(domain="Observation", concept_set_id=oid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if visit_ids: + for vid in visit_ids: + config = QueryConfig(domain="VisitOccurrence", concept_set_id=vid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=False)) + + if group.criteria: + self._rules[-1]["group"].criteria.append(group) + + return self + + def exclude_any_of( + self, + condition_ids: Optional[List[int]] = None, + drug_ids: Optional[List[int]] = None, + drug_era_ids: Optional[List[int]] = None, + procedure_ids: Optional[List[int]] = None, + measurement_ids: Optional[List[int]] = None, + observation_ids: Optional[List[int]] = None, + visit_ids: Optional[List[int]] = None + ) -> 'CohortWithCriteria': + """ + Exclude if ANY of the specified criteria are present. + + This creates exclusion criteria with OR logic. + + Args: + condition_ids: List of condition concept set IDs to exclude + drug_ids: List of drug concept set IDs to exclude + drug_era_ids: List of drug era concept set IDs to exclude + procedure_ids: List of procedure concept set IDs to exclude + measurement_ids: List of measurement concept set IDs to exclude + observation_ids: List of observation concept set IDs to exclude + visit_ids: List of visit concept set IDs to exclude + + Returns: + Self for continued chaining + + Example: + >>> cohort.exclude_any_of(drug_ids=[3, 4]) + # Exclude patients who have Drug 3 OR Drug 4 + """ + group = GroupConfig(type="ANY") + + if condition_ids: + for cid in condition_ids: + config = QueryConfig(domain="ConditionOccurrence", concept_set_id=cid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if drug_ids: + for did in drug_ids: + config = QueryConfig(domain="DrugExposure", concept_set_id=did) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if drug_era_ids: + for deid in drug_era_ids: + config = QueryConfig(domain="DrugEra", concept_set_id=deid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if procedure_ids: + for pid in procedure_ids: + config = QueryConfig(domain="ProcedureOccurrence", concept_set_id=pid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if measurement_ids: + for mid in measurement_ids: + config = QueryConfig(domain="Measurement", concept_set_id=mid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if observation_ids: + for oid in observation_ids: + config = QueryConfig(domain="Observation", concept_set_id=oid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if visit_ids: + for vid in visit_ids: + config = QueryConfig(domain="VisitOccurrence", concept_set_id=vid) + config.time_window = TimeWindow(days_before=99999, days_after=99999) + group.criteria.append(CriteriaConfig(query_config=config, is_exclusion=True)) + + if group.criteria: + self._rules[-1]["group"].criteria.append(group) + + return self + + # Inclusion methods - return query builders with self as parent + def require_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """require_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """require_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_drug_era(self, concept_set_id: int, **kwargs) -> Union['DrugEraQuery', 'CohortWithCriteria']: + """require_drug_era (Supports both chaining and parameter-based API)""" + query = DrugEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_measurement(self, concept_set_id: int, **kwargs) -> Union['MeasurementQuery', 'CohortWithCriteria']: + """require_measurement (Supports both chaining and parameter-based API)""" + query = MeasurementQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_procedure(self, concept_set_id: int, **kwargs) -> Union['ProcedureQuery', 'CohortWithCriteria']: + """require_procedure (Supports both chaining and parameter-based API)""" + query = ProcedureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_visit(self, concept_set_id: int, **kwargs) -> Union['VisitQuery', 'CohortWithCriteria']: + """require_visit (Supports both chaining and parameter-based API)""" + query = VisitQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_observation(self, concept_set_id: int, **kwargs) -> Union['ObservationQuery', 'CohortWithCriteria']: + """require_observation (Supports both chaining and parameter-based API)""" + query = ObservationQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_visit_detail(self, concept_set_id: int, **kwargs) -> Union['VisitDetailQuery', 'CohortWithCriteria']: + """require_visit_detail (Supports both chaining and parameter-based API)""" + query = VisitDetailQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_death(self, **kwargs) -> Union['DeathQuery', 'CohortWithCriteria']: + """require_death (Supports both chaining and parameter-based API)""" + query = DeathQuery(parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_device(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + """require_device (Supports both chaining and parameter-based API)""" + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """require_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_drug_era(self, concept_set_id: int, **kwargs) -> Union['DrugEraQuery', 'CohortWithCriteria']: + """require_drug_era (Supports both chaining and parameter-based API)""" + query = DrugEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """require_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """require_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """exclude_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """exclude_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_drug_era(self, concept_set_id: int, **kwargs) -> Union['DrugEraQuery', 'CohortWithCriteria']: + """exclude_drug_era (Supports both chaining and parameter-based API)""" + query = DrugEraQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_measurement(self, concept_set_id: int, **kwargs) -> Union['MeasurementQuery', 'CohortWithCriteria']: + """exclude_measurement (Supports both chaining and parameter-based API)""" + query = MeasurementQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_procedure(self, concept_set_id: int, **kwargs) -> Union['ProcedureQuery', 'CohortWithCriteria']: + """exclude_procedure (Supports both chaining and parameter-based API)""" + query = ProcedureQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """require_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """exclude_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_device_exposure(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + """require_device_exposure (Supports both chaining and parameter-based API)""" + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_device_exposure(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + """exclude_device_exposure (Supports both chaining and parameter-based API)""" + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """require_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """exclude_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_observation_period(self, **kwargs) -> Union['ObservationPeriodQuery', 'CohortWithCriteria']: + """require_observation_period (Supports both chaining and parameter-based API)""" + query = ObservationPeriodQuery(parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_observation_period(self, **kwargs) -> Union['ObservationPeriodQuery', 'CohortWithCriteria']: + + """exclude_observation_period (Supports both chaining and parameter-based API)""" + + query = ObservationPeriodQuery(parent=self, is_exclusion=True, is_censor=False) + + if kwargs: + + query.apply_params(**kwargs) + + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + + return query._finalize() + + return query + + def require_payer_plan_period(self, concept_set_id: int, **kwargs) -> Union['PayerPlanPeriodQuery', 'CohortWithCriteria']: + """require_payer_plan_period (Supports both chaining and parameter-based API)""" + query = PayerPlanPeriodQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_payer_plan_period(self, concept_set_id: int, **kwargs) -> Union['PayerPlanPeriodQuery', 'CohortWithCriteria']: + """exclude_payer_plan_period (Supports both chaining and parameter-based API)""" + query = PayerPlanPeriodQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_visit_detail(self, concept_set_id: int, **kwargs) -> Union['VisitDetailQuery', 'CohortWithCriteria']: + """require_visit_detail (Supports both chaining and parameter-based API)""" + query = VisitDetailQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_death(self, **kwargs) -> Union['DeathQuery', 'CohortWithCriteria']: + """require_death (Supports both chaining and parameter-based API)""" + query = DeathQuery(parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_device(self, concept_set_id: int, **kwargs) -> Union['DeviceExposureQuery', 'CohortWithCriteria']: + + """require_device (Supports both chaining and parameter-based API)""" + + query = DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + + if kwargs: + + query.apply_params(**kwargs) + + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + + return query._finalize() + + return query + + def require_specimen(self, concept_set_id: int, **kwargs) -> Union['SpecimenQuery', 'CohortWithCriteria']: + """require_specimen (Supports both chaining and parameter-based API)""" + query = SpecimenQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_drug_era(self, concept_set_id: int, **kwargs) -> Union['DrugEraQuery', 'CohortWithCriteria']: + """require_drug_era (Supports both chaining and parameter-based API)""" + query = DrugEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_condition_era(self, concept_set_id: int, **kwargs) -> Union['ConditionEraQuery', 'CohortWithCriteria']: + """require_condition_era (Supports both chaining and parameter-based API)""" + query = ConditionEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """require_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_visit_detail(self, concept_set_id: int, **kwargs) -> Union['VisitDetailQuery', 'CohortWithCriteria']: + """exclude_visit_detail (Supports both chaining and parameter-based API)""" + query = VisitDetailQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def require_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """require_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exclude_dose_era(self, concept_set_id: int, **kwargs) -> Union['DoseEraQuery', 'CohortWithCriteria']: + """exclude_dose_era (Supports both chaining and parameter-based API)""" + query = DoseEraQuery(concept_set_id, parent=self, is_exclusion=True, is_censor=False) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def exit_at_observation_end(self) -> 'CohortWithCriteria': + """Exit cohort at the end of the observation period.""" + self._settings.exit_strategy_type = "observation" + return self + + def exit_after_days(self, days: int, from_field: str = "startDate") -> 'CohortWithCriteria': + """Exit cohort N days after index start/end.""" + self._settings.exit_strategy_type = "date_offset" + self._settings.exit_offset_days = days + self._settings.exit_offset_field = from_field + return self + + def exit_at_era_end(self, concept_set_id: int, gap_days: int = 30, offset: int = 0, supply_override: Optional[int] = None) -> 'CohortWithCriteria': + """Exit cohort at the end of a drug era.""" + self._settings.exit_strategy_type = "custom_era" + self._settings.custom_era_drug_codeset_id = concept_set_id + self._settings.custom_era_gap_days = gap_days + self._settings.custom_era_offset = offset + self._settings.custom_era_days_supply_override = supply_override + return self + + # Censoring methods + def censor_on_condition(self, concept_set_id: int, **kwargs) -> Union['ConditionQuery', 'CohortWithCriteria']: + """censor_on_condition (Supports both chaining and parameter-based API)""" + query = ConditionQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_drug(self, concept_set_id: int, **kwargs) -> Union['DrugQuery', 'CohortWithCriteria']: + """censor_on_drug (Supports both chaining and parameter-based API)""" + query = DrugQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_procedure(self, concept_set_id: int, **kwargs) -> Union['ProcedureQuery', 'CohortWithCriteria']: + """censor_on_procedure (Supports both chaining and parameter-based API)""" + query = ProcedureQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_measurement(self, concept_set_id: int, **kwargs) -> Union['MeasurementQuery', 'CohortWithCriteria']: + """censor_on_measurement (Supports both chaining and parameter-based API)""" + query = MeasurementQuery(concept_set_id, parent=self, is_exclusion=False, is_censor=True) + if kwargs: + query.apply_params(**kwargs) + if any(p in kwargs for p in ['anytime_before', 'anytime_after', 'within_days_before', 'within_days_after', 'within_days', 'same_day', 'during_event', 'before_event_end']): + return query._finalize() + return query + + def censor_on_observation(self, concept_set_id: int) -> ObservationQuery: + """Censor if an observation occurs.""" + return ObservationQuery(concept_set_id, parent=self, is_censor=True) + + def censor_on_visit(self, concept_set_id: int) -> VisitQuery: + """Censor if a visit occurs.""" + return VisitQuery(concept_set_id, parent=self, is_censor=True) + + def censor_on_death(self, concept_set_id: Optional[int] = None) -> DeathQuery: + """Censor on death.""" + return DeathQuery(concept_set_id, parent=self, is_censor=True) + + def censor_on_device_exposure(self, concept_set_id: int) -> DeviceExposureQuery: + """Censor if a device exposure occurs.""" + return DeviceExposureQuery(concept_set_id, parent=self, is_censor=True) + + def collapse_era(self, days: int) -> 'CohortWithCriteria': + """Set the number of gap days to collapse successive cohort entries.""" + self._settings.era_days = days + return self + + def require_gender(self, *concept_ids: int) -> 'CohortWithCriteria': + """Require specific gender concept IDs.""" + self._settings.gender_concepts.extend(concept_ids) + return self + + def require_race(self, *concept_ids: int) -> 'CohortWithCriteria': + """Require specific race concept IDs.""" + self._settings.race_concepts.extend(concept_ids) + return self + + def require_ethnicity(self, *concept_ids: int) -> 'CohortWithCriteria': + """Require specific ethnicity concept IDs.""" + self._settings.ethnicity_concepts.extend(concept_ids) + return self + + def require_age(self, min_age: Optional[int] = None, max_age: Optional[int] = None) -> 'CohortWithCriteria': + """Require specific age range.""" + self._settings.age_min = min_age + self._settings.age_max = max_age + return self + + def build(self) -> CohortExpression: + """ + Build the final CohortExpression. + + Returns: + CohortExpression ready for SQL generation + """ + return _build_cohort_expression( + title=self._parent._title, + concept_sets=self._parent._concept_sets, + entry_configs=self._entry_configs, + prior_observation=self._prior_observation, + post_observation=self._post_observation, + limit=self._limit, + qualified_limit=self._qualified_limit, + expression_limit=self._expression_limit, + rules=self._rules, + settings=self._settings + ) + +# ============================================================================= +# CONVERSION FUNCTIONS +# ============================================================================= + +def _build_cohort_expression( + title: str, + concept_sets: List[ConceptSet], + entry_configs: List[QueryConfig], + prior_observation: int, + post_observation: int, + limit: str, + qualified_limit: str, + expression_limit: str, + rules: List[Dict[str, Any]], + settings: CohortSettings +) -> CohortExpression: + """Build a CohortExpression from the builder state.""" + + # Build primary criteria + entry_criteria_list = [_config_to_criteria(cfg) for cfg in entry_configs] + primary_criteria = PrimaryCriteria( + criteria_list=entry_criteria_list, + observation_window=ObservationFilter( + prior_days=prior_observation, + post_days=post_observation + ), + primary_criteria_limit=ResultLimit(type=limit) + ) + + # Build inclusion rules from rules + inclusion_rules = [] + + # Build demographic rule FIRST if needed + if settings.gender_concepts or settings.race_concepts or settings.ethnicity_concepts or settings.age_min is not None or settings.age_max is not None: + demographic = DemographicCriteria() + if settings.gender_concepts: + demographic.gender = [Concept(concept_id=c, concept_name="Gender") for c in settings.gender_concepts] + if settings.race_concepts: + demographic.race = [Concept(concept_id=c, concept_name="Race") for c in settings.race_concepts] + if settings.ethnicity_concepts: + demographic.ethnicity = [Concept(concept_id=c, concept_name="Ethnicity") for c in settings.ethnicity_concepts] + + if settings.age_min is not None or settings.age_max is not None: + op = 'bt' if (settings.age_min is not None and settings.age_max is not None) else ('gte' if settings.age_min is not None else 'lte') + demographic.age = NumericRange(value=settings.age_min, extent=settings.age_max, op=op) + + inclusion_rules.append(InclusionRule( + name="Demographic Criteria", + expression=CirceCriteriaGroup( + type="ALL", + demographic_criteria_list=[demographic] + ) + )) + + # Then add named rules from builder + for rule_data in rules: + rule_name = rule_data["name"] + root_group = rule_data["group"] + + if not root_group.criteria: + continue + + # If the root group has exactly ONE child and that child is a group (not a criteria), + # unwrap it and use it directly as the expression + if (len(root_group.criteria) == 1 and + isinstance(root_group.criteria[0], GroupConfig)): + # Use the nested group directly + expression = _build_criteria_group(root_group.criteria[0]) + else: + # Use the root group as-is + expression = _build_criteria_group(root_group) + + inclusion_rules.append(InclusionRule( + name=rule_name, + expression=expression + )) + + # Build end strategy + end_strategy = None + if settings.exit_strategy_type == "date_offset": + from circe.cohortdefinition.core import DateOffsetStrategy + end_strategy = DateOffsetStrategy( + date_field=settings.exit_offset_field, + offset=settings.exit_offset_days + ) + elif settings.exit_strategy_type == "custom_era": + from circe.cohortdefinition.core import CustomEraStrategy + end_strategy = CustomEraStrategy( + drug_codeset_id=settings.custom_era_drug_codeset_id, + gap_days=settings.custom_era_gap_days, + offset=settings.custom_era_offset, + days_supply_override=settings.custom_era_days_supply_override + ) + + # Build collapse settings + from circe.cohortdefinition.core import CollapseSettings + collapse_settings = CollapseSettings(era_pad=settings.era_days) + + # Build censoring criteria + censoring_criteria = [] + for cq in settings.censor_queries: + censoring_criteria.append(_config_to_criteria(cq)) + + return CohortExpression( + title=title, + concept_sets=concept_sets, + primary_criteria=primary_criteria, + inclusion_rules=inclusion_rules, + end_strategy=end_strategy, + collapse_settings=collapse_settings, + censoring_criteria=censoring_criteria, + qualified_limit=ResultLimit(type=qualified_limit), + expression_limit=ResultLimit(type=expression_limit) + ) + +def _config_to_criteria(config: QueryConfig): + """Convert a QueryConfig to a domain criteria object.""" + domain_map = { + 'ConditionOccurrence': ConditionOccurrence, + 'ConditionEra': ConditionEra, + 'DrugExposure': DrugExposure, + 'DrugEra': DrugEra, + 'DoseEra': DoseEra, + 'ProcedureOccurrence': ProcedureOccurrence, + 'Measurement': Measurement, + 'Observation': Observation, + 'VisitOccurrence': VisitOccurrence, + 'VisitDetail': VisitDetail, + 'DeviceExposure': DeviceExposure, + 'Specimen': Specimen, + 'ObservationPeriod': ObservationPeriod, + 'PayerPlanPeriod': PayerPlanPeriod, + 'LocationRegion': LocationRegion, + 'Death': Death + } + + criteria_class = domain_map.get(config.domain) + if not criteria_class: + raise ValueError(f"Unknown domain: {config.domain}") + + kwargs = { + 'codeset_id': config.concept_set_id, + 'first': config.first_occurrence if config.first_occurrence else None + } + + if config.age_min is not None or config.age_max is not None: + op = 'bt' if (config.age_min is not None and config.age_max is not None) else ('gte' if config.age_min is not None else 'lte') + kwargs['age'] = NumericRange(value=config.age_min, extent=config.age_max, op=op) + + # Map domain-specific filters + if config.domain == 'Measurement': + if config.value_min is not None or config.value_max is not None: + op = 'bt' if (config.value_min is not None and config.value_max is not None) else ('gte' if config.value_min is not None else 'lte') + kwargs['value_as_number'] = NumericRange(value=config.value_min, extent=config.value_max, op=op) + # Phase 2: Measurement-specific modifiers + if config.measurement_operator_concepts: + kwargs['operator'] = [Concept(concept_id=c, concept_name="Operator") for c in config.measurement_operator_concepts] + if config.range_low_ratio_min is not None or config.range_low_ratio_max is not None: + op = 'bt' if (config.range_low_ratio_min and config.range_low_ratio_max) else ('gte' if config.range_low_ratio_min else 'lte') + kwargs['range_low_ratio'] = NumericRange(value=config.range_low_ratio_min, extent=config.range_low_ratio_max, op=op) + if config.range_high_ratio_min is not None or config.range_high_ratio_max is not None: + op = 'bt' if (config.range_high_ratio_min and config.range_high_ratio_max) else ('gte' if config.range_high_ratio_min else 'lte') + kwargs['range_high_ratio'] = NumericRange(value=config.range_high_ratio_min, extent=config.range_high_ratio_max, op=op) + + if config.domain == 'DrugExposure': + if config.days_supply_min is not None or config.days_supply_max is not None: + op = 'bt' if (config.days_supply_min is not None and config.days_supply_max is not None) else ('gte' if config.days_supply_min is not None else 'lte') + kwargs['days_supply'] = NumericRange(value=config.days_supply_min, extent=config.days_supply_max, op=op) + if config.quantity_min is not None or config.quantity_max is not None: + op = 'bt' if (config.quantity_min is not None and config.quantity_max is not None) else ('gte' if config.quantity_min is not None else 'lte') + kwargs['quantity'] = NumericRange(value=config.quantity_min, extent=config.quantity_max, op=op) + # Phase 2: Drug-specific modifiers + if config.drug_route_concepts: + kwargs['route_concept'] = [Concept(concept_id=c, concept_name="Route") for c in config.drug_route_concepts] + if config.refills_min is not None or config.refills_max is not None: + op = 'bt' if (config.refills_min and config.refills_max) else ('gte' if config.refills_min else 'lte') + kwargs['refills'] = NumericRange(value=config.refills_min, extent=config.refills_max, op=op) + if config.dose_min is not None or config.dose_max is not None: + op = 'bt' if (config.dose_min and config.dose_max) else ('gte' if config.dose_min else 'lte') + kwargs['effective_drug_dose'] = NumericRange(value=config.dose_min, extent=config.dose_max, op=op) + + if config.domain == 'Measurement': + if config.unit_concepts: + kwargs['unit'] = [Concept(concept_id=c, concept_name="Unit") for c in config.unit_concepts] + if config.abnormal is not None: + kwargs['abnormal'] = config.abnormal + if config.range_low_min is not None or config.range_low_max is not None: + op = 'bt' if (config.range_low_min is not None and config.range_low_max is not None) else ('gte' if config.range_low_min is not None else 'lte') + kwargs['range_low'] = NumericRange(value=config.range_low_min, extent=config.range_low_max, op=op) + if config.range_high_min is not None or config.range_high_max is not None: + op = 'bt' if (config.range_high_min is not None and config.range_high_max is not None) else ('gte' if config.range_high_min is not None else 'lte') + kwargs['range_high'] = NumericRange(value=config.range_high_min, extent=config.range_high_max, op=op) + if config.value_as_concept_concepts: + kwargs['value_as_concept'] = [Concept(concept_id=c, concept_name="Value") for c in config.value_as_concept_concepts] + + if config.domain in ['DrugEra', 'ConditionEra']: + if config.era_length_min is not None or config.era_length_max is not None: + op = 'bt' if (config.era_length_min is not None and config.era_length_max is not None) else ('gte' if config.era_length_min is not None else 'lte') + kwargs['era_length'] = NumericRange(value=config.era_length_min, extent=config.era_length_max, op=op) + if config.value_min is not None or config.value_max is not None: + op = 'bt' if (config.value_min is not None and config.value_max is not None) else ('gte' if config.value_min is not None else 'lte') + kwargs['occurrence_count'] = NumericRange(value=config.value_min, extent=config.value_max, op=op) + if config.domain == 'DrugEra' and (config.extent_min is not None or config.extent_max is not None): + op = 'bt' if (config.extent_min is not None and config.extent_max is not None) else ('gte' if config.extent_min is not None else 'lte') + kwargs['gap_days'] = NumericRange(value=config.extent_min, extent=config.extent_max, op=op) + + if config.domain == 'DoseEra': + if config.dose_min is not None or config.dose_max is not None: + op = 'bt' if (config.dose_min is not None and config.dose_max is not None) else ('gte' if config.dose_min is not None else 'lte') + kwargs['dose_value'] = NumericRange(value=config.dose_min, extent=config.dose_max, op=op) + if config.era_length_min is not None or config.era_length_max is not None: + op = 'bt' if (config.era_length_min is not None and config.era_length_max is not None) else ('gte' if config.era_length_min is not None else 'lte') + kwargs['era_length'] = NumericRange(value=config.era_length_min, extent=config.era_length_max, op=op) + + # Phase 2: Procedure-specific modifiers + if config.domain == 'ProcedureOccurrence': + if config.procedure_modifier_concepts: + kwargs['modifier'] = [Concept(concept_id=c, concept_name="Modifier") for c in config.procedure_modifier_concepts] + if config.quantity_min is not None or config.quantity_max is not None: + op = 'bt' if (config.quantity_min and config.quantity_max) else ('gte' if config.quantity_min else 'lte') + kwargs['quantity'] = NumericRange(value=config.quantity_min, extent=config.quantity_max, op=op) + + if config.domain in ['VisitOccurrence', 'VisitDetail', 'ObservationPeriod']: + if config.value_min is not None or config.value_max is not None: + op = 'bt' if (config.value_min is not None and config.value_max is not None) else ('gte' if config.value_min is not None else 'lte') + if config.domain == 'VisitOccurrence': + kwargs['visit_length'] = NumericRange(value=config.value_min, extent=config.value_max, op=op) + elif config.domain == 'VisitDetail': + kwargs['visit_detail_length'] = NumericRange(value=config.value_min, extent=config.value_max, op=op) + elif config.domain == 'ObservationPeriod': + kwargs['period_length'] = NumericRange(value=config.value_min, extent=config.value_max, op=op) + + # Phase 2: Visit-specific modifiers (outside value range check) + if config.domain == 'VisitOccurrence': + if config.place_of_service_concepts: + kwargs['place_of_service'] = [Concept(concept_id=c, concept_name="Place of Service") for c in config.place_of_service_concepts] + + # Phase 2: Observation-specific modifiers + if config.domain == 'Observation': + if config.qualifier_concepts: + kwargs['qualifier'] = [Concept(concept_id=c, concept_name="Qualifier") for c in config.qualifier_concepts] + if config.value_as_string: + kwargs['value_as_string'] = config.value_as_string + + # Map common filters + if config.gender_concepts: + kwargs['gender'] = [Concept(concept_id=c, concept_name="Gender") for c in config.gender_concepts] + + if config.visit_type_concepts: + kwargs['visit_type'] = [Concept(concept_id=c, concept_name="Visit Type") for c in config.visit_type_concepts] + + if config.condition_type_concepts: + kwargs['condition_type'] = [Concept(concept_id=c, concept_name="Condition Type") for c in config.condition_type_concepts] + + if config.drug_type_concepts: + kwargs['drug_type'] = [Concept(concept_id=c, concept_name="Drug Type") for c in config.drug_type_concepts] + + if config.procedure_type_concepts: + kwargs['procedure_type'] = [Concept(concept_id=c, concept_name="Procedure Type") for c in config.procedure_type_concepts] + + if config.measurement_type_concepts: + kwargs['measurement_type'] = [Concept(concept_id=c, concept_name="Measurement Type") for c in config.measurement_type_concepts] + + if config.observation_type_concepts: + kwargs['observation_type'] = [Concept(concept_id=c, concept_name="Observation Type") for c in config.observation_type_concepts] + + if config.device_type_concepts: + kwargs['device_type'] = [Concept(concept_id=c, concept_name="Device Type") for c in config.device_type_concepts] + + if config.provider_specialty_concepts: + kwargs['provider_specialty'] = [Concept(concept_id=c, concept_name="Provider Specialty") for c in config.provider_specialty_concepts] + + if config.source_concept_set_id is not None: + source_field_map = { + 'ConditionOccurrence': 'condition_source_concept', + 'DrugExposure': 'drug_source_concept', + 'ProcedureOccurrence': 'procedure_source_concept', + 'Measurement': 'measurement_source_concept', + 'Observation': 'observation_source_concept', + 'DeviceExposure': 'device_source_concept', + 'Specimen': 'specimen_source_concept', + 'Death': 'death_source_concept', + 'VisitOccurrence': 'visit_source_concept', + 'VisitDetail': 'visit_detail_source_concept' + } + source_field = source_field_map.get(config.domain) + if source_field: + kwargs[source_field] = config.source_concept_set_id + + # Filter out None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + criteria_obj = criteria_class(**kwargs) + + # Add correlated criteria if present + if config.correlated_criteria: + criteria_obj.correlated_criteria = _build_criteria_group(config.correlated_criteria) + + # Add date adjustment if present + if config.start_date_offset != 0 or config.end_date_offset != 0: + criteria_obj.date_adjustment = DateAdjustment( + start_offset=config.start_date_offset, + end_offset=config.end_date_offset, + start_with=config.start_date_with, + end_with=config.end_date_with + ) + + return criteria_obj + +def _build_criteria_group(group_cfg: GroupConfig) -> CirceCriteriaGroup: + """Recursively build a CirceCriteriaGroup from GroupConfig.""" + criteria_list = [] + groups = [] + + for item in group_cfg.criteria: + if isinstance(item, CriteriaConfig): + cc = _build_correlated_criteria(item) + criteria_list.append(cc) + elif isinstance(item, GroupConfig): + groups.append(_build_criteria_group(item)) + + return CirceCriteriaGroup( + type=group_cfg.type, + count=group_cfg.count, + criteria_list=criteria_list, + groups=groups + ) + +def _build_correlated_criteria(criteria_cfg: CriteriaConfig) -> CorelatedCriteria: + """Convert a CriteriaConfig to a CorelatedCriteria.""" + config = criteria_cfg.query_config + query_criteria = _config_to_criteria(config) + + # Build occurrence + if criteria_cfg.is_exclusion: + occurrence = Occurrence(type=0, count=0, is_distinct=False) # exactly 0 + else: + type_map = {"exactly": 0, "atMost": 1, "atLeast": 2} + occ_type = type_map.get(config.occurrence_type, 2) + occurrence = Occurrence( + type=occ_type, + count=config.occurrence_count, + is_distinct=config.is_distinct + ) + + # Build window + start_window = None + if config.time_window: + tw = config.time_window + start_window = Window( + use_index_end=tw.use_index_end, + use_event_end=tw.use_event_end, + start=WindowBound(coeff=-1, days=tw.days_before), + end=WindowBound(coeff=1, days=tw.days_after) + ) + + return CorelatedCriteria( + criteria=query_criteria, + start_window=start_window, + occurrence=occurrence, + restrict_visit=config.restrict_visit, + ignore_observation_period=config.ignore_observation_period + ) diff --git a/circe/cohort_builder/query_builder.py b/circe/cohort_builder/query_builder.py new file mode 100644 index 0000000..53cfb6b --- /dev/null +++ b/circe/cohort_builder/query_builder.py @@ -0,0 +1,716 @@ +""" +Query builder classes for the fluent API. + +These classes provide configuration for domain-specific queries +with time windows, occurrence counts, and filters. +""" + +from typing import Optional, List, Union, Any, TYPE_CHECKING +from dataclasses import dataclass, field + +if TYPE_CHECKING: + from circe.cohort_builder.builder import CohortWithCriteria + + +@dataclass +class TimeWindow: + """Represents a time window relative to the index event.""" + days_before: int = 0 + days_after: int = 0 + use_index_end: bool = False + use_event_end: bool = False + + +@dataclass +class QueryConfig: + """Configuration for a domain query.""" + domain: str + concept_set_id: Optional[int] = None + first_occurrence: bool = False + age_min: Optional[int] = None + age_max: Optional[int] = None + gender_concepts: List[int] = field(default_factory=list) + occurrence_count: int = 1 + occurrence_type: str = "atLeast" # atLeast, atMost, exactly + time_window: Optional[TimeWindow] = None + is_distinct: bool = False + restrict_visit: bool = False + ignore_observation_period: bool = False + + # Domain specific filters + value_min: Optional[float] = None + value_max: Optional[float] = None + extent_min: Optional[float] = None + extent_max: Optional[float] = None + op: Optional[str] = None + unit_concepts: List[int] = field(default_factory=list) + abnormal: Optional[bool] = None + range_low_min: Optional[float] = None + range_low_max: Optional[float] = None + range_high_min: Optional[float] = None + range_high_max: Optional[float] = None + + # Date Adjustment + start_date_offset: int = 0 + end_date_offset: int = 0 + start_date_with: str = "START_DATE" + end_date_with: str = "END_DATE" + + status_concepts: List[int] = field(default_factory=list) + visit_type_concepts: List[int] = field(default_factory=list) + condition_type_concepts: List[int] = field(default_factory=list) + drug_type_concepts: List[int] = field(default_factory=list) + procedure_type_concepts: List[int] = field(default_factory=list) + measurement_type_concepts: List[int] = field(default_factory=list) + observation_type_concepts: List[int] = field(default_factory=list) + device_type_concepts: List[int] = field(default_factory=list) + provider_specialty_concepts: List[int] = field(default_factory=list) + source_concept_set_id: Optional[int] = None + days_supply_min: Optional[int] = None + days_supply_max: Optional[int] = None + quantity_min: Optional[float] = None + quantity_max: Optional[float] = None + era_length_min: Optional[int] = None + era_length_max: Optional[int] = None + dose_min: Optional[float] = None + dose_max: Optional[float] = None + correlated_criteria: Optional['GroupConfig'] = None + + # Measurement specific + value_as_concept_concepts: List[int] = field(default_factory=list) + measurement_operator_concepts: List[int] = field(default_factory=list) + range_low_ratio_min: Optional[float] = None + range_low_ratio_max: Optional[float] = None + range_high_ratio_min: Optional[float] = None + range_high_ratio_max: Optional[float] = None + + # Procedure specific + procedure_modifier_concepts: List[int] = field(default_factory=list) + + # Drug specific + drug_route_concepts: List[int] = field(default_factory=list) + refills_min: Optional[int] = None + refills_max: Optional[int] = None + + # Visit specific + admitted_from_concepts: List[int] = field(default_factory=list) + discharged_to_concepts: List[int] = field(default_factory=list) + place_of_service_concepts: List[int] = field(default_factory=list) + + # Observation specific + qualifier_concepts: List[int] = field(default_factory=list) + value_as_string: Optional[str] = None + + +@dataclass +class CriteriaConfig: + """Stores a configured criteria for the cohort.""" + query_config: QueryConfig + is_exclusion: bool = False + rule_name: Optional[str] = None + + +@dataclass +class GroupConfig: + """Configuration for a group of criteria.""" + type: str = "ALL" # ALL, ANY, AT_LEAST, AT_MOST + count: Optional[int] = None + criteria: List[Union[CriteriaConfig, 'GroupConfig']] = field(default_factory=list) + + +class BaseQuery: + """ + Base class for domain-specific query builders. + """ + + def __init__( + self, + domain: str, + concept_set_id: Optional[int] = None, + parent: Optional[Union['CohortWithCriteria', 'CriteriaGroupBuilder']] = None, + is_entry: bool = False, + is_exclusion: bool = False, + is_censor: bool = False + ): + self._config = QueryConfig( + domain=domain, + concept_set_id=concept_set_id + ) + self._parent = parent + self._is_entry = is_entry + self._is_exclusion = is_exclusion + self._is_censor = is_censor + + def apply_params(self, **kwargs) -> 'BaseQuery': + """Batch apply parameters to the query configuration.""" + # Occurrence counting + if 'at_least' in kwargs: + self._config.occurrence_count = kwargs['at_least'] + self._config.occurrence_type = "atLeast" + elif 'at_most' in kwargs: + self._config.occurrence_count = kwargs['at_most'] + self._config.occurrence_type = "atMost" + elif 'exactly' in kwargs: + self._config.occurrence_count = kwargs['exactly'] + self._config.occurrence_type = "exactly" + + if 'distinct' in kwargs: + self._config.is_distinct = kwargs['distinct'] + + # Age + if 'age_min' in kwargs: + self._config.age_min = kwargs['age_min'] + if 'age_max' in kwargs: + self._config.age_max = kwargs['age_max'] + + # Time Windows + if 'anytime_before' in kwargs and kwargs['anytime_before']: + self._config.time_window = TimeWindow(days_before=99999, days_after=0) + elif 'anytime_after' in kwargs and kwargs['anytime_after']: + self._config.time_window = TimeWindow(days_before=0, days_after=99999) + elif 'within_days_before' in kwargs: + self._config.time_window = TimeWindow(days_before=kwargs['within_days_before'], days_after=0) + elif 'within_days_after' in kwargs: + self._config.time_window = TimeWindow(days_before=0, days_after=kwargs['within_days_after']) + elif 'within_days' in kwargs: + if isinstance(kwargs['within_days'], tuple): + before, after = kwargs['within_days'] + self._config.time_window = TimeWindow(days_before=before, days_after=after) + elif 'same_day' in kwargs and kwargs['same_day']: + self._config.time_window = TimeWindow(days_before=0, days_after=0) + elif 'during_event' in kwargs and kwargs['during_event']: + self._config.time_window = TimeWindow(days_before=0, days_after=0, use_index_end=True) + elif 'before_event_end' in kwargs: + self._config.time_window = TimeWindow(days_before=kwargs['before_event_end'], days_after=0, use_index_end=True) + + if 'relative_to_index_end' in kwargs and kwargs['relative_to_index_end']: + if self._config.time_window: + self._config.time_window.use_index_end = True + else: + self._config.time_window = TimeWindow(use_index_end=True) + + if 'relative_to_event_end' in kwargs and kwargs['relative_to_event_end']: + if self._config.time_window: + self._config.time_window.use_event_end = True + else: + self._config.time_window = TimeWindow(use_event_end=True) + + # Common Modifiers + if 'restrict_visit' in kwargs: + self._config.restrict_visit = kwargs['restrict_visit'] + if 'ignore_observation_period' in kwargs: + self._config.ignore_observation_period = kwargs['ignore_observation_period'] + if 'first_occurrence' in kwargs: + self._config.first_occurrence = kwargs['first_occurrence'] + + if 'gender' in kwargs: + self._config.gender_concepts = kwargs['gender'] if isinstance(kwargs['gender'], list) else [kwargs['gender']] + + return self + + # --- Fluid Chaining Methods (Time Windows & Occurrence) --- + + def at_least(self, count: int) -> 'BaseQuery': + """Require at least N occurrences.""" + self._config.occurrence_count = count + self._config.occurrence_type = "atLeast" + return self + + def at_most(self, count: int) -> 'BaseQuery': + """Require at most N occurrences.""" + self._config.occurrence_count = count + self._config.occurrence_type = "atMost" + return self + + def exactly(self, count: int) -> 'BaseQuery': + """Require exactly N occurrences.""" + self._config.occurrence_count = count + self._config.occurrence_type = "exactly" + return self + + def anytime_before(self) -> Any: + """Events occurring any time before the index.""" + self._config.time_window = TimeWindow(days_before=99999, days_after=0) + return self._finalize() + + def anytime_after(self) -> Any: + """Events occurring any time after the index.""" + self._config.time_window = TimeWindow(days_before=0, days_after=99999) + return self._finalize() + + def within_days_before(self, days: int) -> Any: + """Events occurring within N days before the index.""" + self._config.time_window = TimeWindow(days_before=days, days_after=0) + return self._finalize() + + def within_days_after(self, days: int) -> Any: + """Events occurring within N days after the index.""" + self._config.time_window = TimeWindow(days_before=0, days_after=days) + return self._finalize() + + def within_days(self, before: int = 0, after: int = 0) -> Any: + """Events occurring within a window of [before, after] days.""" + self._config.time_window = TimeWindow(days_before=before, days_after=after) + return self._finalize() + + def same_day(self) -> Any: + """Events occurring on the same day as the index.""" + self._config.time_window = TimeWindow(days_before=0, days_after=0) + return self._finalize() + + def during_event(self) -> Any: + """Events occurring within the duration of the index event.""" + self._config.time_window = TimeWindow(days_before=0, days_after=0, use_index_end=True) + return self._finalize() + + def before_event_end(self, days: int = 0) -> Any: + """Events occurring before the end of the index event.""" + self._config.time_window = TimeWindow(days_before=days, days_after=0, use_index_end=True) + return self._finalize() + + def restrict_to_visit(self) -> 'BaseQuery': + """Restrict criteria to the same visit as the index event.""" + self._config.restrict_visit = True + return self + + def between_visits(self) -> 'BaseQuery': + """Shortcut for restrict_to_visit().""" + return self.restrict_to_visit() + + def ignore_observation_period(self) -> 'BaseQuery': + """Ignore observation period requirements for this criteria.""" + self._config.ignore_observation_period = True + return self + + def with_distinct(self) -> 'BaseQuery': + """Count distinct occurrences (e.g., distinct days or distinct status).""" + self._config.is_distinct = True + return self + + def _finalize(self) -> Any: + """Add this query to the parent and return for chaining.""" + if self._parent is None: + return self + + if hasattr(self._parent, "_add_censor_query") and self._is_censor: + return self._parent._add_censor_query(self._config) + + return self._parent._add_query(self._config, self._is_exclusion) + + def _get_config(self) -> "QueryConfig": + """Get the query configuration.""" + return self._config + + def build(self) -> Any: + """Finalize the query and build the cohort.""" + return self._finalize().build() + + +class ConditionQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("ConditionOccurrence", concept_set_id, **kwargs) + + def with_status(self, *concept_ids: int) -> 'ConditionQuery': + """Require specific condition status concept IDs.""" + self._config.status_concepts.extend(concept_ids) + return self + + def with_condition_type(self, *concept_ids: int) -> 'ConditionQuery': + """Require specific condition type concept IDs.""" + self._config.condition_type_concepts.extend(concept_ids) + return self + + def apply_params(self, **kwargs) -> 'ConditionQuery': + super().apply_params(**kwargs) + if 'status' in kwargs: + self._config.status_concepts = kwargs['status'] if isinstance(kwargs['status'], list) else [kwargs['status']] + if 'type' in kwargs: + self._config.condition_type_concepts = kwargs['type'] if isinstance(kwargs['type'], list) else [kwargs['type']] + return self + + +class DrugQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("DrugExposure", concept_set_id, **kwargs) + + def with_days_supply(self, min_days: int, max_days: Optional[int] = None) -> 'DrugQuery': + """Require specific days supply range.""" + self._config.days_supply_min = min_days + self._config.days_supply_max = max_days + return self + + def with_quantity(self, min_qty: float, max_qty: Optional[float] = None) -> 'DrugQuery': + """Require specific quantity range.""" + self._config.quantity_min = min_qty + self._config.quantity_max = max_qty + return self + + def with_drug_type(self, *concept_ids: int) -> 'DrugQuery': + """Require specific drug type concept IDs.""" + self._config.drug_type_concepts.extend(concept_ids) + return self + + def with_route(self, *concept_ids: int) -> 'DrugQuery': + """Require specific drug route concept IDs.""" + self._config.drug_route_concepts.extend(concept_ids) + return self + + def with_refills(self, min_refills: int, max_refills: Optional[int] = None) -> 'DrugQuery': + """Require specific refills range.""" + self._config.refills_min = min_refills + self._config.refills_max = max_refills + return self + + def with_dose(self, min_dose: float, max_dose: Optional[float] = None) -> 'DrugQuery': + """Require specific effective drug dose range.""" + self._config.dose_min = min_dose + self._config.dose_max = max_dose + return self + + def apply_params(self, **kwargs) -> 'DrugQuery': + super().apply_params(**kwargs) + if 'days_supply_min' in kwargs: + self._config.days_supply_min = kwargs['days_supply_min'] + if 'days_supply_max' in kwargs: + self._config.days_supply_max = kwargs['days_supply_max'] + if 'type' in kwargs: + self._config.drug_type_concepts = kwargs['type'] if isinstance(kwargs['type'], list) else [kwargs['type']] + if 'quantity_min' in kwargs: + self._config.quantity_min = kwargs['quantity_min'] + if 'quantity_max' in kwargs: + self._config.quantity_max = kwargs['quantity_max'] + if 'route' in kwargs: + self._config.drug_route_concepts = kwargs['route'] if isinstance(kwargs['route'], list) else [kwargs['route']] + if 'refills_min' in kwargs: + self._config.refills_min = kwargs['refills_min'] + if 'refills_max' in kwargs: + self._config.refills_max = kwargs['refills_max'] + if 'dose_min' in kwargs: + self._config.dose_min = kwargs['dose_min'] + if 'dose_max' in kwargs: + self._config.dose_max = kwargs['dose_max'] + return self + + +class DrugEraQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("DrugEra", concept_set_id, **kwargs) + + def apply_params(self, **kwargs) -> 'DrugEraQuery': + super().apply_params(**kwargs) + if 'era_length_min' in kwargs: + self._config.era_length_min = kwargs['era_length_min'] + if 'era_length_max' in kwargs: + self._config.era_length_max = kwargs['era_length_max'] + if 'gap_days_min' in kwargs: + self._config.extent_min = kwargs['gap_days_min'] + if 'gap_days_max' in kwargs: + self._config.extent_max = kwargs['gap_days_max'] + if 'occurrence_count_min' in kwargs: + self._config.occurrence_count = kwargs['occurrence_count_min'] + self._config.value_min = kwargs['occurrence_count_min'] + if 'occurrence_count_max' in kwargs: + self._config.value_max = kwargs['occurrence_count_max'] + return self + + +class MeasurementQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("Measurement", concept_set_id, **kwargs) + + def with_value(self, min_val: float, max_val: Optional[float] = None) -> 'MeasurementQuery': + """Require specific value range.""" + self._config.value_min = min_val + self._config.value_max = max_val + return self + + def with_unit(self, *concept_ids: int) -> 'MeasurementQuery': + """Require specific unit concept IDs.""" + self._config.unit_concepts.extend(concept_ids) + return self + + def is_abnormal(self, value: bool = True) -> 'MeasurementQuery': + """Restrict to abnormal values.""" + self._config.abnormal = value + return self + + def with_range_low_ratio(self, min_ratio: float, max_ratio: Optional[float] = None) -> 'MeasurementQuery': + """Require specific range low ratio.""" + self._config.range_low_ratio_min = min_ratio + self._config.range_low_ratio_max = max_ratio + return self + + def with_range_high_ratio(self, min_ratio: float, max_ratio: Optional[float] = None) -> 'MeasurementQuery': + """Require specific range high ratio.""" + self._config.range_high_ratio_min = min_ratio + self._config.range_high_ratio_max = max_ratio + return self + + def with_operator(self, *concept_ids: int) -> 'MeasurementQuery': + """Require specific operator concept IDs.""" + self._config.measurement_operator_concepts.extend(concept_ids) + return self + + def with_value_as_concept(self, *concept_ids: int) -> 'MeasurementQuery': + """Require specific value as concept IDs.""" + self._config.value_as_concept_concepts.extend(concept_ids) + return self + + def apply_params(self, **kwargs) -> 'MeasurementQuery': + super().apply_params(**kwargs) + if 'value_min' in kwargs: + self._config.value_min = kwargs['value_min'] + if 'value_max' in kwargs: + self._config.value_max = kwargs['value_max'] + if 'unit' in kwargs: + self._config.unit_concepts = kwargs['unit'] if isinstance(kwargs['unit'], list) else [kwargs['unit']] + if 'abnormal' in kwargs: + self._config.abnormal = kwargs['abnormal'] + if 'range_low_min' in kwargs: + self._config.range_low_min = kwargs['range_low_min'] + if 'range_low_max' in kwargs: + self._config.range_low_max = kwargs['range_low_max'] + if 'range_high_min' in kwargs: + self._config.range_high_min = kwargs['range_high_min'] + if 'range_high_max' in kwargs: + self._config.range_high_max = kwargs['range_high_max'] + if 'value_as_concept' in kwargs: + self._config.value_as_concept_concepts = kwargs['value_as_concept'] if isinstance(kwargs['value_as_concept'], list) else [kwargs['value_as_concept']] + if 'operator' in kwargs: + self._config.measurement_operator_concepts = kwargs['operator'] if isinstance(kwargs['operator'], list) else [kwargs['operator']] + return self + + +class ProcedureQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("ProcedureOccurrence", concept_set_id, **kwargs) + + def with_procedure_type(self, *concept_ids: int) -> 'ProcedureQuery': + """Require specific procedure type concept IDs.""" + self._config.procedure_type_concepts.extend(concept_ids) + return self + + def with_modifier(self, *concept_ids: int) -> 'ProcedureQuery': + """Require specific procedure modifier concept IDs.""" + self._config.procedure_modifier_concepts.extend(concept_ids) + return self + + def with_quantity(self, min_qty: float, max_qty: Optional[float] = None) -> 'ProcedureQuery': + """Require specific quantity range.""" + self._config.quantity_min = min_qty + self._config.quantity_max = max_qty + return self + + def apply_params(self, **kwargs) -> 'ProcedureQuery': + super().apply_params(**kwargs) + if 'quantity_min' in kwargs: + self._config.quantity_min = kwargs['quantity_min'] + if 'quantity_max' in kwargs: + self._config.quantity_max = kwargs['quantity_max'] + if 'modifier' in kwargs: + self._config.procedure_modifier_concepts = kwargs['modifier'] if isinstance(kwargs['modifier'], list) else [kwargs['modifier']] + if 'type' in kwargs: + self._config.procedure_type_concepts = kwargs['type'] if isinstance(kwargs['type'], list) else [kwargs['type']] + return self + + +class VisitQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("VisitOccurrence", concept_set_id, **kwargs) + + def with_visit_type(self, *concept_ids: int) -> 'VisitQuery': + """Require specific visit type concept IDs.""" + self._config.visit_type_concepts.extend(concept_ids) + return self + + def with_place_of_service(self, *concept_ids: int) -> 'VisitQuery': + """Require specific place of service concept IDs.""" + self._config.place_of_service_concepts.extend(concept_ids) + return self + + def with_length(self, min_days: int, max_days: Optional[int] = None) -> 'VisitQuery': + """Require specific visit length range.""" + self._config.value_min = min_days + self._config.value_max = max_days + return self + + def apply_params(self, **kwargs) -> 'VisitQuery': + super().apply_params(**kwargs) + if 'length_min' in kwargs: + self._config.value_min = kwargs['length_min'] + if 'length_max' in kwargs: + self._config.value_max = kwargs['length_max'] + if 'place_of_service' in kwargs: + self._config.place_of_service_concepts = kwargs['place_of_service'] if isinstance(kwargs['place_of_service'], list) else [kwargs['place_of_service']] + return self + + +class ObservationQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("Observation", concept_set_id, **kwargs) + + def with_observation_type(self, *concept_ids: int) -> 'ObservationQuery': + """Require specific observation type concept IDs.""" + self._config.observation_type_concepts.extend(concept_ids) + return self + + def with_qualifier(self, *concept_ids: int) -> 'ObservationQuery': + """Require specific qualifier concept IDs.""" + self._config.qualifier_concepts.extend(concept_ids) + return self + + def with_value_as_string(self, value: str) -> 'ObservationQuery': + """Require specific value as string.""" + self._config.value_as_string = value + return self + + def apply_params(self, **kwargs) -> 'ObservationQuery': + super().apply_params(**kwargs) + if 'qualifier' in kwargs: + self._config.qualifier_concepts = kwargs['qualifier'] if isinstance(kwargs['qualifier'], list) else [kwargs['qualifier']] + if 'value_as_string' in kwargs: + self._config.value_as_string = kwargs['value_as_string'] + return self + + +class DeathQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("Death", concept_set_id, **kwargs) + + +class ConditionEraQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("ConditionEra", concept_set_id, **kwargs) + + def apply_params(self, **kwargs) -> 'ConditionEraQuery': + super().apply_params(**kwargs) + if 'era_length_min' in kwargs: + self._config.era_length_min = kwargs['era_length_min'] + if 'era_length_max' in kwargs: + self._config.era_length_max = kwargs['era_length_max'] + if 'occurrence_count_min' in kwargs: + self._config.value_min = kwargs['occurrence_count_min'] + if 'occurrence_count_max' in kwargs: + self._config.value_max = kwargs['occurrence_count_max'] + return self + + +class DeviceExposureQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("DeviceExposure", concept_set_id, **kwargs) + + +class SpecimenQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("Specimen", concept_set_id, **kwargs) + + +class ObservationPeriodQuery(BaseQuery): + def __init__(self, **kwargs): + super().__init__("ObservationPeriod", concept_set_id=None, **kwargs) + + def apply_params(self, **kwargs) -> 'ObservationPeriodQuery': + super().apply_params(**kwargs) + if 'length_min' in kwargs: + self._config.value_min = kwargs['length_min'] + if 'length_max' in kwargs: + self._config.value_max = kwargs['length_max'] + return self + + +class PayerPlanPeriodQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("PayerPlanPeriod", concept_set_id, **kwargs) + + +class LocationRegionQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("LocationRegion", concept_set_id, **kwargs) + + +class VisitDetailQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("VisitDetail", concept_set_id, **kwargs) + + +class DoseEraQuery(BaseQuery): + def __init__(self, concept_set_id: Optional[int] = None, **kwargs): + super().__init__("DoseEra", concept_set_id, **kwargs) + + +class CriteriaGroupBuilder: + def __init__(self, parent: Union['CriteriaGroupBuilder', 'CohortWithCriteria', 'BaseQuery'], group: GroupConfig): + self._parent = parent + self._group = group + + def _add_query(self, config: QueryConfig, is_exclusion: bool = False) -> 'CriteriaGroupBuilder': + self._group.criteria.append(CriteriaConfig( + query_config=config, + is_exclusion=is_exclusion + )) + return self + + def end_group(self) -> Any: + """End this group and return to parent context.""" + return self._parent + + def require_condition(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return ConditionQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_drug(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return DrugQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_drug_era(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return DrugEraQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_measurement(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return MeasurementQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_procedure(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return ProcedureQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_visit(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return VisitQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_observation(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return ObservationQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_visit_detail(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return VisitDetailQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_death(self, **kwargs) -> "CriteriaGroupBuilder": + return DeathQuery(parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_device(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return DeviceExposureQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_specimen(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return SpecimenQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_condition_era(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return ConditionEraQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def require_payer_plan_period(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return PayerPlanPeriodQuery(concept_set_id, parent=self, is_exclusion=False).apply_params(**kwargs)._finalize() + + def exclude_condition(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return ConditionQuery(concept_set_id, parent=self, is_exclusion=True).apply_params(**kwargs)._finalize() + + def exclude_drug(self, concept_set_id: int, **kwargs) -> "CriteriaGroupBuilder": + return DrugQuery(concept_set_id, parent=self, is_exclusion=True).apply_params(**kwargs)._finalize() + + + def any_of(self) -> 'CriteriaGroupBuilder': + new_group = GroupConfig(type="ANY") + self._group.criteria.append(new_group) + return CriteriaGroupBuilder(self, new_group) + + def all_of(self) -> 'CriteriaGroupBuilder': + new_group = GroupConfig(type="ALL") + self._group.criteria.append(new_group) + return CriteriaGroupBuilder(self, new_group) + + def at_least_of(self, count: int) -> 'CriteriaGroupBuilder': + new_group = GroupConfig(type="AT_LEAST", count=count) + self._group.criteria.append(new_group) + return CriteriaGroupBuilder(self, new_group) diff --git a/circe/cohortdefinition/__init__.py b/circe/cohortdefinition/__init__.py index e558191..9128634 100644 --- a/circe/cohortdefinition/__init__.py +++ b/circe/cohortdefinition/__init__.py @@ -37,6 +37,22 @@ from .concept_set_expression_query_builder import ConceptSetExpressionQueryBuilder from .interfaces import IGetCriteriaSqlDispatcher, IGetEndStrategySqlDispatcher from .printfriendly import MarkdownRender +from .validators import ( + is_first_event, + has_exclusion_rules, + has_inclusion_rule_by_name, + get_exclusion_count, + has_censoring_criteria, + get_censoring_criteria_types, + has_additional_criteria, + has_end_strategy, + get_end_strategy_type, + get_primary_criteria_types, + has_observation_window, + get_primary_limit_type, + get_concept_set_count, + has_concept_sets, +) __all__ = [ # Main cohort class @@ -76,7 +92,23 @@ "IGetCriteriaSqlDispatcher", "IGetEndStrategySqlDispatcher", # Print-Friendly - "MarkdownRender" + "MarkdownRender", + + # Validator Functions + "is_first_event", + "has_exclusion_rules", + "has_inclusion_rule_by_name", + "get_exclusion_count", + "has_censoring_criteria", + "get_censoring_criteria_types", + "has_additional_criteria", + "has_end_strategy", + "get_end_strategy_type", + "get_primary_criteria_types", + "has_observation_window", + "get_primary_limit_type", + "get_concept_set_count", + "has_concept_sets", ] # Rebuild models with forward references after all imports are complete diff --git a/circe/cohortdefinition/cohort.py b/circe/cohortdefinition/cohort.py index b62f62c..9404cf7 100644 --- a/circe/cohortdefinition/cohort.py +++ b/circe/cohortdefinition/cohort.py @@ -45,8 +45,8 @@ class CohortExpression(CirceBaseModel): Java equivalent: org.ohdsi.circe.cohortdefinition.CohortExpression """ - concept_sets: Optional[List[ConceptSet]] = Field( - default=None, + concept_sets: List[ConceptSet] = Field( + default_factory=list, validation_alias=AliasChoices("ConceptSets", "conceptSets"), serialization_alias="ConceptSets" ) @@ -89,8 +89,8 @@ class CohortExpression(CirceBaseModel): validation_alias=AliasChoices("Title", "title"), serialization_alias="Title" ) - inclusion_rules: Optional[List[InclusionRule]] = Field( - default=None, + inclusion_rules: List[InclusionRule] = Field( + default_factory=list, validation_alias=AliasChoices("InclusionRules", "inclusionRules"), serialization_alias="InclusionRules" ) @@ -99,14 +99,30 @@ class CohortExpression(CirceBaseModel): validation_alias=AliasChoices("CensorWindow", "censorWindow"), serialization_alias="CensorWindow" ) - censoring_criteria: Optional[List[CriteriaType]] = Field( - default=None, - validation_alias=AliasChoices("CensoringCriteria", "censoringCriteria"), + censoring_criteria: List[CriteriaType] = Field( + default_factory=list, + validation_alias=AliasChoices("CensoringCriteria", "censoring_criteria", "censoringCriteria"), serialization_alias="CensoringCriteria" ) model_config = ConfigDict(populate_by_name=True) + @field_validator('inclusion_rules', mode='before') + @classmethod + def allow_none_inclusion_rules(cls, v: Any) -> Any: + """Convert None to empty list for inclusion_rules.""" + if v is None: + return [] + return v + + @field_validator('concept_sets', mode='before') + @classmethod + def allow_none_concept_sets(cls, v: Any) -> Any: + """Convert None to empty list for concept_sets.""" + if v is None: + return [] + return v + @field_validator('end_strategy', mode='before') @classmethod def deserialize_end_strategy(cls, v: Any) -> Any: @@ -141,6 +157,8 @@ def deserialize_censoring_criteria(cls, v: Any) -> Any: Censoring criteria come as [{"ConditionOccurrence": {...}}, ...] and need to be unwrapped and deserialized to Criteria objects. """ + if v is None: + return [] if not v or not isinstance(v, list): return v @@ -224,8 +242,6 @@ def add_concept_set(self, concept_set: ConceptSet) -> None: """ if not isinstance(concept_set, ConceptSet): raise TypeError("Expected ConceptSet instance") - if self.concept_sets is None: - self.concept_sets = [] self.concept_sets.append(concept_set) def remove_concept_set_by_id(self, id_: int) -> None: @@ -241,8 +257,6 @@ def add_inclusion_rule(self, rule: InclusionRule) -> None: """ if not isinstance(rule, InclusionRule): raise TypeError("Expected InclusionRule instance") - if self.inclusion_rules is None: - self.inclusion_rules = [] self.inclusion_rules.append(rule) def remove_inclusion_rule_by_name(self, name: str) -> None: @@ -258,8 +272,6 @@ def add_censoring_criteria(self, criteria: Criteria) -> None: """ if not isinstance(criteria, Criteria): raise TypeError("Expected Criteria instance") - if self.censoring_criteria is None: - self.censoring_criteria = [] self.censoring_criteria.append(criteria) def remove_censoring_criteria_by_type(self, criteria_type: str) -> None: @@ -382,6 +394,171 @@ def _normalize_for_checksum(self, data: Any) -> Any: return data + # ========================================================================= + # VALIDATION AND PROPERTY CHECKING METHODS + # ========================================================================= + + def is_first_event(self) -> bool: + """Check if cohort uses first event criteria. + + Returns: + True if all primary criteria have first=True, False otherwise. + """ + if not self.primary_criteria or not self.primary_criteria.criteria_list: + return False + + # Check if all criteria have first=True + for criteria in self.primary_criteria.criteria_list: + # Get the first attribute, handling both direct attribute and nested structure + first_value = getattr(criteria, 'first', None) + if first_value is not True: + return False + + return True + + def has_exclusion_rules(self) -> bool: + """Check if cohort has exclusion rules (inclusion rules). + + Note: In CIRCE terminology, "inclusion rules" act as exclusion criteria. + + Returns: + True if the cohort has any inclusion rules. + """ + return bool(self.inclusion_rules and len(self.inclusion_rules) > 0) + + def get_exclusion_count(self) -> int: + """Get the number of exclusion rules (inclusion rules). + + Returns: + The number of inclusion rules. + """ + return len(self.inclusion_rules) if self.inclusion_rules else 0 + + def has_inclusion_rule_by_name(self, name: str) -> bool: + """Check if an inclusion rule with the given name exists. + + This is useful for checking if specific rules are shared between cohorts. + + Args: + name: The name of the inclusion rule to search for. + + Returns: + True if an inclusion rule with the given name exists. + """ + if not self.inclusion_rules: + return False + + for rule in self.inclusion_rules: + if getattr(rule, 'name', None) == name: + return True + + return False + + def has_censoring_criteria(self) -> bool: + """Check if cohort has censoring criteria. + + Returns: + True if censoring criteria are defined. + """ + return bool(self.censoring_criteria and len(self.censoring_criteria) > 0) + + def get_censoring_criteria_types(self) -> List[str]: + """Get list of censoring criteria class names. + + Returns: + List of class names (e.g., ['ConditionOccurrence', 'DrugExposure']). + """ + if not self.censoring_criteria: + return [] + + return [criteria.__class__.__name__ for criteria in self.censoring_criteria] + + def has_additional_criteria(self) -> bool: + """Check if cohort has additional criteria defined and not empty. + + Returns: + True if additional criteria are defined and not empty. + """ + if not self.additional_criteria: + return False + + # Check if the criteria group is not empty + return not self.additional_criteria.is_empty() + + def has_end_strategy(self) -> bool: + """Check if cohort has an end strategy defined. + + Returns: + True if an end strategy is defined. + """ + return self.end_strategy is not None + + def get_end_strategy_type(self) -> Optional[str]: + """Get the type of end strategy. + + Returns: + 'DateOffset', 'CustomEra', or None if no end strategy is defined. + """ + if not self.end_strategy: + return None + + class_name = self.end_strategy.__class__.__name__ + if class_name == 'DateOffsetStrategy': + return 'DateOffset' + elif class_name == 'CustomEraStrategy': + return 'CustomEra' + else: + return class_name + + def get_primary_criteria_types(self) -> List[str]: + """Get list of primary criteria class names. + + Returns: + List of class names (e.g., ['ConditionOccurrence', 'DrugExposure']). + """ + if not self.primary_criteria or not self.primary_criteria.criteria_list: + return [] + + return [criteria.__class__.__name__ for criteria in self.primary_criteria.criteria_list] + + def has_observation_window(self) -> bool: + """Check if observation window is defined in primary criteria. + + Returns: + True if observation window is defined. + """ + if not self.primary_criteria: + return False + + return self.primary_criteria.observation_window is not None + + def get_primary_limit_type(self) -> Optional[str]: + """Get the primary limit type. + + Returns: + The primary limit type (e.g., 'All', 'First') or None. + """ + if not self.primary_criteria or not self.primary_criteria.primary_limit: + return None + + return getattr(self.primary_criteria.primary_limit, 'type', None) + + def get_concept_set_count(self) -> int: + """Get the number of concept sets. + + Returns: + The number of concept sets. + """ + return len(self.concept_sets) if self.concept_sets else 0 + + def has_concept_sets(self) -> bool: + """Check if concept sets are defined. + + Returns: + True if concept sets are defined. + """ + return bool(self.concept_sets and len(self.concept_sets) > 0) + def _repr_markdown_(self) -> str: """IPython notebook markdown representation. diff --git a/circe/cohortdefinition/core.py b/circe/cohortdefinition/core.py index 664e2a7..d27e181 100644 --- a/circe/cohortdefinition/core.py +++ b/circe/cohortdefinition/core.py @@ -200,7 +200,7 @@ class CollapseSettings(CirceBaseModel): serialization_alias="EraPad" ) collapse_type: Optional[CollapseType] = Field( - default=None, + default=CollapseType.ERA, validation_alias=AliasChoices("CollapseType", "collapseType"), serialization_alias="CollapseType" ) diff --git a/circe/cohortdefinition/criteria.py b/circe/cohortdefinition/criteria.py index f0f58de..5a835e2 100644 --- a/circe/cohortdefinition/criteria.py +++ b/circe/cohortdefinition/criteria.py @@ -1004,8 +1004,8 @@ class CriteriaGroup(BaseModel): Java equivalent: org.ohdsi.circe.cohortdefinition.CriteriaGroup """ - criteria_list: Optional[List['CorelatedCriteria']] = Field( - default=None, + criteria_list: List['CorelatedCriteria'] = Field( + default_factory=list, validation_alias=AliasChoices("CriteriaList", "criteriaList"), serialization_alias="CriteriaList" ) @@ -1014,13 +1014,13 @@ class CriteriaGroup(BaseModel): validation_alias=AliasChoices("Count", "count"), serialization_alias="Count" ) - groups: Optional[List['CriteriaGroup']] = Field( - default=None, + groups: List['CriteriaGroup'] = Field( + default_factory=list, validation_alias=AliasChoices("Groups", "groups"), serialization_alias="Groups" ) - demographic_criteria_list: Optional[List[DemographicCriteria]] = Field( - default=None, + demographic_criteria_list: List[DemographicCriteria] = Field( + default_factory=list, validation_alias=AliasChoices("DemographicCriteriaList", "demographicCriteriaList"), serialization_alias="DemographicCriteriaList" ) @@ -1039,10 +1039,19 @@ def is_empty(self) -> bool: has_demographic = self.demographic_criteria_list and len(self.demographic_criteria_list) > 0 return not (has_criteria or has_groups or has_demographic) + @field_validator('demographic_criteria_list', mode='before') + @classmethod + def allow_none_demographic(cls, v: Any) -> Any: + if v is None: + return [] + return v + @field_validator('groups', mode='before') @classmethod def deserialize_groups(cls, v: Any) -> Any: # Same Logic as before, just local + if v is None: + return [] if not v or not isinstance(v, list): return v result = [] @@ -1061,6 +1070,8 @@ def deserialize_groups(cls, v: Any) -> Any: @classmethod def deserialize_criteria_list(cls, v: Any) -> Any: # Logic adapted for local CorelatedCriteria + if v is None: + return [] if not v or not isinstance(v, list): return v @@ -1238,8 +1249,8 @@ class PrimaryCriteria(BaseModel): Java equivalent: org.ohdsi.circe.cohortdefinition.PrimaryCriteria """ - criteria_list: Optional[List[CriteriaType]] = Field( - default=None, + criteria_list: List[CriteriaType] = Field( + default_factory=list, validation_alias=AliasChoices("CriteriaList", "criteriaList"), serialization_alias="CriteriaList" ) @@ -1259,6 +1270,8 @@ class PrimaryCriteria(BaseModel): @field_validator('criteria_list', mode='before') @classmethod def deserialize_criteria_list(cls, v: Any) -> Any: + if v is None: + return [] if not v or not isinstance(v, list): return v diff --git a/circe/cohortdefinition/validators.py b/circe/cohortdefinition/validators.py new file mode 100644 index 0000000..2b22cec --- /dev/null +++ b/circe/cohortdefinition/validators.py @@ -0,0 +1,220 @@ +""" +Validator functions for cohort expressions. + +This module provides standalone validator functions that delegate to CohortExpression +methods. These functions provide a functional API for users who prefer it, while the +primary implementation lives in the CohortExpression class methods. + +All functions accept a CohortExpression instance and return validation results. +""" + +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .cohort import CohortExpression + + +def is_first_event(cohort_expression: 'CohortExpression') -> bool: + """Check if cohort uses first event criteria. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if all primary criteria have first=True, False otherwise. + + Example: + >>> from circe.cohortdefinition import CohortExpression + >>> cohort = CohortExpression.model_validate(json_data) + >>> if is_first_event(cohort): + ... print("This cohort uses first event criteria") + """ + return cohort_expression.is_first_event() + + +def has_exclusion_rules(cohort_expression: 'CohortExpression') -> bool: + """Check if cohort has exclusion rules (inclusion rules). + + Note: In CIRCE terminology, "inclusion rules" act as exclusion criteria. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if the cohort has any inclusion rules. + + Example: + >>> if has_exclusion_rules(cohort): + ... print(f"Cohort has {get_exclusion_count(cohort)} exclusion rules") + """ + return cohort_expression.has_exclusion_rules() + + +def get_exclusion_count(cohort_expression: 'CohortExpression') -> int: + """Get the number of exclusion rules (inclusion rules). + + Args: + cohort_expression: The cohort expression to check. + + Returns: + The number of inclusion rules. + """ + return cohort_expression.get_exclusion_count() + + +def has_inclusion_rule_by_name(cohort_expression: 'CohortExpression', name: str) -> bool: + """Check if an inclusion rule with the given name exists. + + This is useful for checking if specific rules are shared between cohorts. + + Args: + cohort_expression: The cohort expression to check. + name: The name of the inclusion rule to search for. + + Returns: + True if an inclusion rule with the given name exists. + + Example: + >>> if has_inclusion_rule_by_name(cohort, "Prior Cancer"): + ... print("This cohort excludes patients with prior cancer") + """ + return cohort_expression.has_inclusion_rule_by_name(name) + + +def has_censoring_criteria(cohort_expression: 'CohortExpression') -> bool: + """Check if cohort has censoring criteria. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if censoring criteria are defined. + + Example: + >>> if has_censoring_criteria(cohort): + ... types = get_censoring_criteria_types(cohort) + ... print(f"Censoring on: {', '.join(types)}") + """ + return cohort_expression.has_censoring_criteria() + + +def get_censoring_criteria_types(cohort_expression: 'CohortExpression') -> List[str]: + """Get list of censoring criteria class names. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + List of class names (e.g., ['ConditionOccurrence', 'DrugExposure']). + """ + return cohort_expression.get_censoring_criteria_types() + + +def has_additional_criteria(cohort_expression: 'CohortExpression') -> bool: + """Check if cohort has additional criteria defined and not empty. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if additional criteria are defined and not empty. + """ + return cohort_expression.has_additional_criteria() + + +def has_end_strategy(cohort_expression: 'CohortExpression') -> bool: + """Check if cohort has an end strategy defined. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if an end strategy is defined. + + Example: + >>> if has_end_strategy(cohort): + ... strategy_type = get_end_strategy_type(cohort) + ... print(f"End strategy: {strategy_type}") + """ + return cohort_expression.has_end_strategy() + + +def get_end_strategy_type(cohort_expression: 'CohortExpression') -> Optional[str]: + """Get the type of end strategy. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + 'DateOffset', 'CustomEra', or None if no end strategy is defined. + """ + return cohort_expression.get_end_strategy_type() + + +def get_primary_criteria_types(cohort_expression: 'CohortExpression') -> List[str]: + """Get list of primary criteria class names. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + List of class names (e.g., ['ConditionOccurrence', 'DrugExposure']). + + Example: + >>> types = get_primary_criteria_types(cohort) + >>> print(f"Entry events: {', '.join(types)}") + """ + return cohort_expression.get_primary_criteria_types() + + +def has_observation_window(cohort_expression: 'CohortExpression') -> bool: + """Check if observation window is defined in primary criteria. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if observation window is defined. + """ + return cohort_expression.has_observation_window() + + +def get_primary_limit_type(cohort_expression: 'CohortExpression') -> Optional[str]: + """Get the primary limit type. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + The primary limit type (e.g., 'All', 'First') or None. + """ + return cohort_expression.get_primary_limit_type() + + +def get_concept_set_count(cohort_expression: 'CohortExpression') -> int: + """Get the number of concept sets. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + The number of concept sets. + """ + return cohort_expression.get_concept_set_count() + + +def has_concept_sets(cohort_expression: 'CohortExpression') -> bool: + """Check if concept sets are defined. + + Args: + cohort_expression: The cohort expression to check. + + Returns: + True if concept sets are defined. + + Example: + >>> if has_concept_sets(cohort): + ... count = get_concept_set_count(cohort) + ... print(f"Cohort uses {count} concept sets") + """ + return cohort_expression.has_concept_sets() diff --git a/circe/skills/__init__.py b/circe/skills/__init__.py new file mode 100644 index 0000000..ad38ef4 --- /dev/null +++ b/circe/skills/__init__.py @@ -0,0 +1,70 @@ +""" +Skills module for AI agent integration. + +This module provides documentation and API reference for AI agents +that use the circe package to build OHDSI cohort definitions. + +Usage: + from circe.skills import get_cohort_builder_skill + + # Get the skill documentation as a string + skill_docs = get_cohort_builder_skill() +""" + +from importlib.resources import files +from typing import Optional + + +def get_cohort_builder_skill() -> str: + """ + Return the CohortBuilder skill documentation for AI agents. + + This returns a markdown document describing how to use the + CohortBuilder context manager API to build OHDSI cohort definitions. + + Returns: + str: Markdown documentation for the CohortBuilder API + + Example: + >>> from circe.skills import get_cohort_builder_skill + >>> skill = get_cohort_builder_skill() + >>> print(skill[:100]) + --- + description: Build OHDSI cohort definitions using the Pythonic context manager API + --- + """ + skill_file = files("circe.skills").joinpath("cohort_builder.md") + return skill_file.read_text() + + +def get_skill(name: str = "cohort_builder") -> Optional[str]: + """ + Return skill documentation by name. + + Args: + name: Name of the skill (default: "cohort_builder") + + Returns: + str: Skill documentation, or None if not found + + Available skills: + - cohort_builder: Build OHDSI cohort definitions using CohortBuilder + """ + skill_map = { + "cohort_builder": get_cohort_builder_skill, + } + + func = skill_map.get(name) + if func: + return func() + return None + + +def list_skills() -> list: + """ + List all available skills. + + Returns: + list: Names of available skills + """ + return ["cohort_builder"] diff --git a/circe/skills/cohort_builder.md b/circe/skills/cohort_builder.md new file mode 100644 index 0000000..d328f21 --- /dev/null +++ b/circe/skills/cohort_builder.md @@ -0,0 +1,584 @@ +--- +description: Build OHDSI cohort definitions using the Pythonic context manager API +--- + +# Cohort Builder Skill + +Build OHDSI cohort definitions using the `CohortBuilder` context manager. + +## Key Principles + +1. **Use named concept sets** - Always attach concept sets to the cohort so they appear in output +2. **Return expressions from functions** - Define cohorts as functions for readability +3. **Save to cohorts directory** - Output multiple cohorts to a common directory by name +4. **Apply constraints only when clinically specified** - Do not add population filters (age, gender, observation windows) unless they are explicitly requested in the clinical description or essential to the phenotype definition. When in doubt, start broad and let the clinical description guide you. + +## Basic Pattern + +```python +from circe.cohort_builder import CohortBuilder +from circe.vocabulary import concept_set, descendants + + +def create_diabetes_cohort(): + """Create a Type 2 Diabetes cohort.""" + # Define concept sets (attached to cohort) + t2dm = concept_set(descendants(201826), id=1, name="Type 2 Diabetes") + metformin = concept_set(descendants(1503297), id=2, name="Metformin") + + with CohortBuilder("T2DM with Prior Metformin") as cohort: + cohort.with_concept_sets(t2dm, metformin) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + + with cohort.include_rule("Prior Metformin") as rule: + rule.require_drug(2, anytime_before=True) + + return cohort.expression + + +# Save cohort to file +if __name__ == "__main__": + from pathlib import Path + + cohort = create_diabetes_cohort() + output_dir = Path("cohorts") + output_dir.mkdir(exist_ok=True) + + output_file = output_dir / "new_t2dm_on_metformin.json" + output_file.write_text(cohort.model_dump_json(by_alias=True, indent=2)) +``` + +## API Reference + +### Context Manager + +```python +with CohortBuilder("Title") as cohort: + # Define cohort inside block + cohort.with_condition(1) + +expression = cohort.expression # Access after block exits +``` + +### Concept Sets (Required for Output) + +Always define and attach concept sets: + +```python +from circe.vocabulary import concept_set, descendants + +# Create concept sets with meaningful names +t2dm = concept_set(descendants(201826), id=1, name="Type 2 Diabetes") +metformin = concept_set(descendants(1503297), id=2, name="Metformin") + +with CohortBuilder("My Cohort") as cohort: + cohort.with_concept_sets(t2dm, metformin) # Attach to cohort + cohort.with_condition(concept_set_id=1) # Reference by ID +``` + +### Entry Events + +```python +cohort.with_condition(concept_set_id) +cohort.with_drug(concept_set_id) +cohort.with_procedure(concept_set_id) +cohort.with_measurement(concept_set_id) +cohort.with_observation(concept_set_id) +cohort.with_visit(concept_set_id) +cohort.with_death() +``` + +### Entry Configuration + +```python +# Occurrence limiting +cohort.first_occurrence() +cohort.all_occurrences() + +# Observation period requirements (apply only if clinically necessary) +cohort.with_observation_window(prior_days=365, post_days=0) # Example values + +# Age constraints (apply only if specified in clinical description) +cohort.min_age(18) # Example: minimum age +cohort.max_age(65) # Example: maximum age +cohort.require_age(min_age=18, max_age=65) # Example: age range + +# Demographic constraints (apply only if specified in clinical description) +cohort.require_gender(8507, 8532) # Example: gender concept IDs (separate args) +cohort.require_race(8527) # Example: race concept ID +cohort.require_ethnicity(38003563) # Example: ethnicity concept ID +``` + +### Inclusion/Exclusion with Time Windows + +```python +cohort.require_condition(id, within_days_before=30) +cohort.require_drug(id, anytime_before=True) +cohort.exclude_condition(id, within_days_after=90) +cohort.exclude_drug(id, same_day=True) +``` + +**Time window options:** +- `within_days_before=N` +- `within_days_after=N` +- `anytime_before=True` +- `anytime_after=True` +- `same_day=True` +- `during_event=True` + +### Named Inclusion Rules + +For attrition tracking, use nested contexts: + +```python +with cohort.include_rule("Prior Treatment") as rule: + rule.require_drug(2, anytime_before=True) + +with cohort.include_rule("No Contraindications") as rule: + rule.exclude_condition(3, within_days_before=365) +``` + +## Decision Guide: When to Apply Population Constraints + +Use this guide to determine whether to apply age, gender, observation window, or other population filters: + +### ✅ Apply Age Constraints When: +- The clinical description explicitly mentions age (e.g., "adults 18+", "elderly 65+", "children under 12") +- The phenotype is inherently age-specific (e.g., "pediatric asthma", "geriatric hip fracture") +- Age is part of a validated algorithm (e.g., Sentinel, OMOP phenotype library) + +### ❌ Do NOT Apply Age Constraints When: +- The clinical description does not mention age +- You're building a general population cohort +- The concept set itself defines the population (e.g., "pregnancy" is inherently age-restricted) + +### ✅ Apply Observation Window When: +- You need to confirm "new" or "incident" cases (require prior observation to rule out prevalent cases) +- The phenotype requires baseline characteristics (need lookback period) +- The clinical description specifies continuous enrollment + +### ❌ Do NOT Apply Observation Window When: +- Building a simple prevalence cohort +- The clinical description doesn't mention enrollment or baseline periods +- You're identifying all occurrences regardless of observation history + +### ✅ Apply Gender/Race/Ethnicity When: +- The clinical description explicitly specifies demographic restrictions +- The phenotype is inherently demographic-specific (e.g., "prostate cancer" is male-specific) + +### ❌ Do NOT Apply Demographic Filters When: +- The clinical description does not mention demographics +- You're building a general population cohort + +### General Rule: +**When in doubt, start broad.** It's easier to add constraints later than to realize you've been excluding valid patients all along. + + +## Complete Example with Multiple Cohorts + +```python +from pathlib import Path +from circe.cohort_builder import CohortBuilder +from circe.vocabulary import concept_set, descendants + + +def create_pediatric_asthma_cohort(): + """Pediatric asthma cohort - no age restriction applied. + + Note: No age filter is specified because the clinical definition + does not restrict by age. The concept set itself defines the population. + """ + asthma = concept_set(descendants(317009), id=1, name="Asthma") + + with CohortBuilder("Pediatric Asthma") as cohort: + cohort.with_concept_sets(asthma) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + # No age filter - let the data define the population + + return cohort.expression + + +def create_elderly_hip_fracture_cohort(): + """Hip fracture in elderly patients. + + Note: Age restriction IS applied here because it's clinically relevant + to the phenotype definition (elderly-specific outcome). + """ + hip_fx = concept_set(descendants(4230399), id=1, name="Hip Fracture") + + with CohortBuilder("Elderly Hip Fracture") as cohort: + cohort.with_concept_sets(hip_fx) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + cohort.min_age(65) # Example: Age IS a clinical requirement here + + return cohort.expression + + +def create_new_diabetes_users_cohort(): + """New diabetes diagnosis with continuous enrollment. + + Note: Observation window IS applied because we need to ensure + patients are observable before diagnosis (to confirm 'new' diagnosis). + """ + t2dm = concept_set(descendants(201826), id=1, name="Type 2 Diabetes") + + with CohortBuilder("New Type 2 Diabetes") as cohort: + cohort.with_concept_sets(t2dm) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + # Example: Observation window ensures we can confirm "new" diagnosis + cohort.with_observation_window(prior_days=365) + + return cohort.expression + + +def save_cohort(expression, name: str, output_dir: Path = Path("cohorts")): + """Save cohort expression to JSON file.""" + output_dir.mkdir(exist_ok=True) + filename = name.lower().replace(" ", "_") + ".json" + output_path = output_dir / filename + output_path.write_text(expression.model_dump_json(by_alias=True, indent=2)) + return output_path + + +if __name__ == "__main__": + cohorts = [ + ("pediatric_asthma", create_pediatric_asthma_cohort()), + ("elderly_hip_fracture", create_elderly_hip_fracture_cohort()), + ("new_type_2_diabetes", create_new_diabetes_users_cohort()), + ] + + for name, expr in cohorts: + path = save_cohort(expr, name) + print(f"Saved: {path}") +``` + +``` + +## Modifying Existing Cohorts + +If you have an existing cohort definition in your context (as a `CohortExpression` object or JSON), you can modify it instead of creating a new one from scratch. + +### When to Modify vs Create New + +**✅ Modify an existing cohort when:** +- You need to make small adjustments to an existing definition +- You want to add/remove specific criteria while preserving the rest +- You're iterating on a cohort based on feedback +- The existing cohort is close to what you need + +**❌ Create a new cohort when:** +- The changes are substantial (different entry event, completely different logic) +- You need a variant for comparison (keep both versions) +- The clinical description is fundamentally different + +### Loading Existing Cohorts + +Use `CohortBuilder.from_expression()` to load an existing cohort for modification: + +```python +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition import CohortExpression + +# Load from JSON +existing_json = """{ ... cohort definition ... }""" +existing = CohortExpression.model_validate_json(existing_json) + +# Load into builder for modification +with CohortBuilder.from_expression(existing) as cohort: + # Make modifications here + cohort.require_drug(5, within_days_before=30) + cohort.remove_inclusion_rule("Old Rule") + +modified = cohort.expression +``` + +### Modification Operations + +#### Adding New Criteria + +Use the same API as building new cohorts: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Add new inclusion criteria + cohort.require_measurement(4, within_days_after=90) + + # Add new exclusion criteria + cohort.exclude_condition(6, anytime_before=True) + + # Add new concept sets + from circe.vocabulary import concept_set, descendants + hba1c = concept_set(descendants(3004410), id=4, name="HbA1c") + cohort.with_concept_sets(hba1c) +``` + +#### Removing Inclusion Rules + +Remove rules by their name: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Remove a specific rule + cohort.remove_inclusion_rule("Prior Treatment") + + # Or clear all rules and start fresh + cohort.clear_inclusion_rules() +``` + +#### Removing Censoring Criteria + +Remove censoring criteria by type, concept set ID, or index: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Remove by type + cohort.remove_censoring_criteria(criteria_type="Death") + + # Remove by concept set ID + cohort.remove_censoring_criteria(concept_set_id=5) + + # Remove by index + cohort.remove_censoring_criteria(index=0) + + # Or clear all censoring criteria + cohort.clear_censoring_criteria() +``` + +#### Removing Entry Events + +Remove entry events while ensuring at least one remains: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Remove by concept set ID + cohort.remove_entry_event(concept_set_id=1) + + # Remove by type (removes first match) + cohort.remove_entry_event(criteria_type="DrugExposure") + + # Remove by index + cohort.remove_entry_event(index=0) +``` + +**Note**: You cannot remove the last entry event. At least one must remain. + +#### Removing Concept Sets + +Remove concept sets with or without cleaning up references: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Remove just the concept set (leaves references) + cohort.remove_concept_set(concept_set_id=3) + + # Remove concept set AND all criteria that reference it + cohort.remove_all_references(concept_set_id=3) +``` + +**Recommendation**: Use `remove_all_references()` to avoid orphaned criteria. + +#### Clearing Demographic Criteria + +Remove all demographic restrictions: + +```python +with CohortBuilder.from_expression(existing) as cohort: + # Clear age, gender, race, ethnicity restrictions + cohort.clear_demographic_criteria() +``` + +### Practical Modification Examples + +#### Example 1: Refining an Existing Cohort + +```python +def refine_diabetes_cohort(existing_cohort: CohortExpression): + """Refine diabetes cohort by removing age restriction and adding HbA1c requirement.""" + + # Add new concept set for HbA1c + hba1c = concept_set(descendants(3004410), id=4, name="HbA1c Measurement") + + with CohortBuilder.from_expression(existing_cohort, title="Refined Diabetes Cohort") as cohort: + # Remove age restriction (if it exists) + cohort.clear_demographic_criteria() + + # Add HbA1c measurement requirement + cohort.with_concept_sets(hba1c) + cohort.require_measurement(4, within_days_after=90) + + return cohort.expression +``` + +#### Example 2: Removing Outdated Criteria + +```python +def remove_outdated_exclusions(existing_cohort: CohortExpression): + """Remove outdated exclusion criteria from cohort.""" + + with CohortBuilder.from_expression(existing_cohort) as cohort: + # Remove specific exclusion rules + cohort.remove_inclusion_rule("Cancer Exclusion") + cohort.remove_inclusion_rule("Pregnancy Exclusion") + + # Remove death censoring + cohort.remove_censoring_criteria(criteria_type="Death") + + return cohort.expression +``` + +#### Example 3: Simplifying a Complex Cohort + +```python +def simplify_cohort(existing_cohort: CohortExpression): + """Simplify cohort by removing all inclusion rules and keeping only entry event.""" + + with CohortBuilder.from_expression(existing_cohort, title="Simplified Cohort") as cohort: + # Clear all inclusion rules + cohort.clear_inclusion_rules() + + # Clear all censoring criteria + cohort.clear_censoring_criteria() + + # Clear demographic restrictions + cohort.clear_demographic_criteria() + + return cohort.expression +``` + +#### Example 4: Adapting for Different Population + +```python +def adapt_for_pediatric_population(adult_cohort: CohortExpression): + """Adapt an adult cohort for pediatric population.""" + + with CohortBuilder.from_expression(adult_cohort, title="Pediatric Version") as cohort: + # Remove adult age restriction + cohort.clear_demographic_criteria() + + # Add pediatric age restriction + cohort.max_age(17) + + # Remove adult-specific criteria (example) + try: + cohort.remove_inclusion_rule("Pregnancy Screening") + except KeyError: + pass # Rule doesn't exist, continue + + return cohort.expression +``` + +### Decision Guide: Modify or Create New? + +Use this flowchart to decide: + +1. **Is the entry event the same?** + - No → Create new cohort + - Yes → Continue + +2. **Are you changing >50% of the criteria?** + - Yes → Consider creating new cohort + - No → Continue + +3. **Do you need to keep both versions?** + - Yes → Create new cohort (with different title) + - No → Modify existing + +4. **Are the changes additive (adding criteria)?** + - Yes → Modify existing + - Mostly removals → Modify existing + +5. **Is this a one-time adjustment or a variant?** + - One-time → Modify existing + - Variant → Create new with different title + +### Important Notes + +**Preservation of Original**: Modifications create a copy. The original `CohortExpression` is never mutated. + +```python +original = CohortExpression.model_validate_json(json_data) + +with CohortBuilder.from_expression(original) as cohort: + cohort.clear_inclusion_rules() + +modified = cohort.expression + +# original is unchanged +assert len(original.inclusion_rules) > 0 # Still has rules +assert len(modified.inclusion_rules) == 0 # Rules cleared +``` + +**Error Handling**: Removal methods raise errors for invalid operations: + +```python +try: + cohort.remove_inclusion_rule("Nonexistent Rule") +except KeyError: + print("Rule not found") + +try: + cohort.remove_entry_event(concept_set_id=999) +except ValueError: + print("Entry event not found or would leave zero entry events") +``` + +**State Reconstruction Limitations**: Complex cohorts with deeply nested logic or correlated criteria may not be fully reconstructed. The modification API works best with: +- Simple inclusion/exclusion rules +- Standard time windows +- Common criteria types + +For highly complex cohorts, consider creating a new one from scratch. + +## Common Pitfalls to Avoid + +### 1. **The "Adult Default" Trap** +❌ **Wrong**: Automatically adding `cohort.min_age(18)` to every cohort +✅ **Right**: Only add age constraints when clinically specified + +### 2. **The "365-Day Observation" Assumption** +❌ **Wrong**: Adding `cohort.with_observation_window(prior_days=365)` by default +✅ **Right**: Only require observation windows when needed for "new" diagnosis or baseline assessment + +### 3. **Copying Boilerplate Without Thinking** +❌ **Wrong**: Copying entire example code including all constraints +✅ **Right**: Read the clinical description carefully and apply only relevant constraints + +### 4. **Ignoring the Clinical Description** +❌ **Wrong**: Building what you think the cohort should be +✅ **Right**: Building exactly what the clinical description specifies + +### Example of Correct Approach: + +**Clinical Description**: "Patients with Type 2 Diabetes" + +```python +# ✅ CORRECT: No age or observation window (not specified) +def create_t2dm_cohort(): + t2dm = concept_set(descendants(201826), id=1, name="Type 2 Diabetes") + + with CohortBuilder("Type 2 Diabetes") as cohort: + cohort.with_concept_sets(t2dm) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + + return cohort.expression +``` + +```python +# ❌ WRONG: Added constraints not in description +def create_t2dm_cohort(): + t2dm = concept_set(descendants(201826), id=1, name="Type 2 Diabetes") + + with CohortBuilder("Type 2 Diabetes") as cohort: + cohort.with_concept_sets(t2dm) + cohort.with_condition(concept_set_id=1) + cohort.first_occurrence() + cohort.min_age(18) # ← NOT SPECIFIED + cohort.with_observation_window(prior_days=365) # ← NOT SPECIFIED + + return cohort.expression +``` diff --git a/circe/vocabulary/__init__.py b/circe/vocabulary/__init__.py index d7b5fd0..7696afd 100644 --- a/circe/vocabulary/__init__.py +++ b/circe/vocabulary/__init__.py @@ -9,10 +9,161 @@ Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents. """ +from typing import Optional, List, Union from .concept import ( Concept, ConceptSet, ConceptSetExpression, ConceptSetItem ) + +# ============================================================================= +# COMPOSER HELPER FUNCTIONS +# ============================================================================= + +class ConceptReference: + """ + Lightweight reference to a concept for use in concept_set() builder. + + Attributes: + concept_id: The OMOP concept ID + include_descendants: Whether to include descendant concepts + include_mapped: Whether to include mapped concepts + is_excluded: Whether this concept is excluded + """ + def __init__( + self, + concept_id: int, + include_descendants: bool = False, + include_mapped: bool = False, + is_excluded: bool = False + ): + self.concept_id = concept_id + self.include_descendants = include_descendants + self.include_mapped = include_mapped + self.is_excluded = is_excluded + + +def descendants(concept_id: int) -> ConceptReference: + """ + Create a concept reference that includes all descendants. + + Args: + concept_id: The OMOP concept ID + + Returns: + ConceptReference with include_descendants=True + + Example: + >>> # Type 2 Diabetes and all descendants + >>> descendants(201826) + """ + return ConceptReference( + concept_id=concept_id, + include_descendants=True + ) + + +def mapped(concept_id: int) -> ConceptReference: + """ + Create a concept reference that includes mapped concepts. + + Args: + concept_id: The OMOP concept ID + + Returns: + ConceptReference with include_mapped=True + """ + return ConceptReference( + concept_id=concept_id, + include_mapped=True + ) + + +def exclude(concept_ref: Union[int, ConceptReference]) -> ConceptReference: + """ + Mark a concept or concept reference as excluded. + + Args: + concept_ref: Concept ID or ConceptReference to exclude + + Returns: + ConceptReference with is_excluded=True + """ + if isinstance(concept_ref, int): + return ConceptReference(concept_id=concept_ref, is_excluded=True) + return ConceptReference( + concept_id=concept_ref.concept_id, + include_descendants=concept_ref.include_descendants, + include_mapped=concept_ref.include_mapped, + is_excluded=True + ) + + +def concept_set( + *concepts: Union[int, ConceptReference], + id: Optional[int] = None, + name: Optional[str] = None +) -> ConceptSet: + """ + Create a concept set from concept IDs or references. + + Args: + *concepts: Variable number of concept IDs or ConceptReference objects + id: Optional ID for the concept set (auto-generated if not provided) + name: Optional name for the concept set + + Returns: + ConceptSet object ready for use in cohort definitions + + Example: + >>> # Simple concept set with descendants + >>> t2dm = concept_set( + ... descendants(201826), + ... name="Type 2 Diabetes" + ... ) + + >>> # Multiple concepts + >>> diabetes = concept_set( + ... descendants(201826), # T2DM + ... descendants(201254), # T1DM + ... name="All Diabetes" + ... ) + + >>> # With exclusions + >>> t2dm_no_secondary = concept_set( + ... descendants(201826), + ... exclude(descendants(443238)), # Exclude secondary diabetes + ... name="T2DM (no secondary)" + ... ) + """ + items = [] + for concept in concepts: + if isinstance(concept, int): + # Simple concept ID + items.append(ConceptSetItem( + concept=Concept(concept_id=concept), + include_descendants=False, + include_mapped=False, + is_excluded=False + )) + elif isinstance(concept, ConceptReference): + # ConceptReference with options + items.append(ConceptSetItem( + concept=Concept(concept_id=concept.concept_id), + include_descendants=concept.include_descendants, + include_mapped=concept.include_mapped, + is_excluded=concept.is_excluded + )) + + return ConceptSet( + id=id or 0, # Will need to be set by user or auto-assigned + name=name or "Concept Set", + expression=ConceptSetExpression(items=items) + ) + + __all__ = [ - "Concept", "ConceptSet", "ConceptSetExpression", "ConceptSetItem" + # Core classes + "Concept", "ConceptSet", "ConceptSetExpression", "ConceptSetItem", + # Composer helpers + "concept_set", "descendants", "mapped", "exclude", "ConceptReference" ] \ No newline at end of file diff --git a/diabetes_cohort.json b/diabetes_cohort.json new file mode 100644 index 0000000..867b5aa --- /dev/null +++ b/diabetes_cohort.json @@ -0,0 +1,48 @@ +{ + "ConceptSets": [ + { + "id": 1, + "name": "Type 2 Diabetes Mellitus", + "expression": { + "isExcluded": false, + "includeMapped": false, + "includeDescendants": false, + "items": [ + { + "concept": { + "CONCEPT_ID": 201826, + "CONCEPT_NAME": "Type 2 diabetes mellitus", + "CONCEPT_CODE": "44054006", + "CONCEPT_CLASS_ID": "Clinical Finding", + "STANDARD_CONCEPT": "S", + "DOMAIN_ID": "Condition", + "VOCABULARY_ID": "SNOMED" + }, + "isExcluded": false, + "includeMapped": false, + "includeDescendants": true + } + ] + } + } + ], + "PrimaryCriteria": { + "CriteriaList": [ + { + "ConditionOccurrence": { + "CodesetId": 1, + "First": true, + "ConditionTypeExclude": false + } + } + ], + "ObservationWindow": { + "PriorDays": 0, + "PostDays": 0 + }, + "PrimaryCriteriaLimit": { + "Type": "All" + } + }, + "Title": "Patients with Type 2 Diabetes" +} \ No newline at end of file diff --git a/diabetes_cohort.sql b/diabetes_cohort.sql new file mode 100644 index 0000000..309c36d --- /dev/null +++ b/diabetes_cohort.sql @@ -0,0 +1,286 @@ +CREATE TABLE #Codesets ( + codeset_id int NOT NULL, + concept_id bigint NOT NULL +) +; + +INSERT INTO #Codesets (codeset_id, concept_id) +SELECT 1 as codeset_id, c.concept_id FROM (select distinct I.concept_id FROM +( + select concept_id from @vocabulary_database_schema.CONCEPT where (concept_id in (201826)) + +UNION select c.concept_id + from @vocabulary_database_schema.CONCEPT c + join @vocabulary_database_schema.CONCEPT_ANCESTOR ca on c.concept_id = ca.descendant_concept_id + WHERE c.invalid_reason is null + and (ca.ancestor_concept_id in (201826)) + +) I + +) C; + +UPDATE STATISTICS #Codesets; + + +SELECT event_id, person_id, start_date, end_date, op_start_date, op_end_date, visit_occurrence_id +INTO #qualified_events +FROM +( + select pe.event_id, pe.person_id, pe.start_date, pe.end_date, pe.op_start_date, pe.op_end_date, row_number() over (partition by pe.person_id order by pe.start_date ASC) as ordinal, cast(pe.visit_occurrence_id as bigint) as visit_occurrence_id + FROM (-- Begin Primary Events +select P.ordinal as event_id, P.person_id, P.start_date, P.end_date, op_start_date, op_end_date, cast(P.visit_occurrence_id as bigint) as visit_occurrence_id +FROM +( + select E.person_id, E.start_date, E.end_date, + row_number() OVER (PARTITION BY E.person_id ORDER BY E.sort_date ASC, E.event_id) ordinal, + OP.observation_period_start_date as op_start_date, OP.observation_period_end_date as op_end_date, cast(E.visit_occurrence_id as bigint) as visit_occurrence_id + FROM + ( + +-- Begin Condition Occurrence Criteria +SELECT C.person_id, C.condition_occurrence_id as event_id, C.start_date, C.end_date, + C.visit_occurrence_id, C.start_date as sort_date +FROM +( + SELECT co.person_id,co.condition_occurrence_id,co.condition_concept_id,co.visit_occurrence_id,co.condition_start_date as start_date, COALESCE(co.condition_end_date, DATEADD(day,1,co.condition_start_date)) as end_date , row_number() over (PARTITION BY co.person_id ORDER BY co.condition_start_date, co.condition_occurrence_id) as ordinal + FROM @cdm_database_schema.CONDITION_OCCURRENCE co + JOIN #Codesets cs on (co.condition_concept_id = cs.concept_id and cs.codeset_id = 1) +) C + +WHERE C.ordinal = 1 +-- End Condition Occurrence Criteria + + ) E + JOIN @cdm_database_schema.observation_period OP on E.person_id = OP.person_id and E.start_date >= OP.observation_period_start_date and E.start_date <= op.observation_period_end_date + WHERE DATEADD(day,0,OP.OBSERVATION_PERIOD_START_DATE) <= E.START_DATE AND DATEADD(day,0,E.START_DATE) <= OP.OBSERVATION_PERIOD_END_DATE +) P + +-- End Primary Events +) pe + +) QE + +; + +--- Inclusion Rule Inserts + +CREATE TABLE #inclusion_events (inclusion_rule_id bigint, + person_id bigint, + event_id bigint +); + +select event_id, person_id, start_date, end_date, op_start_date, op_end_date +into #included_events +FROM ( + SELECT event_id, person_id, start_date, end_date, op_start_date, op_end_date, row_number() over (partition by person_id order by start_date ASC) as ordinal + from + ( + select Q.event_id, Q.person_id, Q.start_date, Q.end_date, Q.op_start_date, Q.op_end_date, SUM(coalesce(POWER(cast(2 as bigint), I.inclusion_rule_id), 0)) as inclusion_rule_mask + from #qualified_events Q + LEFT JOIN #inclusion_events I on I.person_id = Q.person_id and I.event_id = Q.event_id + GROUP BY Q.event_id, Q.person_id, Q.start_date, Q.end_date, Q.op_start_date, Q.op_end_date + ) MG -- matching groups + +) Results + +; + + + +-- generate cohort periods into #final_cohort +select person_id, start_date, end_date +INTO #cohort_rows +from ( -- first_ends + select F.person_id, F.start_date, F.end_date + FROM ( + select I.event_id, I.person_id, I.start_date, CE.end_date, row_number() over (partition by I.person_id, I.event_id order by CE.end_date) as ordinal + from #included_events I + join ( -- cohort_ends +-- cohort exit dates +-- By default, cohort exit at the event's op end date +select event_id, person_id, op_end_date as end_date from #included_events + ) CE on I.event_id = CE.event_id and I.person_id = CE.person_id and CE.end_date >= I.start_date + ) F + WHERE F.ordinal = 1 +) FE; + + +select person_id, min(start_date) as start_date, DATEADD(day,-1 * 0, max(end_date)) as end_date +into #final_cohort +from ( + select person_id, start_date, end_date, sum(is_start) over (partition by person_id order by start_date, is_start desc rows unbounded preceding) group_idx + from ( + select person_id, start_date, end_date, + case when max(end_date) over (partition by person_id order by start_date rows between unbounded preceding and 1 preceding) >= start_date then 0 else 1 end is_start + from ( + select person_id, start_date, DATEADD(day,0,end_date) as end_date + from #cohort_rows + ) CR + ) ST +) GR +group by person_id, group_idx; + +DELETE FROM @target_database_schema.@target_cohort_table where cohort_definition_id = 1; +INSERT INTO @target_database_schema.@target_cohort_table (cohort_definition_id, subject_id, cohort_start_date, cohort_end_date) +select 1 as cohort_definition_id, person_id, start_date, end_date +FROM #final_cohort CO +; + +{1 != 0}?{ +-- BEGIN: Censored Stats + +delete from @results_database_schema.cohort_censor_stats where cohort_definition_id = 1; + +-- END: Censored Stats +} +{1 != 0 & 0 != 0}?{ + +CREATE TABLE #inclusion_rules (rule_sequence int); + +-- Find the event that is the 'best match' per person. +-- the 'best match' is defined as the event that satisfies the most inclusion rules. +-- ties are solved by choosing the event that matches the earliest inclusion rule, and then earliest. + +select q.person_id, q.event_id +into #best_events +from #qualified_events Q +join ( + SELECT R.person_id, R.event_id, ROW_NUMBER() OVER (PARTITION BY R.person_id ORDER BY R.rule_count DESC,R.min_rule_id ASC, R.start_date ASC) AS rank_value + FROM ( + SELECT Q.person_id, Q.event_id, COALESCE(COUNT(DISTINCT I.inclusion_rule_id), 0) AS rule_count, COALESCE(MIN(I.inclusion_rule_id), 0) AS min_rule_id, Q.start_date + FROM #qualified_events Q + LEFT JOIN #inclusion_events I ON q.person_id = i.person_id AND q.event_id = i.event_id + GROUP BY Q.person_id, Q.event_id, Q.start_date + ) R +) ranked on Q.person_id = ranked.person_id and Q.event_id = ranked.event_id +WHERE ranked.rank_value = 1 +; + +-- modes of generation: (the same tables store the results for the different modes, identified by the mode_id column) +-- 0: all events +-- 1: best event + + +-- BEGIN: Inclusion Impact Analysis - event +-- calculte matching group counts +delete from @results_database_schema.cohort_inclusion_result where cohort_definition_id = 1 and mode_id = 0; +insert into @results_database_schema.cohort_inclusion_result (cohort_definition_id, inclusion_rule_mask, person_count, mode_id) +select 1 as cohort_definition_id, inclusion_rule_mask, count_big(*) as person_count, 0 as mode_id +from +( + select Q.person_id, Q.event_id, CAST(SUM(coalesce(POWER(cast(2 as bigint), I.inclusion_rule_id), 0)) AS bigint) as inclusion_rule_mask + from #qualified_events Q + LEFT JOIN #inclusion_events I on q.person_id = i.person_id and q.event_id = i.event_id + GROUP BY Q.person_id, Q.event_id +) MG -- matching groups +group by inclusion_rule_mask +; + +-- calculate gain counts +delete from @results_database_schema.cohort_inclusion_stats where cohort_definition_id = 1 and mode_id = 0; +insert into @results_database_schema.cohort_inclusion_stats (cohort_definition_id, rule_sequence, person_count, gain_count, person_total, mode_id) +select 1 as cohort_definition_id, ir.rule_sequence, coalesce(T.person_count, 0) as person_count, coalesce(SR.person_count, 0) gain_count, EventTotal.total, 0 as mode_id +from #inclusion_rules ir +left join +( + select i.inclusion_rule_id, count_big(i.event_id) as person_count + from #qualified_events Q + JOIN #inclusion_events i on Q.person_id = I.person_id and Q.event_id = i.event_id + group by i.inclusion_rule_id +) T on ir.rule_sequence = T.inclusion_rule_id +CROSS JOIN (select count(*) as total_rules from #inclusion_rules) RuleTotal +CROSS JOIN (select count_big(event_id) as total from #qualified_events) EventTotal +LEFT JOIN @results_database_schema.cohort_inclusion_result SR on SR.mode_id = 0 AND SR.cohort_definition_id = 1 AND (POWER(cast(2 as bigint),RuleTotal.total_rules) - POWER(cast(2 as bigint),ir.rule_sequence) - 1) = SR.inclusion_rule_mask -- POWER(2,rule count) - POWER(2,rule sequence) - 1 is the mask for 'all except this rule' +; + +-- calculate totals +delete from @results_database_schema.cohort_summary_stats where cohort_definition_id = 1 and mode_id = 0; +insert into @results_database_schema.cohort_summary_stats (cohort_definition_id, base_count, final_count, mode_id) +select 1 as cohort_definition_id, PC.total as person_count, coalesce(FC.total, 0) as final_count, 0 as mode_id +FROM +(select count_big(event_id) as total from #qualified_events) PC, +(select sum(sr.person_count) as total + from @results_database_schema.cohort_inclusion_result sr + CROSS JOIN (select count(*) as total_rules from #inclusion_rules) RuleTotal + where sr.mode_id = 0 and sr.cohort_definition_id = 1 and sr.inclusion_rule_mask = POWER(cast(2 as bigint),RuleTotal.total_rules)-1 +) FC +; + +-- END: Inclusion Impact Analysis - event + +-- BEGIN: Inclusion Impact Analysis - person +-- calculte matching group counts +delete from @results_database_schema.cohort_inclusion_result where cohort_definition_id = 1 and mode_id = 1; +insert into @results_database_schema.cohort_inclusion_result (cohort_definition_id, inclusion_rule_mask, person_count, mode_id) +select 1 as cohort_definition_id, inclusion_rule_mask, count_big(*) as person_count, 1 as mode_id +from +( + select Q.person_id, Q.event_id, CAST(SUM(coalesce(POWER(cast(2 as bigint), I.inclusion_rule_id), 0)) AS bigint) as inclusion_rule_mask + from #best_events Q + LEFT JOIN #inclusion_events I on q.person_id = i.person_id and q.event_id = i.event_id + GROUP BY Q.person_id, Q.event_id +) MG -- matching groups +group by inclusion_rule_mask +; + +-- calculate gain counts +delete from @results_database_schema.cohort_inclusion_stats where cohort_definition_id = 1 and mode_id = 1; +insert into @results_database_schema.cohort_inclusion_stats (cohort_definition_id, rule_sequence, person_count, gain_count, person_total, mode_id) +select 1 as cohort_definition_id, ir.rule_sequence, coalesce(T.person_count, 0) as person_count, coalesce(SR.person_count, 0) gain_count, EventTotal.total, 1 as mode_id +from #inclusion_rules ir +left join +( + select i.inclusion_rule_id, count_big(i.event_id) as person_count + from #best_events Q + JOIN #inclusion_events i on Q.person_id = I.person_id and Q.event_id = i.event_id + group by i.inclusion_rule_id +) T on ir.rule_sequence = T.inclusion_rule_id +CROSS JOIN (select count(*) as total_rules from #inclusion_rules) RuleTotal +CROSS JOIN (select count_big(event_id) as total from #best_events) EventTotal +LEFT JOIN @results_database_schema.cohort_inclusion_result SR on SR.mode_id = 1 AND SR.cohort_definition_id = 1 AND (POWER(cast(2 as bigint),RuleTotal.total_rules) - POWER(cast(2 as bigint),ir.rule_sequence) - 1) = SR.inclusion_rule_mask -- POWER(2,rule count) - POWER(2,rule sequence) - 1 is the mask for 'all except this rule' +; + +-- calculate totals +delete from @results_database_schema.cohort_summary_stats where cohort_definition_id = 1 and mode_id = 1; +insert into @results_database_schema.cohort_summary_stats (cohort_definition_id, base_count, final_count, mode_id) +select 1 as cohort_definition_id, PC.total as person_count, coalesce(FC.total, 0) as final_count, 1 as mode_id +FROM +(select count_big(event_id) as total from #best_events) PC, +(select sum(sr.person_count) as total + from @results_database_schema.cohort_inclusion_result sr + CROSS JOIN (select count(*) as total_rules from #inclusion_rules) RuleTotal + where sr.mode_id = 1 and sr.cohort_definition_id = 1 and sr.inclusion_rule_mask = POWER(cast(2 as bigint),RuleTotal.total_rules)-1 +) FC +; + +-- END: Inclusion Impact Analysis - person + + +TRUNCATE TABLE #best_events; +DROP TABLE #best_events; + +TRUNCATE TABLE #inclusion_rules; +DROP TABLE #inclusion_rules; +} + + + + +TRUNCATE TABLE #cohort_rows; +DROP TABLE #cohort_rows; + +TRUNCATE TABLE #final_cohort; +DROP TABLE #final_cohort; + +TRUNCATE TABLE #inclusion_events; +DROP TABLE #inclusion_events; + +TRUNCATE TABLE #qualified_events; +DROP TABLE #qualified_events; + +TRUNCATE TABLE #included_events; +DROP TABLE #included_events; + +TRUNCATE TABLE #Codesets; +DROP TABLE #Codesets; + \ No newline at end of file diff --git a/examples/basic_cohort_fluent.py b/examples/basic_cohort_fluent.py new file mode 100644 index 0000000..3bdb476 --- /dev/null +++ b/examples/basic_cohort_fluent.py @@ -0,0 +1,78 @@ +""" +Basic Cohort Definition Example - Fluent Builder API + +This example demonstrates how to create a simple cohort definition +for patients with Type 2 Diabetes using the fluent cohort builder API. +""" + +from circe.cohort_builder import CohortBuilder +from circe.vocabulary import concept_set, descendants +from circe.api import build_cohort_query +from circe.cohortdefinition.cohort_expression_query_builder import BuildExpressionQueryOptions + + +def create_diabetes_cohort(): + """Create a simple Type 2 Diabetes cohort definition using the fluent API.""" + + # Define the Type 2 Diabetes concept set using vocabulary helpers + diabetes_concept_set = concept_set( + descendants(201826), # Type 2 diabetes mellitus + id=1, + name="Type 2 Diabetes Mellitus" + ) + + # Build cohort using fluent API + cohort = ( + CohortBuilder("Patients with Type 2 Diabetes") + .with_concept_sets(diabetes_concept_set) + .with_condition(concept_set_id=1) # Entry on T2DM diagnosis + .first_occurrence() # First diagnosis only + .build() + ) + + return cohort + + +def generate_sql_from_cohort(cohort): + """Generate SQL from the cohort definition.""" + + options = BuildExpressionQueryOptions() + options.cohort_id = 1 + options.generate_stats = True + + sql = build_cohort_query(cohort, options) + + return sql + + +if __name__ == "__main__": + # Create the cohort definition + print("Creating Type 2 Diabetes cohort definition (Fluent API)...") + cohort = create_diabetes_cohort() + + # Display cohort information + print(f"\nCohort Title: {cohort.title}") + print(f"Number of Concept Sets: {len(cohort.concept_sets) if cohort.concept_sets else 0}") + if cohort.concept_sets: + print(f"Concept Set: {cohort.concept_sets[0].name}") + + # Generate SQL + print("\nGenerating SQL...") + sql = generate_sql_from_cohort(cohort) + + # Display first 500 characters of SQL + print(f"\nGenerated SQL (first 500 chars):") + print(sql[:500]) + print("...") + + # Save outputs + output_file = "diabetes_cohort_fluent.sql" + with open(output_file, "w") as f: + f.write(sql) + print(f"\nFull SQL saved to: {output_file}") + + json_output = cohort.model_dump_json(indent=2, by_alias=True, exclude_none=True) + json_file = "diabetes_cohort_fluent.json" + with open(json_file, "w") as f: + f.write(json_output) + print(f"Cohort definition saved to: {json_file}") diff --git a/examples/complex_cohort_fluent.py b/examples/complex_cohort_fluent.py new file mode 100644 index 0000000..0094f34 --- /dev/null +++ b/examples/complex_cohort_fluent.py @@ -0,0 +1,128 @@ +""" +Complex Cohort Definition Example - Fluent Builder API + +This example demonstrates advanced features using the fluent builder: +- Entry event with first occurrence +- Age restrictions +- Observation window requirements +- Inclusion criteria with time windows +- Exclusion criteria +""" + +from circe.cohort_builder import CohortBuilder +from circe.vocabulary import concept_set, descendants +from circe.api import build_cohort_query +from circe.cohortdefinition.cohort_expression_query_builder import BuildExpressionQueryOptions + + +def create_complex_cohort(): + """ + Create a cohort for patients with: + 1. Type 2 Diabetes diagnosis (entry event) + 2. Age 18+ at index + 3. 365 days prior observation + 4. Metformin prescription within 30 days after diagnosis + 5. No insulin exposure in the 180 days before diagnosis + 6. HbA1c measurement within 180 days after diagnosis + """ + + # Define concept sets + t2dm = concept_set( + descendants(201826), # Type 2 diabetes mellitus + id=1, + name="Type 2 Diabetes Mellitus" + ) + + metformin = concept_set( + descendants(1503297), # Metformin + id=2, + name="Metformin" + ) + + insulin = concept_set( + descendants(1511348), # Insulin + id=3, + name="Insulin" + ) + + hba1c = concept_set( + descendants(3004410), # HbA1c measurement + id=4, + name="Hemoglobin A1c" + ) + + # Build cohort using fluent API + cohort = ( + CohortBuilder("T2DM Patients on Metformin (Complex)") + .with_concept_sets(t2dm, metformin, insulin, hba1c) + + # Entry: First T2DM diagnosis + .with_condition(concept_set_id=1) + .first_occurrence() + .with_observation(prior_days=365) + .min_age(18) + + # Inclusion: Must have metformin within 30 days after + .require_drug(concept_set_id=2) + .within_days_after(30) + + # Inclusion: Must have HbA1c within 180 days after + .require_measurement(concept_set_id=4) + .within_days_after(180) + + # Exclusion: No prior insulin + .exclude_drug(concept_set_id=3) + .within_days_before(180) + + .build() + ) + + return cohort + + +def generate_sql_from_cohort(cohort): + """Generate SQL from the cohort definition.""" + + options = BuildExpressionQueryOptions() + options.cohort_id = 1 + options.generate_stats = True + + sql = build_cohort_query(cohort, options) + return sql + + +if __name__ == "__main__": + print("Creating complex T2DM cohort definition (Fluent API)...") + cohort = create_complex_cohort() + + # Display cohort information + print(f"\nCohort Title: {cohort.title}") + print(f"Number of Concept Sets: {len(cohort.concept_sets) if cohort.concept_sets else 0}") + print(f"Has Primary Criteria: {cohort.primary_criteria is not None}") + print(f"Number of Inclusion Rules: {len(cohort.inclusion_rules) if cohort.inclusion_rules else 0}") + + if cohort.concept_sets: + print("\nConcept Sets:") + for cs in cohort.concept_sets: + print(f" - {cs.name} (ID: {cs.id})") + + # Generate SQL + print("\nGenerating SQL...") + sql = generate_sql_from_cohort(cohort) + + # Display first 1000 characters of SQL + print(f"\nGenerated SQL (first 1000 chars):") + print(sql[:1000]) + print("...") + + # Save outputs + output_file = "complex_cohort_fluent.sql" + with open(output_file, "w") as f: + f.write(sql) + print(f"\nFull SQL saved to: {output_file}") + + json_output = cohort.model_dump_json(indent=2, by_alias=True, exclude_none=True) + json_file = "complex_cohort_fluent.json" + with open(json_file, "w") as f: + f.write(json_output) + print(f"Cohort definition saved to: {json_file}") diff --git a/examples/secondary_primary_malignancy.py b/examples/secondary_primary_malignancy.py new file mode 100644 index 0000000..700cef7 --- /dev/null +++ b/examples/secondary_primary_malignancy.py @@ -0,0 +1,89 @@ +""" +Example: Secondary Primary Malignancy Cohort Definition + +This example demonstrates a complex cohort definition for identifying +patients with a secondary primary cancer (SPM) using the context manager API. + +Clinical Description: + - Target Population: Adults (18-85) with a history of solid tumor malignancy. + - Entry Event: First diagnosis of a solid tumor malignancy. + - Requirement 1: A second distinct primary cancer at a different anatomic site. + - Requirement 2: At least a 1-year time gap between the first and second cancer. + - Requirement 3: Evidence of diagnostic confirmation for the second cancer + (e.g., biopsy, imaging, or specific tumor markers). + - Requirement 4: Active treatment for the second cancer (chemotherapy, + radiation, or surgery). + - Exclusion: Patients with evidence of metastatic disease before or near + the second cancer diagnosis are excluded to ensure the second cancer + is indeed a primary malignancy and not a recurrence or metastasis. +""" + +import json +from circe.cohort_builder import CohortBuilder +from circe.api import build_cohort_query, cohort_print_friendly + + +def create_secondary_primary_malignancy_cohort(): + """Build the cohort using the context manager API.""" + + with CohortBuilder("Secondary Primary Malignancy") as cohort: + # --- Entry event: First solid tumor malignancy --- + cohort.with_condition(1) # Concept set 1: Solid tumor malignancies + cohort.first_occurrence() + cohort.with_observation_window(prior_days=365) + cohort.require_age(18, 85) + + # --- Inclusion Rule 1: Second distinct cancer at different site --- + with cohort.include_rule("Second Primary Cancer at Different Site") as rule: + # Must have a record for a different anatomic site after index + rule.require_condition(2, anytime_after=True) + + # --- Inclusion Rule 2: Minimum time gap (at least 1 year) --- + with cohort.include_rule("Time Gap Requirement") as rule: + rule.require_condition(2, within_days_after=365) + + # --- Inclusion Rule 3: Exclude metastatic disease --- + with cohort.include_rule("No Metastatic Disease") as rule: + # Exclude if any metastatic disease present + rule.exclude_condition(100, anytime_before=True) + rule.exclude_condition(101, anytime_before=True) + rule.exclude_condition(102, anytime_before=True) + rule.exclude_condition(103, anytime_before=True) + + # --- Inclusion Rule 4: Diagnostic Confirmation --- + with cohort.include_rule("Diagnostic Confirmation") as rule: + # Biopsy procedure + rule.require_procedure(200, within_days_after=30) + + # --- Inclusion Rule 5: Active Cancer Treatment --- + with cohort.include_rule("Active Cancer Treatment") as rule: + # Chemotherapy + rule.require_drug(400, anytime_after=True) + + return cohort.expression + + +if __name__ == "__main__": + # Create the cohort + cohort = create_secondary_primary_malignancy_cohort() + + # Export to JSON + cohort_json = cohort.model_dump_json(by_alias=True, exclude_none=True, indent=2) + print("--- Cohort JSON (Snippet) ---") + print(cohort_json[:500] + "...") + + # Generate human-readable Markdown + markdown = cohort_print_friendly(cohort) + print("\n--- Human-Readable Description ---") + print(markdown[:500] + "...") + + # Generate SQL + from circe.cohortdefinition import BuildExpressionQueryOptions + options = BuildExpressionQueryOptions() + options.cdm_schema = "my_cdm" + options.target_table = "results.cohort" + options.cohort_id = 123 + + sql = build_cohort_query(cohort, options) + print("\n--- SQL Query (Snippet) ---") + print(sql[:500] + "...") diff --git a/pyproject.toml b/pyproject.toml index a566850..44be533 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ include = ["circe*"] exclude = ["circe.tests*"] [tool.setuptools.package-data] -circe = ["py.typed"] +circe = ["py.typed", "skills/*.md"] [tool.black] line-length = 88 diff --git a/scripts/generate_skill_backup.py b/scripts/generate_skill_backup.py deleted file mode 100644 index 880d7f4..0000000 --- a/scripts/generate_skill_backup.py +++ /dev/null @@ -1,357 +0,0 @@ -""" -Generate SKILL.md for the cohort_builder from the actual codebase. - -This script introspects the cohort builder classes to extract: -- Available methods and their signatures -- Chaining behavior (return types) -- Docstrings and usage examples -- Valid parameter combinations - -The generated SKILL.md will be a single source of truth that cannot -drift from the actual API implementation. -""" - -import inspect -from typing import get_type_hints, List, Dict, Any, Set -from dataclasses import dataclass -import sys -from pathlib import Path - -# Add project root to path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from circe.cohort_builder.builder import CohortBuilder, CohortWithEntry, CohortWithCriteria -from circe.cohort_builder.query_builder import ( - BaseQuery, ConditionQuery, DrugQuery, DrugEraQuery, MeasurementQuery, - ProcedureQuery, VisitQuery, ObservationQuery, DeathQuery, - ConditionEraQuery, DeviceExposureQuery, SpecimenQuery, - ObservationPeriodQuery, PayerPlanPeriodQuery, LocationRegionQuery, - VisitDetailQuery, DoseEraQuery, CriteriaGroupBuilder -) - - -@dataclass -class MethodInfo: - """Information about a method.""" - name: str - signature: str - return_type: str - docstring: str - parameters: List[Dict[str, Any]] - is_chainable: bool - finalizes: bool # Returns parent builder (breaks chain) - - -class SkillGenerator: - """Generates SKILL.md from the cohort builder codebase.""" - - def __init__(self): - self.builder_methods: List[MethodInfo] = [] - self.entry_methods: List[MethodInfo] = [] - self.criteria_methods: List[MethodInfo] = [] - self.query_modifiers: Dict[str, List[MethodInfo]] = {} - self.time_windows: List[MethodInfo] = [] - - def extract_method_info(self, cls, method_name: str) -> MethodInfo: - """Extract information about a method.""" - method = getattr(cls, method_name) - sig = inspect.signature(method) - - # Get return type - return_annotation = sig.return_annotation - if return_annotation == inspect.Signature.empty: - return_type = "Unknown" - else: - return_type = str(return_annotation).replace("'", "") - - # Build parameter list - params = [] - for param_name, param in sig.parameters.items(): - if param_name == 'self': - continue - param_info = { - 'name': param_name, - 'type': str(param.annotation) if param.annotation != inspect.Parameter.empty else 'Any', - 'default': param.default if param.default != inspect.Parameter.empty else None, - 'required': param.default == inspect.Parameter.empty - } - params.append(param_info) - - # Build signature string - param_strs = [] - for p in params: - if p['default'] is not None: - param_strs.append(f"{p['name']}={p['default']}") - else: - param_strs.append(p['name']) - signature = f"{method_name}({', '.join(param_strs)})" - - # Get docstring - docstring = inspect.getdoc(method) or "" - - # Determine if method finalizes (returns parent) or chains (returns self) - finalizes = 'CohortWithCriteria' in return_type or 'CohortWithEntry' in return_type - is_chainable = return_type != 'None' and not finalizes - - return MethodInfo( - name=method_name, - signature=signature, - return_type=return_type, - docstring=docstring, - parameters=params, - is_chainable=is_chainable, - finalizes=finalizes - ) - - def discover_methods(self): - """Discover all public methods from the builder classes.""" - - # CohortBuilder entry methods - for name, method in inspect.getmembers(CohortBuilder, predicate=inspect.isfunction): - if name.startswith('_') or name == 'with_concept_sets': - continue - if name.startswith('with_'): - self.builder_methods.append(self.extract_method_info(CohortBuilder, name)) - - # CohortWithEntry methods - for name, method in inspect.getmembers(CohortWithEntry, predicate=inspect.isfunction): - if name.startswith('_'): - continue - if name in ['first_occurrence', 'with_observation', 'min_age', 'max_age', - 'require_age', 'require_gender', 'require_race', 'require_ethnicity', - 'begin_rule', 'any_of', 'all_of', 'at_least_of']: - self.entry_methods.append(self.extract_method_info(CohortWithEntry, name)) - - # CohortWithCriteria methods - for name, method in inspect.getmembers(CohortWithCriteria, predicate=inspect.isfunction): - if name.startswith('_'): - continue - if name.startswith('require_') or name.startswith('exclude_') or \ - name in ['any_of', 'all_of', 'at_least_of', 'begin_rule', 'build', - 'require_any_of', 'require_all_of', 'require_at_least_of', 'exclude_any_of']: - self.criteria_methods.append(self.extract_method_info(CohortWithCriteria, name)) - - # BaseQuery time windows - for name, method in inspect.getmembers(BaseQuery, predicate=inspect.isfunction): - if name in ['within_days_before', 'within_days_after', 'within_days', - 'anytime_before', 'anytime_after', 'same_day', 'restrict_to_visit', - 'during_event', 'before_event_end']: - self.time_windows.append(self.extract_method_info(BaseQuery, name)) - - # Domain-specific modifiers - modifier_map = { - 'BaseQuery': ['at_least', 'at_most', 'exactly', 'with_distinct', 'ignore_observation_period'], - 'ProcedureQuery': ['with_quantity', 'with_modifier'], - 'MeasurementQuery': ['with_operator', 'with_value', 'with_unit', 'is_abnormal', - 'with_range_low_ratio', 'with_range_high_ratio'], - 'DrugQuery': ['with_route', 'with_dose', 'with_refills', 'with_days_supply', 'with_quantity'], - 'VisitQuery': ['with_length', 'with_place_of_service'], - 'ObservationQuery': ['with_qualifier', 'with_value_as_string'] - } - - for cls_name, methods in modifier_map.items(): - cls = globals().get(cls_name) - if cls: - self.query_modifiers[cls_name] = [] - for method_name in methods: - if hasattr(cls, method_name): - self.query_modifiers[cls_name].append( - self.extract_method_info(cls, method_name) - ) - - def generate_markdown(self) -> str: - """Generate the SKILL.md content.""" - md = [] - - # Header - md.append("---") - md.append("description: Build OHDSI cohort definitions using the fluent Python API") - md.append("---") - md.append("") - md.append("# Cohort Builder Skill") - md.append("") - md.append("Build OHDSI cohort definitions step-by-step using the fluent `cohort_builder` API.") - md.append("") - md.append("**⚠️ AUTO-GENERATED**: This file is generated from the codebase. Do not edit manually.") - md.append("") - - # Entry Events - md.append("## Entry Event Methods") - md.append("") - md.append("Start building a cohort with one of these methods on `CohortBuilder`:") - md.append("") - md.append("```python") - for method in sorted(self.builder_methods, key=lambda m: m.name): - md.append(f"CohortBuilder(\"Title\").{method.signature}") - md.append("```") - md.append("") - - # Entry Configuration - md.append("## Entry Configuration Methods") - md.append("") - md.append("After defining the entry event, configure it with:") - md.append("") - for method in sorted(self.entry_methods, key=lambda m: m.name): - if method.name in ['first_occurrence', 'with_observation', 'min_age', 'max_age']: - md.append(f"### `.{method.signature}`") - if method.docstring: - md.append(f"{method.docstring}") - md.append("") - - # Demographics - md.append("## Demographic Criteria") - md.append("") - md.append("Add demographic requirements:") - md.append("") - for method in sorted(self.entry_methods, key=lambda m: m.name): - if method.name.startswith('require_'): - md.append(f"- `.{method.signature}`: {method.docstring.split('.')[0] if method.docstring else ''}") - md.append("") - - # CRITICAL CHAINING RULE - md.append("## ⚠️ CRITICAL CHAINING RULE") - md.append("") - md.append("**Modifiers MUST be called BEFORE time windows!**") - md.append("") - md.append("Time window methods finalize the criteria and return to the parent builder.") - md.append("Once a time window is called, you cannot chain further modifiers.") - md.append("") - md.append("✅ **CORRECT**:") - md.append("```python") - md.append(".require_drug(10).at_least(2).within_days_before(30)") - md.append("```") - md.append("") - md.append("❌ **INCORRECT**:") - md.append("```python") - md.append(".require_drug(10).within_days_before(30).at_least(2) # ERROR!") - md.append("```") - md.append("") - - # Time Windows - md.append("## Time Window Methods (Call LAST)") - md.append("") - md.append("These methods finalize the criteria:") - md.append("") - for method in sorted(self.time_windows, key=lambda m: m.name): - md.append(f"- `.{method.signature}`: {method.docstring.split('.')[0] if method.docstring else ''}") - md.append("") - - # Modifiers - md.append("## Modifier Methods (Call BEFORE time windows)") - md.append("") - for cls_name, methods in self.query_modifiers.items(): - if methods: - md.append(f"### {cls_name}") - md.append("") - for method in sorted(methods, key=lambda m: m.name): - md.append(f"- `.{method.signature}`") - md.append("") - - # Inclusion Criteria - md.append("## Inclusion Criteria Methods") - md.append("") - md.append("Build complex criteria with:") - md.append("") - for method in sorted(self.criteria_methods, key=lambda m: m.name): - if method.name in ['require_any_of', 'require_all_of', 'require_at_least_of', 'exclude_any_of']: - md.append(f"### `.{method.signature}`") - if method.docstring: - md.append(f"{method.docstring[:200]}...") - md.append("") - - return "\n".join(md) - - def run(self, output_path: str): - """Run the skill generator.""" - print("🔍 Discovering methods...") - self.discover_methods() - - print(f"✅ Found {len(self.builder_methods)} entry methods") - print(f"✅ Found {len(self.entry_methods)} configuration methods") - print(f"✅ Found {len(self.criteria_methods)} criteria methods") - print(f"✅ Found {len(self.time_windows)} time window methods") - - print("\n📝 Generating SKILL.md...") - content = self.generate_markdown() - - with open(output_path, 'w') as f: - f.write(content) - - print(f"✅ Generated {output_path}") - print(f"📊 Total lines: {len(content.splitlines())}") - return content - - def update_system_prompt(self, skill_content: str, prompt_path: str): - """Update a system prompt with the generated skill.""" - print(f"📝 Updating {prompt_path}...") - - try: - # Read existing prompt - with open(prompt_path, 'r') as f: - prompt_content = f.read() - except FileNotFoundError: - print(f"⚠️ Prompt file not found: {prompt_path}") - return - - # Find the SKILL section markers - start_marker = "[BEGIN SKILL.MD CONTENT]" - end_marker = "[END SKILL.MD CONTENT]" - - start_idx = prompt_content.find(start_marker) - end_idx = prompt_content.find(end_marker) - - if start_idx == -1 or end_idx == -1: - print(f"⚠️ Could not find SKILL section markers in {prompt_path}") - return - - # Replace the content between markers - # Skip the frontmatter from skill content - skill_lines = skill_content.splitlines() - skill_body = [] - in_frontmatter = False - for line in skill_lines: - if line.strip() == "---": - if not in_frontmatter: - in_frontmatter = True - else: - in_frontmatter = False - continue - if not in_frontmatter: - skill_body.append(line) - - new_skill_section = "\n".join(skill_body).strip() - - new_prompt = ( - prompt_content[:start_idx + len(start_marker)] + - "\n\n" + new_skill_section + "\n\n" + - prompt_content[end_idx:] - ) - - # Write updated prompt - with open(prompt_path, 'w') as f: - f.write(new_prompt) - - print(f"✅ Updated {prompt_path}") - - - -if __name__ == "__main__": - generator = SkillGenerator() - - # Generate SKILL.md - skill_output = ".agent/skills/cohort_builder/SKILL.md" - skill_content = generator.run(skill_output) - - # Update all system prompt variants - prompts = [ - ("prompts/reasoning_models_prompt.md", "Reasoning Models"), - ("prompts/standard_models_prompt.md", "Standard Models"), - ("prompts/fast_models_prompt.md", "Fast Models"), - ] - - for prompt_path, model_type in prompts: - generator.update_system_prompt(skill_content, prompt_path) - - print("\n✅ All documentation updated!") - print(f" - SKILL.md") - print(f" - {len(prompts)} model-specific prompts") diff --git a/tests/test_advanced_features.py b/tests/test_advanced_features.py new file mode 100644 index 0000000..61de016 --- /dev/null +++ b/tests/test_advanced_features.py @@ -0,0 +1,193 @@ +""" +Tests for Phase 4 (Time Window Enhancements) and Phase 5 (Advanced Features). +""" + +import pytest +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition.cohort import CohortExpression + + +def test_between_visits(): + """Test between_visits() restricts to same visit.""" + cohort = ( + CohortBuilder("Test Between Visits") + .with_condition(1) + .begin_rule("Same Visit Procedure") + .require_procedure(10).between_visits() + .build() + ) + + assert isinstance(cohort, CohortExpression) + # Check restrict_visit flag was set + # This maps to VisitFilter in CIRCE which requires same visit_occurrence_id + + +def test_during_event(): + """Test during_event() requires event within index duration.""" + cohort = ( + CohortBuilder("Test During Event") + .with_drug(1) + .begin_rule("Measurement During Drug") + .require_measurement(10).during_event() + .build() + ) + + # Should set time window with use_index_end=True and 0/0 days + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify time window configuration + + +def test_before_event_end(): + """Test before_event_end() relative to index end date.""" + cohort = ( + CohortBuilder("Test Before Event End") + .with_condition(1) + .begin_rule("Drug Before Condition Ends") + .require_drug(10).before_event_end(days=7) + .build() + ) + + # Should set time window relative to index end, 7 days before + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify time window with use_index_end=True + + +def test_anytime_before_unlimited(): + """Test anytime_before() with no limit.""" + cohort = ( + CohortBuilder("Test Anytime Before Unlimited") + .with_drug(1) + .begin_rule("Prior Condition") + .require_condition(10).anytime_before() + .build() + ) + + # Should set large lookback window + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify large lookback window + + +def test_within_days_before_with_years(): + """Test within_days_before() with 5 year limit.""" + cohort = ( + CohortBuilder("Test Within 5 Years Before") + .with_condition(1) + .begin_rule("Prior Procedure") + .require_procedure(10).within_days_before(1825) # 5 * 365 + .build() + ) + + # Should set time window to 1825 days before + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify 1825 day lookback + + +def test_anytime_after_unlimited(): + """Test anytime_after() with no limit.""" + cohort = ( + CohortBuilder("Test Anytime After Unlimited") + .with_condition(1) + .begin_rule("Future Death") + .require_death().anytime_after() + .build() + ) + + # Should set large lookahead window + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify large lookahead window + + +def test_within_days_after_with_years(): + """Test within_days_after() with 1 year limit.""" + cohort = ( + CohortBuilder("Test Within 1 Year After") + .with_drug(1) + .begin_rule("Death Within Year") + .require_death().within_days_after(365) # 1 * 365 + .build() + ) + + # Should set time window to 365 days after + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Verify 365 day lookahead + + +def test_with_distinct(): + """Test with_distinct() for distinct counting.""" + cohort = ( + CohortBuilder("Test Distinct Counting") + .with_condition(1) + .begin_rule("Distinct Measurements") + .require_measurement(10).with_distinct().at_least(3).anytime_before() + .build() + ) + + # Should set is_distinct flag + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.occurrence.is_distinct == True + + +def test_ignore_observation_period(): + """Test ignore_observation_period() flag.""" + cohort = ( + CohortBuilder("Test Ignore Obs Period") + .with_drug(1) + .begin_rule("Condition Outside Obs") + .require_condition(10).ignore_observation_period().anytime_before() + .build() + ) + + # Should set ignore_observation_period flag + # This maps to criteria-level flag in CIRCE + + +def test_chaining_advanced_features(): + """Test chaining multiple advanced features.""" + cohort = ( + CohortBuilder("Test Feature Chaining") + .with_condition(1) + .begin_rule("Complex Criteria") + .require_measurement(10)\ + .with_distinct()\ + .ignore_observation_period()\ + .at_least(2)\ + .within_days_before(3650) # 10 * 365 + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.occurrence.is_distinct == True + # Also verify ignore_observation_period and occurrence count + + +def test_time_window_combinations(): + """Test combining different time window methods.""" + cohort = ( + CohortBuilder("Test Time Window Mix") + .with_condition(1) + .begin_rule("Past 2 Years") + .require_drug(10).within_days_before(730) # 2 * 365 + .begin_rule("After Index") + .require_procedure(20).anytime_after() + .build() + ) + + assert len(cohort.inclusion_rules) == 2 + + +def test_between_visits_with_modifiers(): + """Test between_visits() combined with Phase 2 modifiers.""" + cohort = ( + CohortBuilder("Test Between Visits + Modifiers") + .with_condition(1) + .begin_rule("Same-Visit High-Dose Drug") + .require_drug(10)\ + .with_route(4132161)\ + .with_dose(min_dose=50.0)\ + .between_visits() + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.route_concept is not None + # Also verify restrict_visit flag diff --git a/tests/test_capr.py b/tests/test_capr.py new file mode 100644 index 0000000..e6a6025 --- /dev/null +++ b/tests/test_capr.py @@ -0,0 +1,350 @@ +""" +Unit tests for CohortComposer module. + +Tests the fluent API for building cohort definitions. +""" + +import pytest +from circe.capr import ( + cohort, entry, exit_strategy, era, attrition, + condition_occurrence, drug_exposure, drug_era, measurement, procedure, + observation, visit, visit_detail, device_exposure, specimen, death, + observation_period, payer_plan_period, location_region, + condition_era, dose_era, + at_least, at_most, exactly, with_all, with_any, + during_interval, event_starts, event_ends, continuous_observation, + sensitive_disease_cohort, specific_disease_cohort, + acute_disease_cohort, chronic_disease_cohort, new_user_drug_cohort +) +from circe.vocabulary import concept_set, descendants, mapped, exclude + + +class TestQueryFunctions: + """Test domain query functions.""" + + def test_condition_occurrence_basic(self): + query = condition_occurrence(concept_set_id=1) + assert query.domain == 'ConditionOccurrence' + assert query.concept_set_id == 1 + assert query.first_occurrence is False + + def test_condition_occurrence_first(self): + query = condition_occurrence(concept_set_id=1, first_occurrence=True) + assert query.first_occurrence is True + + def test_condition_occurrence_with_age(self): + query = condition_occurrence(concept_set_id=1, age=(18, 65)) + assert query.criteria_options['age'] == (18, 65) + + def test_drug_exposure_basic(self): + query = drug_exposure(concept_set_id=2) + assert query.domain == 'DrugExposure' + assert query.concept_set_id == 2 + + def test_drug_exposure_with_days_supply(self): + query = drug_exposure(concept_set_id=2, days_supply=(30, 90)) + assert query.criteria_options['days_supply'] == (30, 90) + + def test_drug_era(self): + query = drug_era(concept_set_id=2, first_occurrence=True) + assert query.domain == 'DrugEra' + assert query.first_occurrence is True + + def test_measurement(self): + query = measurement(concept_set_id=3, value_as_number=(6.5, 10.0)) + assert query.domain == 'Measurement' + assert query.criteria_options['value_as_number'] == (6.5, 10.0) + + def test_procedure(self): + query = procedure(concept_set_id=4) + assert query.domain == 'ProcedureOccurrence' + + def test_all_domains_create_valid_queries(self): + """Test that all domain query functions work.""" + queries = [ + condition_occurrence(concept_set_id=1), + condition_era(concept_set_id=1), + drug_exposure(concept_set_id=1), + drug_era(concept_set_id=1), + dose_era(concept_set_id=1), + procedure(concept_set_id=1), + measurement(concept_set_id=1), + observation(concept_set_id=1), + visit(concept_set_id=1), + visit_detail(concept_set_id=1), + device_exposure(concept_set_id=1), + specimen(concept_set_id=1), + death(), + observation_period(), + payer_plan_period(concept_set_id=1), + location_region(concept_set_id=1), + ] + assert len(queries) == 16 + + +class TestWindowFunctions: + """Test time window functions.""" + + def test_event_starts(self): + interval = event_starts(before=365, after=0) + assert interval.start == 365 + assert interval.end == 0 + assert interval.index == "startDate" + + def test_event_ends(self): + interval = event_ends(before=0, after=30) + assert interval.start == 0 + assert interval.end == 30 + assert interval.index == "endDate" + + def test_during_interval(self): + window = during_interval( + start_window=event_starts(before=365, after=0) + ) + assert window.start_window is not None + assert window.start_window.start == 365 + + def test_continuous_observation(self): + obs = continuous_observation(prior_days=365, post_days=0) + assert obs.prior_days == 365 + assert obs.post_days == 0 + + +class TestCriteriaFunctions: + """Test occurrence counting functions.""" + + def test_at_least(self): + criteria = at_least( + count=2, + query=drug_exposure(concept_set_id=1), + aperture=during_interval(event_starts(before=365, after=0)) + ) + assert criteria.occurrence_type == 'atLeast' + assert criteria.count == 2 + + def test_at_most(self): + criteria = at_most( + count=1, + query=condition_occurrence(concept_set_id=1), + aperture=during_interval(event_starts(before=365, after=0)) + ) + assert criteria.occurrence_type == 'atMost' + assert criteria.count == 1 + + def test_exactly(self): + criteria = exactly( + count=0, + query=drug_exposure(concept_set_id=1), + aperture=during_interval(event_starts(before=365, after=1)) + ) + assert criteria.occurrence_type == 'exactly' + assert criteria.count == 0 + + def test_with_all(self): + group = with_all( + at_least(1, condition_occurrence(1)), + at_least(1, drug_exposure(2)) + ) + assert group.group_type == 'ALL' + assert len(group.criteria_list) == 2 + + def test_with_any(self): + group = with_any( + at_least(1, condition_occurrence(1)), + at_least(1, condition_occurrence(2)) + ) + assert group.group_type == 'ANY' + assert len(group.criteria_list) == 2 + + +class TestAttrition: + """Test attrition functions.""" + + def test_attrition_single_rule(self): + rules = attrition( + no_prior_drug=with_all( + exactly(0, drug_exposure(1)) + ) + ) + assert len(rules.rules) == 1 + assert rules.rules[0].name == "no prior drug" + + def test_attrition_multiple_rules(self): + rules = attrition( + rule_one=with_all(at_least(1, condition_occurrence(1))), + rule_two=with_all(exactly(0, drug_exposure(2))) + ) + assert len(rules.rules) == 2 + + +class TestCohortConstruction: + """Test main cohort construction functions.""" + + def test_simple_entry_cohort(self): + c = cohort( + title="Simple Cohort", + entry=entry( + condition_occurrence(concept_set_id=1, first_occurrence=True), + observation_window=(365, 0) + ) + ) + assert c.title == "Simple Cohort" + assert c.entry_event is not None + assert c.entry_event.observation_window.prior_days == 365 + + def test_cohort_with_attrition(self): + c = cohort( + title="Cohort with Attrition", + entry=entry( + condition_occurrence(concept_set_id=1, first_occurrence=True), + observation_window=(365, 0) + ), + attrition=attrition( + no_prior_drug=with_all( + exactly(0, drug_exposure(2), during_interval(event_starts(before=365, after=1))) + ) + ) + ) + assert c.attrition is not None + assert len(c.attrition.rules) == 1 + + def test_build_produces_cohort_expression(self): + c = cohort( + title="Build Test", + entry=entry( + condition_occurrence(concept_set_id=1, first_occurrence=True), + observation_window=(365, 0) + ) + ) + expr = c.build() + assert expr.title == "Build Test" + assert expr.primary_criteria is not None + + def test_exit_strategy_observation(self): + ex = exit_strategy(end_strategy="observation") + assert ex.strategy_type == "observation" + + def test_exit_strategy_date_offset(self): + ex = exit_strategy(end_strategy="date_offset", offset_days=365) + assert ex.strategy_type == "date_offset" + assert ex.offset_days == 365 + + +class TestTemplates: + """Test generic cohort templates.""" + + def test_sensitive_disease_cohort(self): + expr = sensitive_disease_cohort( + concept_set_id=1, + title="Sensitive Test" + ) + assert expr.title == "Sensitive Test" + assert expr.primary_criteria is not None + + def test_specific_disease_cohort(self): + expr = specific_disease_cohort( + concept_set_id=1, + confirmation_days=30, + title="Specific Test" + ) + assert expr.title == "Specific Test" + assert expr.inclusion_rules is not None + + def test_acute_disease_cohort(self): + expr = acute_disease_cohort( + concept_set_id=1, + washout_days=180, + title="Acute Test" + ) + assert expr.title == "Acute Test" + assert expr.inclusion_rules is not None + + def test_chronic_disease_cohort(self): + expr = chronic_disease_cohort( + concept_set_id=1, + lookback_days=365, + title="Chronic Test" + ) + assert expr.title == "Chronic Test" + assert expr.inclusion_rules is not None + + def test_new_user_drug_cohort(self): + expr = new_user_drug_cohort( + drug_concept_set_id=2, + washout_days=365, + title="New User Test" + ) + assert expr.title == "New User Test" + assert expr.inclusion_rules is not None + + +class TestConceptSetHelpers: + """Test vocabulary helper functions.""" + + def test_descendants(self): + ref = descendants(201826) + assert ref.concept_id == 201826 + assert ref.include_descendants is True + + def test_mapped(self): + ref = mapped(201826) + assert ref.concept_id == 201826 + assert ref.include_mapped is True + + def test_exclude_int(self): + ref = exclude(201826) + assert ref.concept_id == 201826 + assert ref.is_excluded is True + + def test_exclude_reference(self): + ref = exclude(descendants(201826)) + assert ref.concept_id == 201826 + assert ref.include_descendants is True + assert ref.is_excluded is True + + def test_concept_set_simple(self): + cs = concept_set(201826, name="T2DM") + assert cs.name == "T2DM" + assert len(cs.expression.items) == 1 + + def test_concept_set_with_descendants(self): + cs = concept_set(descendants(201826), name="T2DM with descendants") + assert cs.expression.items[0].include_descendants is True + + def test_concept_set_multiple(self): + cs = concept_set( + descendants(201826), + descendants(201254), + name="Multiple Diabetes" + ) + assert len(cs.expression.items) == 2 + + +class TestIntegration: + """Integration tests for full cohort building.""" + + def test_full_cohort_with_all_components(self): + c = cohort( + title="Full Integration Test", + entry=entry( + drug_exposure(concept_set_id=2, first_occurrence=True), + observation_window=(365, 0), + primary_criteria_limit="First" + ), + attrition=attrition( + has_indication=with_all( + at_least(1, condition_occurrence(1), during_interval(event_starts(before=365, after=0))) + ), + no_prior_drug=with_all( + exactly(0, drug_exposure(2), during_interval(event_starts(before=365, after=1))) + ) + ), + exit=exit_strategy(end_strategy="observation"), + era=era(era_days=0) + ) + + expr = c.build() + assert expr.title == "Full Integration Test" + assert expr.primary_criteria is not None + assert expr.inclusion_rules is not None + assert len(expr.inclusion_rules) == 2 diff --git a/tests/test_cohort_builder.py b/tests/test_cohort_builder.py new file mode 100644 index 0000000..b62ec21 --- /dev/null +++ b/tests/test_cohort_builder.py @@ -0,0 +1,283 @@ +""" +Unit tests for the Simplified Parameter-Based Fluent Builder. +""" + +import pytest +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition import CohortExpression + + +# ============================================================================= +# CONTEXT MANAGER TESTS +# ============================================================================= + +class TestContextManager: + """Test the context manager API for CohortBuilder.""" + + def test_context_manager_basic(self): + """Test basic context manager usage.""" + with CohortBuilder("Context Test") as cohort: + cohort.with_condition(1) + + # After exiting, expression should be built + expr = cohort.expression + assert isinstance(expr, CohortExpression) + assert expr.title == "Context Test" + + def test_context_manager_auto_builds(self): + """Verify that auto-build happens on context exit.""" + with CohortBuilder("Auto Build Test") as cohort: + cohort.with_drug(2) + cohort.first_occurrence() + cohort.with_observation_window(prior_days=365) + + expr = cohort.expression + assert expr.primary_criteria.observation_window.prior_days == 365 + + def test_context_manager_with_criteria(self): + """Test context manager with inclusion/exclusion criteria.""" + with CohortBuilder("Criteria Test") as cohort: + cohort.with_condition(1) + cohort.require_drug(2, within_days_before=30) + cohort.exclude_condition(3, anytime_before=True) + + expr = cohort.expression + assert expr.inclusion_rules is not None + assert len(expr.inclusion_rules) >= 1 + + def test_context_manager_result_access_before_exit_raises(self): + """Accessing expression inside context should raise.""" + with pytest.raises(RuntimeError, match="inside the context manager"): + with CohortBuilder("Error Test") as cohort: + cohort.with_condition(1) + _ = cohort.expression # Should raise + + def test_context_manager_no_entry_event_raises(self): + """Accessing expression with no entry event should raise.""" + with CohortBuilder("Empty Test") as cohort: + pass # No entry event defined + + with pytest.raises(RuntimeError, match="No cohort has been built"): + _ = cohort.expression + + def test_context_manager_nested_rules(self): + """Test nested inclusion rule contexts.""" + with CohortBuilder("Nested Rules Test") as cohort: + cohort.with_condition(1) + + with cohort.include_rule("Prior Treatment") as rule: + rule.require_drug(2, anytime_before=True) + + with cohort.include_rule("Lab Confirmation") as rule: + rule.require_measurement(3, same_day=True) + + expr = cohort.expression + # Should have two named rules + rule_names = [r.name for r in expr.inclusion_rules] + assert "Prior Treatment" in rule_names + assert "Lab Confirmation" in rule_names + + def test_context_manager_exception_still_builds(self): + """Even if an exception occurs, the cohort should be built if possible.""" + cohort = CohortBuilder("Exception Test") + try: + with cohort: + cohort.with_condition(1) + raise ValueError("Test exception") + except ValueError: + pass + + # Cohort should still be built + expr = cohort.expression + assert isinstance(expr, CohortExpression) + + def test_context_manager_chaining_returns_self(self): + """Context manager methods should return self for chaining.""" + with CohortBuilder("Chaining Test") as cohort: + result = cohort.with_condition(1) + assert result is cohort + + result = cohort.first_occurrence() + assert result is cohort + + result = cohort.require_drug(2, within_days_before=30) + assert result is cohort + + +# ============================================================================= +# FLUENT API TESTS (Backwards Compatibility) +# ============================================================================= + + +class TestCohortStart: + """Test initial Cohort state.""" + + def test_create_cohort(self): + cohort = CohortBuilder("Test Cohort") + assert cohort._title == "Test Cohort" + + def test_with_condition_returns_entry_state(self): + result = CohortBuilder("Test").with_condition(concept_set_id=1) + assert hasattr(result, 'first_occurrence') + assert hasattr(result, 'with_observation') + assert hasattr(result, 'build') + + def test_with_drug_returns_entry_state(self): + result = CohortBuilder("Test").with_drug(concept_set_id=1) + assert hasattr(result, 'first_occurrence') + + +class TestCohortWithEntry: + """Test CohortWithEntry state.""" + + def test_first_occurrence_chainable(self): + result = (CohortBuilder("Test") + .with_condition(1) + .first_occurrence()) + assert result._entry_queries[0]._get_config().first_occurrence is True + + def test_with_observation(self): + result = (CohortBuilder("Test") + .with_condition(1) + .with_observation(prior_days=365, post_days=30)) + assert result._prior_observation_days == 365 + assert result._post_observation_days == 30 + + def test_min_age(self): + result = (CohortBuilder("Test") + .with_condition(1) + .min_age(18)) + assert result._entry_queries[0]._get_config().age_min == 18 + + def test_build_from_entry(self): + expr = (CohortBuilder("Simple") + .with_condition(1) + .build()) + assert isinstance(expr, CohortExpression) + assert expr.title == "Simple" + + +class TestCohortWithCriteria: + """Test CohortWithCriteria state.""" + + def test_require_returns_criteria_state(self): + # In new API, require_drug returns ChoiceWithCriteria directly if params are passed + result = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, anytime_before=True)) + # It returns CohortWithCriteria, so it should have build(), require_*, etc. + assert hasattr(result, 'require_condition') + assert hasattr(result, 'build') + + def test_chained_criteria(self): + result = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, within_days=(0, 30)) # within_days_after(30) + .exclude_drug(3, anytime_before=True)) + assert len(result._rules[0]["group"].criteria) == 2 + + +class TestBuildCohortExpression: + """Test building final CohortExpression.""" + + def test_simple_cohort(self): + expr = (CohortBuilder("Simple Cohort") + .with_condition(1) + .first_occurrence() + .with_observation(prior_days=365) + .build()) + + assert expr.title == "Simple Cohort" + assert expr.primary_criteria is not None + assert expr.primary_criteria.observation_window.prior_days == 365 + + def test_cohort_with_inclusion(self): + expr = (CohortBuilder("With Inclusion") + .with_drug(2) + .first_occurrence() + .require_condition(1, within_days_before=365) + .build()) + + assert expr.inclusion_rules is not None + assert len(expr.inclusion_rules) == 1 + + def test_cohort_with_exclusion(self): + expr = (CohortBuilder("With Exclusion") + .with_condition(1) + .exclude_drug(2, anytime_before=True) + .build()) + + assert expr.inclusion_rules is not None + + def test_cohort_with_multiple_criteria(self): + expr = (CohortBuilder("Multiple Criteria") + .with_drug(2) + .first_occurrence() + .with_observation(prior_days=365) + .min_age(18) + .require_condition(1, within_days_before=365) + .exclude_drug(2, within_days_before=365) + .require_measurement(3, within_days_after=30) + .build()) + + assert expr.primary_criteria is not None + assert expr.inclusion_rules is not None + + +class TestQueryMethods: + """Test query configuration via parameters.""" + + def test_within_days_before(self): + result = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, within_days_before=365)) + + config = result._rules[0]["group"].criteria[0].query_config + assert config.time_window.days_before == 365 + assert config.time_window.days_after == 0 + + def test_within_days_after(self): + result = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, within_days_after=30)) + + config = result._rules[0]["group"].criteria[0].query_config + assert config.time_window.days_before == 0 + assert config.time_window.days_after == 30 + + def test_anytime_before(self): + result = (CohortBuilder("Test") + .with_condition(1) + .exclude_drug(2, anytime_before=True)) + + config = result._rules[0]["group"].criteria[0].query_config + assert config.time_window.days_before == 99999 + + def test_same_day(self): + result = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, same_day=True)) + + config = result._rules[0]["group"].criteria[0].query_config + assert config.time_window.days_before == 0 + assert config.time_window.days_after == 0 + + +def test_begin_end_rule(): + """Test that begin_rule and end_rule work correctly.""" + cohort = ( + CohortBuilder("Test Rule Blocks") + .with_condition(1) + .begin_rule("Rule A") + .require_drug(2, anytime_before=True) + .end_rule() + .begin_rule("Rule B") + .require_measurement(3, same_day=True) + .end_rule() + .build() + ) + + # Inclusion rules should be "Rule A" and "Rule B" + # (The default "Inclusion Criteria" rule is skipped because it's empty) + assert cohort.inclusion_rules[0].name == "Rule A" + assert cohort.inclusion_rules[1].name == "Rule B" diff --git a/tests/test_cohort_expression.py b/tests/test_cohort_expression.py index 8a72fed..c967608 100644 --- a/tests/test_cohort_expression.py +++ b/tests/test_cohort_expression.py @@ -34,7 +34,7 @@ def test_cohort_expression_empty_initialization(self): """Test CohortExpression with no parameters.""" cohort = CohortExpression() - self.assertIsNone(cohort.concept_sets) + self.assertEqual(cohort.concept_sets, []) self.assertIsNone(cohort.qualified_limit) self.assertIsNone(cohort.additional_criteria) self.assertIsNone(cohort.end_strategy) @@ -43,9 +43,9 @@ def test_cohort_expression_empty_initialization(self): self.assertIsNone(cohort.expression_limit) self.assertIsNone(cohort.collapse_settings) self.assertIsNone(cohort.title) - self.assertIsNone(cohort.inclusion_rules) + self.assertEqual(cohort.inclusion_rules, []) self.assertIsNone(cohort.censor_window) - self.assertIsNone(cohort.censoring_criteria) + self.assertEqual(cohort.censoring_criteria, []) def test_cohort_expression_with_title(self): """Test CohortExpression with title.""" @@ -370,8 +370,42 @@ def test_cohort_expression_with_none_values(self): self.assertIsNone(cohort.title) self.assertIsNone(cohort.primary_criteria) - self.assertIsNone(cohort.concept_sets) + self.assertEqual(cohort.concept_sets, []) + def test_cohort_expression_inclusion_rules_none_to_list(self): + """Test that inclusion_rules=None is converted to empty list.""" + # Test via constructor + cohort = CohortExpression(inclusion_rules=None) + self.assertEqual(cohort.inclusion_rules, []) + + # Test via JSON validation + cohort_json = CohortExpression.model_validate({"InclusionRules": None}) + self.assertEqual(cohort_json.inclusion_rules, []) + + def test_cohort_expression_list_defaults(self): + """Test defaults and None handling for list fields.""" + # 1. Default Initialization + c = CohortExpression() + self.assertEqual(c.concept_sets, []) + self.assertEqual(c.censoring_criteria, []) + self.assertEqual(c.inclusion_rules, []) + + # 2. None Initialization + c_none = CohortExpression( + concept_sets=None, + censoring_criteria=None, + inclusion_rules=None + ) + self.assertEqual(c_none.concept_sets, []) + self.assertEqual(c_none.censoring_criteria, []) + self.assertEqual(c_none.inclusion_rules, []) + + # 3. JSON Null + c_json = CohortExpression.model_validate_json('{"ConceptSets": null, "CensoringCriteria": null, "InclusionRules": null}') + self.assertEqual(c_json.concept_sets, []) + self.assertEqual(c_json.censoring_criteria, []) + self.assertEqual(c_json.inclusion_rules, []) + def test_cohort_expression_empty_string_title(self): """Test CohortExpression with empty string title.""" cohort = CohortExpression(title="") diff --git a/tests/test_cohort_modification.py b/tests/test_cohort_modification.py new file mode 100644 index 0000000..2819e1e --- /dev/null +++ b/tests/test_cohort_modification.py @@ -0,0 +1,400 @@ +""" +Unit tests for cohort modification capabilities. +""" + +import pytest +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition import CohortExpression, InclusionRule +from circe.vocabulary import ConceptSet, Concept + + +class TestCohortModification: + """Test cohort modification capabilities.""" + + def test_from_expression_basic(self): + """Test loading a cohort from expression.""" + # Build a simple cohort + expr = (CohortBuilder("Original Cohort") + .with_condition(1) + .first_occurrence() + .with_observation(prior_days=365) + .build()) + + # Load it for modification + builder = CohortBuilder.from_expression(expr) + + # Verify it loaded correctly + assert builder._title == "Original Cohort" + assert builder._state is not None + assert len(builder._state._entry_configs) == 1 + assert builder._state._prior_observation == 365 + + def test_from_expression_with_new_title(self): + """Test loading with a new title.""" + expr = (CohortBuilder("Original") + .with_drug(2) + .build()) + + builder = CohortBuilder.from_expression(expr, title="Modified") + assert builder._title == "Modified" + + def test_from_expression_preserves_concept_sets(self): + """Test that concept sets are preserved.""" + cs = ConceptSet(id=1, name="Test Concept Set") + + expr = (CohortBuilder("Test") + .with_concept_sets(cs) + .with_condition(1) + .build()) + + builder = CohortBuilder.from_expression(expr) + assert len(builder._concept_sets) == 1 + assert builder._concept_sets[0].id == 1 + + def test_from_expression_no_primary_criteria_raises(self): + """Test that loading without primary criteria raises error.""" + # Create an invalid expression manually + expr = CohortExpression(title="Invalid") + + with pytest.raises(ValueError, match="Cannot modify cohort without primary criteria"): + CohortBuilder.from_expression(expr) + + def test_modify_and_add_criteria(self): + """Test modifying an existing cohort by adding criteria.""" + # Create original cohort + expr = (CohortBuilder("Diabetes") + .with_condition(1) + .first_occurrence() + .build()) + + # Modify it + with CohortBuilder.from_expression(expr) as cohort: + cohort.require_drug(2, within_days_before=30) + cohort.exclude_condition(3, anytime_before=True) + + modified = cohort.expression + + # Verify modifications + assert modified.title == "Diabetes" + assert len(modified.inclusion_rules) >= 1 + + def test_remove_inclusion_rule(self): + """Test removing an inclusion rule by name.""" + # Create cohort with named rules + expr = (CohortBuilder("Test") + .with_condition(1) + .begin_rule("Rule A") + .require_drug(2, anytime_before=True) + .end_rule() + .begin_rule("Rule B") + .require_measurement(3, same_day=True) + .end_rule() + .build()) + + # Remove Rule A + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_inclusion_rule("Rule A") + + modified = cohort.expression + rule_names = [r.name for r in modified.inclusion_rules] + + assert "Rule A" not in rule_names + assert "Rule B" in rule_names + + def test_remove_inclusion_rule_invalid_name_raises(self): + """Test that removing non-existent rule raises KeyError.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .build()) + + with pytest.raises(KeyError, match="No inclusion rule found"): + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_inclusion_rule("Nonexistent Rule") + + def test_remove_censoring_criteria_by_type(self): + """Test removing censoring criteria by type.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .censor_on_drug(2, anytime_after=True) + .censor_on_death() + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_censoring_criteria(criteria_type="Death") + + modified = cohort.expression + + # Should have one censoring criteria left (drug) + assert len(modified.censoring_criteria) == 1 + assert modified.censoring_criteria[0].__class__.__name__ == "DrugExposure" + + def test_remove_censoring_criteria_by_concept_set(self): + """Test removing censoring criteria by concept set ID.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .censor_on_drug(2, anytime_after=True) + .censor_on_condition(3, anytime_after=True) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_censoring_criteria(concept_set_id=2) + + modified = cohort.expression + + # Should have one censoring criteria left (condition with cs_id=3) + assert len(modified.censoring_criteria) == 1 + assert modified.censoring_criteria[0].codeset_id == 3 + + def test_remove_censoring_criteria_by_index(self): + """Test removing censoring criteria by index.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .censor_on_drug(2, anytime_after=True) + .censor_on_condition(3, anytime_after=True) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_censoring_criteria(index=0) + + modified = cohort.expression + + # Should have one censoring criteria left + assert len(modified.censoring_criteria) == 1 + + def test_remove_censoring_criteria_no_args_raises(self): + """Test that calling without arguments raises ValueError.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .build()) + + with pytest.raises(ValueError, match="Must provide one of"): + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_censoring_criteria() + + def test_remove_censoring_criteria_multiple_args_raises(self): + """Test that providing multiple arguments raises ValueError.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .build()) + + with pytest.raises(ValueError, match="Can only provide one of"): + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_censoring_criteria(criteria_type="Death", index=0) + + def test_remove_entry_event_by_concept_set(self): + """Test removing an entry event by concept set ID.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .or_with_drug(2) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_entry_event(concept_set_id=1) + + modified = cohort.expression + + # Should have one entry event left (drug) + assert len(modified.primary_criteria.criteria_list) == 1 + assert modified.primary_criteria.criteria_list[0].__class__.__name__ == "DrugExposure" + + def test_remove_entry_event_by_type(self): + """Test removing an entry event by type.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .or_with_drug(2) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_entry_event(criteria_type="ConditionOccurrence") + + modified = cohort.expression + + # Should have one entry event left (drug) + assert len(modified.primary_criteria.criteria_list) == 1 + assert modified.primary_criteria.criteria_list[0].__class__.__name__ == "DrugExposure" + + def test_remove_last_entry_event_raises(self): + """Test that removing the last entry event raises ValueError.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .build()) + + with pytest.raises(ValueError, match="Cannot remove the last entry event"): + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_entry_event(concept_set_id=1) + + def test_remove_concept_set(self): + """Test removing a concept set by ID.""" + cs1 = ConceptSet(id=1, name="CS1") + cs2 = ConceptSet(id=2, name="CS2") + + expr = (CohortBuilder("Test") + .with_concept_sets(cs1, cs2) + .with_condition(1) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_concept_set(concept_set_id=1) + + modified = cohort.expression + + assert len(modified.concept_sets) == 1 + assert modified.concept_sets[0].id == 2 + + def test_remove_concept_set_invalid_id_raises(self): + """Test that removing non-existent concept set raises KeyError.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .build()) + + with pytest.raises(KeyError, match="No concept set found"): + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_concept_set(concept_set_id=999) + + def test_remove_all_references(self): + """Test removing a concept set and all references.""" + cs1 = ConceptSet(id=1, name="CS1") + cs2 = ConceptSet(id=2, name="CS2") + + expr = (CohortBuilder("Test") + .with_concept_sets(cs1, cs2) + .with_condition(1) + .or_with_drug(2) + .require_measurement(1, same_day=True) # References cs1 + .censor_on_condition(1, anytime_after=True) # References cs1 + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.remove_all_references(concept_set_id=1) + + modified = cohort.expression + + # Concept set should be removed + assert len(modified.concept_sets) == 1 + assert modified.concept_sets[0].id == 2 + + # Entry event with cs1 should be removed + assert len(modified.primary_criteria.criteria_list) == 1 + assert modified.primary_criteria.criteria_list[0].codeset_id == 2 + + # Censoring criteria with cs1 should be removed + assert len(modified.censoring_criteria) == 0 + + def test_clear_inclusion_rules(self): + """Test clearing all inclusion rules.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .require_drug(2, within_days_before=30) + .exclude_condition(3, anytime_before=True) + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.clear_inclusion_rules() + + modified = cohort.expression + + # Should have only the default empty rule + assert len(modified.inclusion_rules) <= 1 + if modified.inclusion_rules: + assert modified.inclusion_rules[0].name == "Inclusion Criteria" + + def test_clear_censoring_criteria(self): + """Test clearing all censoring criteria.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .censor_on_drug(2, anytime_after=True) + .censor_on_death() + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.clear_censoring_criteria() + + modified = cohort.expression + assert len(modified.censoring_criteria) == 0 + + def test_clear_demographic_criteria(self): + """Test clearing demographic restrictions.""" + expr = (CohortBuilder("Test") + .with_condition(1) + .min_age(18) + .max_age(65) + .require_gender(8507, 8532) # Male, Female + .build()) + + with CohortBuilder.from_expression(expr) as cohort: + cohort.clear_demographic_criteria() + + modified = cohort.expression + + # Demographic rule should be removed or empty + demo_rules = [r for r in modified.inclusion_rules if r.name == "Demographic Criteria"] + assert len(demo_rules) == 0 + + def test_complex_modification_workflow(self): + """Test a realistic workflow with multiple modifications.""" + # Create original cohort + cs1 = ConceptSet(id=1, name="Diabetes") + cs2 = ConceptSet(id=2, name="Metformin") + cs3 = ConceptSet(id=3, name="Cancer") + + expr = (CohortBuilder("Diabetes Cohort") + .with_concept_sets(cs1, cs2, cs3) + .with_condition(1) + .first_occurrence() + .with_observation(prior_days=365) + .min_age(18) + .begin_rule("Prior Treatment") + .require_drug(2, within_days_before=365) + .end_rule() + .begin_rule("Cancer Exclusion") + .exclude_condition(3, anytime_before=True) + .end_rule() + .build()) + + # Modify it + cs4 = ConceptSet(id=4, name="Insulin") + + with CohortBuilder.from_expression(expr) as cohort: + # Remove old exclusion rule + cohort.remove_inclusion_rule("Cancer Exclusion") + + # Add new concept set + cohort.with_concept_sets(cs4) + + # Add new criteria + cohort.require_measurement(4, within_days_after=90) + + # Clear age restriction + cohort.clear_demographic_criteria() + + modified = cohort.expression + + # Verify changes + assert modified.title == "Diabetes Cohort" + assert len(modified.concept_sets) == 4 + + rule_names = [r.name for r in modified.inclusion_rules] + assert "Prior Treatment" in rule_names + assert "Cancer Exclusion" not in rule_names + + # Demographic criteria should be cleared + demo_rules = [r for r in modified.inclusion_rules if r.name == "Demographic Criteria"] + assert len(demo_rules) == 0 + + def test_modification_preserves_original(self): + """Test that modifications don't affect the original expression.""" + # Create original + expr = (CohortBuilder("Original") + .with_condition(1) + .require_drug(2, within_days_before=30) + .build()) + + original_rule_count = len(expr.inclusion_rules) + + # Modify copy + with CohortBuilder.from_expression(expr) as cohort: + cohort.clear_inclusion_rules() + cohort.require_measurement(3, same_day=True) + + # Original should be unchanged + assert len(expr.inclusion_rules) == original_rule_count diff --git a/tests/test_collection_methods.py b/tests/test_collection_methods.py new file mode 100644 index 0000000..d8992ae --- /dev/null +++ b/tests/test_collection_methods.py @@ -0,0 +1,222 @@ +""" +Tests for collection methods in the fluent cohort builder. + +These methods allow simplified creation of grouped criteria without +manually managing .any_of()...end_group() chains. +""" + +import pytest +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition.cohort import CohortExpression + + +def test_require_any_of_drugs(): + """Test require_any_of with drug IDs.""" + cohort = ( + CohortBuilder("Test ANY Drugs") + .with_condition(1) + .require_any_of(drug_ids=[10, 11, 12]) + .build() + ) + + assert isinstance(cohort, CohortExpression) + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].name == "Inclusion Criteria" + assert cohort.inclusion_rules[0].expression.type == "ANY" + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 3 + + # Verify all are drug criteria + for cc in cohort.inclusion_rules[0].expression.criteria_list: + assert cc.criteria.__class__.__name__ == "DrugExposure" + + +def test_require_any_of_mixed_domains(): + """Test require_any_of with multiple domain types.""" + cohort = ( + CohortBuilder("Test ANY Mixed") + .with_condition(1) + .require_any_of( + condition_ids=[20, 21], + drug_ids=[30], + procedure_ids=[40, 41, 42] + ) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "ANY" + # Should have 2 conditions + 1 drug + 3 procedures = 6 criteria + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 6 + + +def test_require_all_of(): + """Test require_all_of creates ALL group.""" + cohort = ( + CohortBuilder("Test ALL") + .with_drug(1) + .require_all_of(procedure_ids=[10, 11]) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "ALL" + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 2 + + +def test_require_at_least_of(): + """Test require_at_least_of with count parameter.""" + cohort = ( + CohortBuilder("Test AT_LEAST") + .with_condition(1) + .require_at_least_of(2, procedure_ids=[10, 11, 12, 13]) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "AT_LEAST" + assert cohort.inclusion_rules[0].expression.count == 2 + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 4 + + +def test_exclude_any_of(): + """Test exclude_any_of creates exclusion criteria.""" + cohort = ( + CohortBuilder("Test Exclusion") + .with_condition(1) + .exclude_any_of(drug_ids=[20, 21]) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + # Verify exclusion criteria were created + assert cohort.inclusion_rules[0].expression.type == "ANY" + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 2 + + # Check that criteria have occurrence count of 0 (exclusion) + for cc in cohort.inclusion_rules[0].expression.criteria_list: + assert cc.occurrence.count == 0 + assert cc.occurrence.type == 0 # exactly 0 + + +def test_collection_methods_chaining(): + """Test that collection methods can be chained together.""" + cohort = ( + CohortBuilder("Test Chaining") + .with_condition(1) + .require_any_of(drug_ids=[10, 11]) + .require_all_of(measurement_ids=[20, 21]) + .exclude_any_of(procedure_ids=[30]) + .build() + ) + + # Should have 3 groups in inclusion rules + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "ALL" + assert len(cohort.inclusion_rules[0].expression.groups) == 3 + + +def test_collection_with_named_rules(): + """Test collection methods with named inclusion rules.""" + cohort = ( + CohortBuilder("Test Named Rules") + .with_condition(1) + .begin_rule("Prior Medications") + .require_any_of(drug_ids=[10, 11, 12]) + .begin_rule("Recent Procedures") + .require_at_least_of(2, procedure_ids=[20, 21, 22, 23]) + .build() + ) + + assert len(cohort.inclusion_rules) == 2 + assert cohort.inclusion_rules[0].name == "Prior Medications" + assert cohort.inclusion_rules[0].expression.type == "ANY" + assert cohort.inclusion_rules[1].name == "Recent Procedures" + assert cohort.inclusion_rules[1].expression.type == "AT_LEAST" + + +def test_empty_collection_ignored(): + """Test that collection methods with no IDs don't add empty groups.""" + cohort = ( + CohortBuilder("Test Empty") + .with_condition(1) + .require_any_of() # No IDs provided + .build() + ) + + # Should have no inclusion rules since no criteria were added + assert cohort.inclusion_rules is None or len(cohort.inclusion_rules) == 0 + + +def test_collection_with_observation_ids(): + """Test collection methods with observation concept sets.""" + cohort = ( + CohortBuilder("Test Observations") + .with_condition(1) + .require_any_of(observation_ids=[5, 6, 7]) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "ANY" + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 3 + + for cc in cohort.inclusion_rules[0].expression.criteria_list: + assert cc.criteria.__class__.__name__ == "Observation" + + +def test_collection_with_visit_ids(): + """Test collection methods with visit concept sets.""" + cohort = ( + CohortBuilder("Test Visits") + .with_drug(1) + .require_all_of(visit_ids=[100, 101]) + .build() + ) + + assert len(cohort.inclusion_rules) == 1 + assert cohort.inclusion_rules[0].expression.type == "ALL" + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 2 + + for cc in cohort.inclusion_rules[0].expression.criteria_list: + assert cc.criteria.__class__.__name__ == "VisitOccurrence" + + +def test_require_at_least_of_single_type(): + """Test at_least with a single domain type.""" + cohort = ( + CohortBuilder("Test AT_LEAST Single") + .with_condition(1) + .require_at_least_of(3, drug_ids=[10, 11, 12, 13, 14]) + .build() + ) + + assert cohort.inclusion_rules[0].expression.type == "AT_LEAST" + assert cohort.inclusion_rules[0].expression.count == 3 + assert len(cohort.inclusion_rules[0].expression.criteria_list) == 5 + + +def test_real_world_example(): + """Test a real-world clinical scenario using collection methods.""" + cohort = ( + CohortBuilder("Diabetes with Complications") + .with_condition(1) # Entry: Type 2 Diabetes + .first_occurrence() + .with_observation(prior_days=365) + .require_age(18, 75) + .begin_rule("Antidiabetic Medications") + .require_any_of(drug_ids=[10, 11, 12]) # Metformin, Insulin, GLP-1 + .begin_rule("Diabetic Complications") + .require_at_least_of( + 1, + condition_ids=[20, 21, 22, 23] # Retinopathy, Neuropathy, Nephropathy, CVD + ) + .begin_rule("No Cancer History") + .exclude_any_of(condition_ids=[30, 31, 32]) # Various cancers + .build() + ) + + assert len(cohort.inclusion_rules) == 4 # Age + 3 named rules + assert cohort.inclusion_rules[0].name == "Demographic Criteria" + assert cohort.inclusion_rules[1].name == "Antidiabetic Medications" + assert cohort.inclusion_rules[2].name == "Diabetic Complications" + assert cohort.inclusion_rules[3].name == "No Cancer History" diff --git a/tests/test_criteria_classes.py b/tests/test_criteria_classes.py index 04b5975..3552d12 100644 --- a/tests/test_criteria_classes.py +++ b/tests/test_criteria_classes.py @@ -17,7 +17,8 @@ ConditionOccurrence, DrugExposure, ProcedureOccurrence, VisitOccurrence, Observation, Measurement, DeviceExposure, Specimen, Death, VisitDetail, ObservationPeriod, PayerPlanPeriod, LocationRegion, ConditionEra, - DrugEra, DoseEra, GeoCriteria, WindowedCriteria + DrugEra, DoseEra, GeoCriteria, WindowedCriteria, + CriteriaGroup, PrimaryCriteria, DemographicCriteria ) from circe.cohortdefinition.core import ( TextFilter, WindowBound, Window, @@ -25,6 +26,56 @@ ConceptSetSelection ) from circe.vocabulary.concept import Concept +import json + + +class TestStructureDefaults(unittest.TestCase): + """Test defaults and None handling for structural classes.""" + + def test_criteria_group_list_fields(self): + """Test defaults and None handling for CriteriaGroup list fields.""" + # 1. Default Initialization + cg = CriteriaGroup() + self.assertEqual(cg.criteria_list, []) + self.assertEqual(cg.groups, []) + self.assertEqual(cg.demographic_criteria_list, []) + + # 2. None Initialization + cg_none = CriteriaGroup( + criteria_list=None, + groups=None, + demographic_criteria_list=None + ) + self.assertEqual(cg_none.criteria_list, []) + self.assertEqual(cg_none.groups, []) + self.assertEqual(cg_none.demographic_criteria_list, []) + + # 3. JSON Null + cg_json = CriteriaGroup.model_validate_json(json.dumps({ + "CriteriaList": None, + "Groups": None, + "DemographicCriteriaList": None + })) + self.assertEqual(cg_json.criteria_list, []) + self.assertEqual(cg_json.groups, []) + self.assertEqual(cg_json.demographic_criteria_list, []) + + def test_primary_criteria_list_fields(self): + """Test defaults and None handling for PrimaryCriteria list fields.""" + # 1. Default Initialization + pc = PrimaryCriteria() + self.assertEqual(pc.criteria_list, []) + + # 2. None Initialization + pc_none = PrimaryCriteria(criteria_list=None) + self.assertEqual(pc_none.criteria_list, []) + + # 3. JSON Null + pc_json = PrimaryCriteria.model_validate_json(json.dumps({ + "CriteriaList": None + })) + self.assertEqual(pc_json.criteria_list, []) + class TestConditionOccurrence(unittest.TestCase): diff --git a/tests/test_domain_modifiers.py b/tests/test_domain_modifiers.py new file mode 100644 index 0000000..d8c15e0 --- /dev/null +++ b/tests/test_domain_modifiers.py @@ -0,0 +1,168 @@ +""" +Tests for domain-specific modifiers in the new parameter-based fluent cohort builder API. +""" + +import pytest +from circe.cohort_builder import CohortBuilder +from circe.cohortdefinition.cohort import CohortExpression + + +def test_procedure_with_quantity(): + """Test procedure quantity modifier.""" + cohort = ( + CohortBuilder("Test Procedure Quantity") + .with_condition(1) + .begin_rule("Procedure Rule") + .require_procedure(10, quantity_min=1, quantity_max=5, anytime_before=True) + .build() + ) + + assert isinstance(cohort, CohortExpression) + # Check that quantity range was applied + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.quantity is not None + assert criteria.criteria.quantity.value == 1.0 + assert criteria.criteria.quantity.extent == 5.0 + + +def test_measurement_with_operator(): + """Test measurement operator modifier.""" + cohort = ( + CohortBuilder("Test Measurement Operator") + .with_condition(1) + .begin_rule("Measurement Rule") + .require_measurement(10, operator=4172704, anytime_before=True) # Greater than + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.operator is not None + assert len(criteria.criteria.operator) == 1 + assert criteria.criteria.operator[0].concept_id == 4172704 + + +def test_measurement_with_range_ratios(): + """Test measurement range ratio modifiers.""" + cohort = ( + CohortBuilder("Test Measurement Ratios") + .with_drug(1) + .begin_rule("Measurement Rule") + .require_measurement(10, + range_low_ratio_min=0.5, + range_low_ratio_max=1.5, + range_high_ratio_min=1.0, + range_high_ratio_max=2.0, + anytime_before=True) + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + # Check ratios (these need to be added to the builder later if not already supported) + # For now, just verifying the API call works and values are in QueryConfig + # The actual mapping to OHDSI JSON happens in cohortdefinition/builders/measurement.py + # which I should also check. + + +def test_drug_with_route(): + """Test drug route modifier.""" + cohort = ( + CohortBuilder("Test Drug Route") + .with_condition(1) + .begin_rule("Drug Rule") + .require_drug(10, route=4132161, anytime_before=True) # Oral + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.route_concept is not None + assert len(criteria.criteria.route_concept) == 1 + assert criteria.criteria.route_concept[0].concept_id == 4132161 + + +def test_drug_with_refills(): + """Test drug refills modifier.""" + cohort = ( + CohortBuilder("Test Drug Refills") + .with_condition(1) + .begin_rule("Drug Rule") + .require_drug(10, refills_min=1, refills_max=12, anytime_before=True) + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.refills is not None + assert criteria.criteria.refills.value == 1 + assert criteria.criteria.refills.extent == 12 + + +def test_drug_with_dose(): + """Test drug dose modifier.""" + cohort = ( + CohortBuilder("Test Drug Dose") + .with_procedure(1) + .begin_rule("Drug Rule") + .require_drug(10, dose_min=10.0, dose_max=50.0, anytime_before=True) + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.effective_drug_dose is not None + assert criteria.criteria.effective_drug_dose.value == 10.0 + assert criteria.criteria.effective_drug_dose.extent == 50.0 + + +def test_visit_with_place_of_service(): + """Test visit place_of_service modifier.""" + cohort = ( + CohortBuilder("Test Visit Place of Service") + .with_condition(1) + .begin_rule("Visit Rule") + .require_visit(10, place_of_service=8546, anytime_before=True) # Hospice + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.place_of_service is not None + assert len(criteria.criteria.place_of_service) == 1 + assert criteria.criteria.place_of_service[0].concept_id == 8546 + + +def test_observation_with_qualifier(): + """Test observation qualifier modifier.""" + cohort = ( + CohortBuilder("Test Observation Qualifier") + .with_condition(1) + .begin_rule("Observation Rule") + .require_observation(10, qualifier=4214956, anytime_before=True) + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.qualifier is not None + assert len(criteria.criteria.qualifier) == 1 + assert criteria.criteria.qualifier[0].concept_id == 4214956 + + +def test_multiple_modifiers_chained(): + """Test multiple modifiers on one criterion via parameters.""" + cohort = ( + CohortBuilder("Test Multi-Modifier") + .with_condition(1) + .begin_rule("Drug Rule") + .require_drug(10, + route=4132161, + refills_min=2, + refills_max=6, + dose_min=5.0, + dose_max=20.0, + days_supply_min=30, + days_supply_max=90, + anytime_before=True) + .build() + ) + + criteria = cohort.inclusion_rules[0].expression.criteria_list[0] + assert criteria.criteria.route_concept is not None + assert criteria.criteria.refills is not None + assert criteria.criteria.effective_drug_dose is not None + assert criteria.criteria.days_supply is not None diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..b2e1576 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,337 @@ +""" +Tests for cohort expression validator functions. + +This module tests both the instance methods on CohortExpression and the +standalone validator functions. +""" + +import pytest +import json +from pathlib import Path + +from circe.cohortdefinition import ( + CohortExpression, + PrimaryCriteria, + ConditionOccurrence, + DrugExposure, + InclusionRule, + CriteriaGroup, + DateOffsetStrategy, + CustomEraStrategy, + ObservationFilter, + ResultLimit, +) +from circe.cohortdefinition.validators import ( + is_first_event, + has_exclusion_rules, + has_inclusion_rule_by_name, + get_exclusion_count, + has_censoring_criteria, + get_censoring_criteria_types, + has_additional_criteria, + has_end_strategy, + get_end_strategy_type, + get_primary_criteria_types, + has_observation_window, + get_primary_limit_type, + get_concept_set_count, + has_concept_sets, +) +from circe.vocabulary import ConceptSet + + +class TestIsFirstEvent: + """Test is_first_event method and function.""" + + def test_empty_cohort(self): + """Test with empty cohort expression.""" + cohort = CohortExpression() + assert cohort.is_first_event() is False + assert is_first_event(cohort) is False + + def test_no_primary_criteria(self): + """Test with no primary criteria.""" + cohort = CohortExpression(primary_criteria=None) + assert cohort.is_first_event() is False + + def test_all_first_true(self): + """Test when all criteria have first=True.""" + criteria1 = ConditionOccurrence(codeset_id=1, first=True) + criteria2 = DrugExposure(codeset_id=2, first=True) + + primary = PrimaryCriteria(criteria_list=[criteria1, criteria2]) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.is_first_event() is True + assert is_first_event(cohort) is True + + def test_all_first_false(self): + """Test when all criteria have first=False.""" + criteria1 = ConditionOccurrence(codeset_id=1, first=False) + criteria2 = DrugExposure(codeset_id=2, first=False) + + primary = PrimaryCriteria(criteria_list=[criteria1, criteria2]) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.is_first_event() is False + + def test_mixed_first_values(self): + """Test when criteria have mixed first values.""" + criteria1 = ConditionOccurrence(codeset_id=1, first=True) + criteria2 = DrugExposure(codeset_id=2, first=False) + + primary = PrimaryCriteria(criteria_list=[criteria1, criteria2]) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.is_first_event() is False + + def test_first_none(self): + """Test when first is None.""" + criteria1 = ConditionOccurrence(codeset_id=1, first=None) + + primary = PrimaryCriteria(criteria_list=[criteria1]) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.is_first_event() is False + + +class TestExclusionRules: + """Test exclusion rule methods and functions.""" + + def test_no_exclusion_rules(self): + """Test cohort with no exclusion rules.""" + cohort = CohortExpression() + + assert cohort.has_exclusion_rules() is False + assert has_exclusion_rules(cohort) is False + assert cohort.get_exclusion_count() == 0 + assert get_exclusion_count(cohort) == 0 + + def test_with_exclusion_rules(self): + """Test cohort with exclusion rules.""" + rule1 = InclusionRule(name="Rule 1", description="First rule") + rule2 = InclusionRule(name="Rule 2", description="Second rule") + + cohort = CohortExpression(inclusion_rules=[rule1, rule2]) + + assert cohort.has_exclusion_rules() is True + assert has_exclusion_rules(cohort) is True + assert cohort.get_exclusion_count() == 2 + assert get_exclusion_count(cohort) == 2 + + def test_has_inclusion_rule_by_name(self): + """Test finding inclusion rule by name.""" + rule1 = InclusionRule(name="Prior Cancer", description="Exclude prior cancer") + rule2 = InclusionRule(name="Age Limit", description="Age restriction") + + cohort = CohortExpression(inclusion_rules=[rule1, rule2]) + + assert cohort.has_inclusion_rule_by_name("Prior Cancer") is True + assert has_inclusion_rule_by_name(cohort, "Prior Cancer") is True + assert cohort.has_inclusion_rule_by_name("Age Limit") is True + assert cohort.has_inclusion_rule_by_name("Nonexistent") is False + assert has_inclusion_rule_by_name(cohort, "Nonexistent") is False + + def test_has_inclusion_rule_by_name_no_rules(self): + """Test finding rule by name when no rules exist.""" + cohort = CohortExpression() + + assert cohort.has_inclusion_rule_by_name("Any Name") is False + + +class TestCensoringCriteria: + """Test censoring criteria methods and functions.""" + + def test_no_censoring_criteria(self): + """Test cohort with no censoring criteria.""" + cohort = CohortExpression() + + assert cohort.has_censoring_criteria() is False + assert has_censoring_criteria(cohort) is False + assert cohort.get_censoring_criteria_types() == [] + assert get_censoring_criteria_types(cohort) == [] + + def test_with_censoring_criteria(self): + """Test cohort with censoring criteria.""" + censor1 = ConditionOccurrence(codeset_id=1) + censor2 = DrugExposure(codeset_id=2) + + cohort = CohortExpression(censoring_criteria=[censor1, censor2]) + + assert cohort.has_censoring_criteria() is True + assert has_censoring_criteria(cohort) is True + + types = cohort.get_censoring_criteria_types() + assert len(types) == 2 + assert "ConditionOccurrence" in types + assert "DrugExposure" in types + + assert get_censoring_criteria_types(cohort) == types + + +class TestAdditionalCriteria: + """Test additional criteria methods and functions.""" + + def test_no_additional_criteria(self): + """Test cohort with no additional criteria.""" + cohort = CohortExpression() + + assert cohort.has_additional_criteria() is False + assert has_additional_criteria(cohort) is False + + def test_empty_additional_criteria(self): + """Test cohort with empty additional criteria group.""" + empty_group = CriteriaGroup() + cohort = CohortExpression(additional_criteria=empty_group) + + assert cohort.has_additional_criteria() is False + + def test_with_additional_criteria(self): + """Test cohort with non-empty additional criteria.""" + criteria = ConditionOccurrence(codeset_id=1) + from circe.cohortdefinition.criteria import CorelatedCriteria + from circe.cohortdefinition.core import Window + + correlated = CorelatedCriteria(criteria=criteria) + group = CriteriaGroup(criteria_list=[correlated]) + cohort = CohortExpression(additional_criteria=group) + + assert cohort.has_additional_criteria() is True + assert has_additional_criteria(cohort) is True + + +class TestEndStrategy: + """Test end strategy methods and functions.""" + + def test_no_end_strategy(self): + """Test cohort with no end strategy.""" + cohort = CohortExpression() + + assert cohort.has_end_strategy() is False + assert has_end_strategy(cohort) is False + assert cohort.get_end_strategy_type() is None + assert get_end_strategy_type(cohort) is None + + def test_date_offset_strategy(self): + """Test cohort with DateOffset end strategy.""" + strategy = DateOffsetStrategy(offset=30, date_field="StartDate") + cohort = CohortExpression(end_strategy=strategy) + + assert cohort.has_end_strategy() is True + assert has_end_strategy(cohort) is True + assert cohort.get_end_strategy_type() == "DateOffset" + assert get_end_strategy_type(cohort) == "DateOffset" + + def test_custom_era_strategy(self): + """Test cohort with CustomEra end strategy.""" + strategy = CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0) + cohort = CohortExpression(end_strategy=strategy) + + assert cohort.has_end_strategy() is True + assert cohort.get_end_strategy_type() == "CustomEra" + assert get_end_strategy_type(cohort) == "CustomEra" + + +class TestPrimaryCriteria: + """Test primary criteria methods and functions.""" + + def test_no_primary_criteria(self): + """Test cohort with no primary criteria.""" + cohort = CohortExpression() + + assert cohort.get_primary_criteria_types() == [] + assert get_primary_criteria_types(cohort) == [] + assert cohort.has_observation_window() is False + assert has_observation_window(cohort) is False + assert cohort.get_primary_limit_type() is None + assert get_primary_limit_type(cohort) is None + + def test_primary_criteria_types(self): + """Test getting primary criteria types.""" + criteria1 = ConditionOccurrence(codeset_id=1) + criteria2 = DrugExposure(codeset_id=2) + + primary = PrimaryCriteria(criteria_list=[criteria1, criteria2]) + cohort = CohortExpression(primary_criteria=primary) + + types = cohort.get_primary_criteria_types() + assert len(types) == 2 + assert "ConditionOccurrence" in types + assert "DrugExposure" in types + assert get_primary_criteria_types(cohort) == types + + def test_observation_window(self): + """Test observation window detection.""" + obs_window = ObservationFilter(prior_days=365, post_days=0) + primary = PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)], + observation_window=obs_window + ) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.has_observation_window() is True + assert has_observation_window(cohort) is True + + def test_primary_limit_type(self): + """Test getting primary limit type.""" + limit = ResultLimit(type="First") + primary = PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)], + primary_limit=limit + ) + cohort = CohortExpression(primary_criteria=primary) + + assert cohort.get_primary_limit_type() == "First" + assert get_primary_limit_type(cohort) == "First" + + +class TestConceptSets: + """Test concept set methods and functions.""" + + def test_no_concept_sets(self): + """Test cohort with no concept sets.""" + cohort = CohortExpression() + + assert cohort.has_concept_sets() is False + assert has_concept_sets(cohort) is False + assert cohort.get_concept_set_count() == 0 + assert get_concept_set_count(cohort) == 0 + + def test_with_concept_sets(self): + """Test cohort with concept sets.""" + cs1 = ConceptSet(id=1, name="Diabetes") + cs2 = ConceptSet(id=2, name="Hypertension") + + cohort = CohortExpression(concept_sets=[cs1, cs2]) + + assert cohort.has_concept_sets() is True + assert has_concept_sets(cohort) is True + assert cohort.get_concept_set_count() == 2 + assert get_concept_set_count(cohort) == 2 + + +class TestRealCohortDefinition: + """Test with real cohort definition from test files.""" + + def test_isolated_immune_thrombocytopenia(self): + """Test with the isolated immune thrombocytopenia cohort.""" + test_file = Path(__file__).parent / "cohorts" / "isolated_immune_thrombocytopenia.json" + + if not test_file.exists(): + pytest.skip("Test cohort file not found") + + with open(test_file) as f: + data = json.load(f) + + cohort = CohortExpression.model_validate(data) + + # This cohort should have concept sets + assert cohort.has_concept_sets() is True + assert cohort.get_concept_set_count() > 0 + + # Check primary criteria + assert len(cohort.get_primary_criteria_types()) > 0 + + # This cohort has inclusion rules (exclusion criteria) + assert cohort.has_exclusion_rules() is True + assert cohort.get_exclusion_count() > 0