Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,23 @@ This package provides a complete Python implementation of CIRCE-BE with:
- Measurement, Observation
- Visit Occurrence/Detail
- Device Exposure, Specimen
- Death, Location Region
- Observation Period, Payer Plan Period
- And more...
- **Full cohort expression validation** with comprehensive error checking
- **Markdown rendering** for human-readable cohort descriptions
- **Complete CLI interface** with 4 commands (validate, generate-sql, render-markdown, process)
- Specimen, Death
- Payer Plan Period, Location Region
- **Full Cohort Expression Validation** with 40+ checker implementations
- **Markdown Rendering** for human-readable descriptions
- **Complete CLI Interface** for validation, SQL, and rendering
- **Extension System** to support custom CDM domains

## Extensions

`circe_py` includes a powerful extension system that allows adding support for custom CDM domains.

Included Extensions:

- **OHDSI Waveform Extension**: Support for the OHDSI Waveform Extension specification (waveform_occurrence, waveform_registry, waveform_channel_metadata, waveform_feature). See [waveform_extension/README.md](waveform_extension/README.md).

For information on how to implement your own extension, see the [Developer Guide for Extensions](docs/developer/extensions.rst).

- **Java interoperability** - supports both camelCase and snake_case field names for seamless Java CIRCE-BE compatibility

## ⚠️ Java Fidelity Requirement
Expand Down
10 changes: 9 additions & 1 deletion circe/cohortdefinition/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .utils import BuilderUtils, BuilderOptions, CriteriaColumn
from .base import CriteriaSqlBuilder
from circe.extensions import get_registry
from .condition_occurrence import ConditionOccurrenceSqlBuilder
from .drug_exposure import DrugExposureSqlBuilder
from .procedure_occurrence import ProcedureOccurrenceSqlBuilder
Expand All @@ -28,6 +29,12 @@
from .visit_detail import VisitDetailSqlBuilder
from .location_region import LocationRegionSqlBuilder

# Extension support
def get_builder_for_criteria(criteria):
"""Get a SQL builder for a criteria instance, checking extensions first."""
registry = get_registry()
return registry.get_builder(criteria)

__all__ = [
# Utility classes
"BuilderUtils", "BuilderOptions", "CriteriaColumn",
Expand All @@ -51,5 +58,6 @@
"ObservationPeriodSqlBuilder",
"PayerPlanPeriodSqlBuilder",
"VisitDetailSqlBuilder",
"LocationRegionSqlBuilder"
"LocationRegionSqlBuilder",
"get_builder_for_criteria"
]
31 changes: 25 additions & 6 deletions circe/cohortdefinition/cohort_expression_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
ConditionOccurrenceSqlBuilder, DeathSqlBuilder, DeviceExposureSqlBuilder,
MeasurementSqlBuilder, ObservationSqlBuilder, SpecimenSqlBuilder,
VisitOccurrenceSqlBuilder, DrugExposureSqlBuilder, ProcedureOccurrenceSqlBuilder,
ConditionEraSqlBuilder, DrugEraSqlBuilder, DoseEraSqlBuilder, ObservationPeriodSqlBuilder, PayerPlanPeriodSqlBuilder,
VisitDetailSqlBuilder, LocationRegionSqlBuilder
ConditionEraSqlBuilder, DrugEraSqlBuilder, DoseEraSqlBuilder, ObservationPeriodSqlBuilder, PayerPlanPeriodSqlBuilder,
VisitDetailSqlBuilder, LocationRegionSqlBuilder, get_builder_for_criteria
)
from circe.extensions import get_registry
from .interfaces import IGetCriteriaSqlDispatcher, IGetEndStrategySqlDispatcher
from .concept_set_expression_query_builder import ConceptSetExpressionQueryBuilder

Expand Down Expand Up @@ -1296,7 +1297,19 @@ def get_criteria_sql(self, criteria: Criteria, options: Optional[BuilderOptions]
'DoseEra': DoE,
}

if criteria_type in criteria_class_map:
# Check if it's a registered extension criteria
registry = get_registry()
if criteria_type and criteria_type in registry._criteria_classes:
try:
criteria_data = dict(criteria_data) if criteria_data else {}
# Add defaults if needed
if 'first' not in criteria_data or criteria_data.get('first') is None:
criteria_data['first'] = False

criteria = registry._criteria_classes[criteria_type].model_validate(criteria_data, strict=False)
except Exception as e:
raise ValueError(f"Failed to deserialize extension criteria: {criteria_type} - {e}")
elif criteria_type in criteria_class_map:
try:
# Make a mutable copy to add defaults
criteria_data = dict(criteria_data) if criteria_data else {}
Expand All @@ -1315,10 +1328,16 @@ def get_criteria_sql(self, criteria: Criteria, options: Optional[BuilderOptions]
criteria = criteria_class_map[criteria_type].model_validate(criteria_data, strict=False)
except Exception as e:
raise ValueError(f"Failed to deserialize criteria from dict: {criteria_type} - {e}")
else:
raise ValueError(f"Unknown criteria type in dict: {criteria_type}")
else:
raise ValueError(f"Invalid criteria dict structure: {criteria}")
raise ValueError(f"Unknown criteria type in dict: {criteria_type}")
else:
if isinstance(criteria, dict):
raise ValueError(f"Invalid criteria dict structure: {criteria}")

# Check for extension builder first
extension_builder = get_builder_for_criteria(criteria)
if extension_builder:
return self._get_criteria_sql_from_builder(extension_builder, criteria, options)

# Import here to avoid circular dependency - use the already imported names
if isinstance(criteria, ConditionOccurrence):
Expand Down
53 changes: 42 additions & 11 deletions circe/cohortdefinition/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Reference: JAVA_CLASS_MAPPINGS.md for Java equivalents.
"""

from typing import List, Optional, Any, ClassVar, Union, TYPE_CHECKING
from pydantic import BaseModel, Field, ConfigDict, model_serializer, AliasChoices, field_validator
from typing import List, Optional, Any, ClassVar, Union, TYPE_CHECKING, Annotated
from pydantic import BaseModel, Field, ConfigDict, model_serializer, AliasChoices, field_validator, BeforeValidator
from enum import Enum
from ..vocabulary.concept import Concept
from .core import (
Expand Down Expand Up @@ -218,13 +218,26 @@ class Criteria(CirceBaseModel):
@model_serializer(mode='wrap')
def _serialize_polymorphic(self, serializer, info):
"""Serialize with polymorphic type wrapper for Java compatibility."""
# Get the serialized data using default serialization
data = serializer(self)
# Wrap in class name for polymorphic deserialization in Java
# Only wrap if this is a subclass (not the base Criteria class)
if self.__class__.__name__ != 'Criteria':
return {self.__class__.__name__: data}
return data
if self.__class__.__name__ == 'Criteria':
return serializer(self)

# For subclasses (extensions), we want to ensure all fields are included
# even if serialized via a base class Union link.
# We manually build the dict to avoid infinite recursion with model_dump()
data = {}
for field_name, field_info in self.model_fields.items():
value = getattr(self, field_name)
if value is not None:
# Use serialization_alias if it exists, otherwise use field name
# Note: alias_generator (PascalCase) is handled via serialization_alias
# effectively if we use the right property.
# In Pydantic V2, serialization_alias is often the PascalCase version if configured.
alias = field_info.serialization_alias or field_name
# If it's a generic field without explicit alias, it might need PascalCase
# but most CIRCE fields have explicit aliases.
data[alias] = value

return {self.__class__.__name__: data}

def accept(self, dispatcher: Any, options: Optional[Any] = None) -> str:
"""Accept method for visitor pattern."""
Expand Down Expand Up @@ -1204,14 +1217,32 @@ def normalize_window(window_dict: dict) -> dict:


# Define CriteriaType Union for strict typing
CriteriaType = Union[
# Define CriteriaType Union for strict typing
_CriteriaTypeUnion = Union[
ConditionOccurrence, DrugExposure, ProcedureOccurrence,
VisitOccurrence, Observation, Measurement, DeviceExposure,
Specimen, Death, VisitDetail, ObservationPeriod,
PayerPlanPeriod, LocationRegion, ConditionEra,
DrugEra, DoseEra
DrugEra, DoseEra, Criteria
]

def _validate_criteria_extension(v: Any) -> Any:
"""Validate criteria checking extensions registry for custom types."""
if isinstance(v, dict) and len(v) == 1:
key = next(iter(v))
try:
from circe.extensions import get_registry
registry = get_registry()
cls = registry.get_criteria_class(key)
if cls:
# Found registered criteria class, deserialize it
return cls.model_validate(v[key])
except ImportError:
pass
return v

CriteriaType = Annotated[_CriteriaTypeUnion, BeforeValidator(_validate_criteria_extension)]

# Map for dynamic lookup
NAMES_TO_CLASSES = {
'ConditionOccurrence': ConditionOccurrence,
Expand Down
30 changes: 26 additions & 4 deletions circe/cohortdefinition/printfriendly/markdown_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,36 @@ class MarkdownRender:
subdirectory and mirror the structure of Java's .ftl files.
"""

def __init__(self, concept_sets: Optional[List[ConceptSet]] = None, include_concept_sets: bool = False):
def __init__(self, concept_sets: Optional[List[ConceptSet]] = None, include_concept_sets: bool = False, template_paths: Optional[List[Path]] = None):
"""Initialize the markdown renderer.

Args:
concept_sets: Optional list of concept sets for resolving codeset IDs to names
include_concept_sets: Whether to include concept set tables in the output (default: False)
template_paths: Optional list of additional template directories to search
"""
self._concept_sets = concept_sets or []
self._include_concept_sets = include_concept_sets

# Initialize Jinja2 environment
template_dir = Path(__file__).parent / 'templates'
# Initialize Jinja2 environment with multiple loaders
built_in_template_dir = Path(__file__).parent / 'templates'

# Start with built-in templates
loaders = [jinja2.FileSystemLoader(str(built_in_template_dir))]

# Add user provided paths
if template_paths:
for path in template_paths:
loaders.append(jinja2.FileSystemLoader(str(path)))

# Add registry paths
from circe.extensions import get_registry
registry = get_registry()
for path in registry.template_paths:
loaders.append(jinja2.FileSystemLoader(str(path)))

self._env = jinja2.Environment(
loader=jinja2.FileSystemLoader(str(template_dir)),
loader=jinja2.ChoiceLoader(loaders),
trim_blocks=True,
lstrip_blocks=True,
autoescape=False # We're generating markdown, not HTML
Expand All @@ -54,6 +70,12 @@ def __init__(self, concept_sets: Optional[List[ConceptSet]] = None, include_conc
self._env.filters['format_date'] = self._format_date
self._env.filters['format_number'] = self._format_number

# Add extension helper to look up template name for a criteria instance
def get_template_for_criteria(criteria):
return registry.get_template(criteria)

self._env.globals['get_template_for_criteria'] = get_template_for_criteria

# Register global functions
self._env.globals['codeset_name'] = self._codeset_name
self._env.globals['format_date'] = self._format_date
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
============================================ #}
{%- macro Criteria(c, level=0, isPlural=true, countCriteria={}, indexLabel="cohort entry") -%}
{%- set type_name = c.__class__.__name__ -%}
{%- if type_name == "ConditionEra" -%}{{ ConditionEra(c, level, isPlural, countCriteria, indexLabel) }}
{%- set custom_template = get_template_for_criteria(c) -%}
{%- if custom_template -%}
{%- with criteria=c, level=level, isPlural=isPlural, countCriteria=countCriteria, indexLabel=indexLabel -%}
{%- include custom_template -%}
{%- endwith -%}
{%- elif type_name == "ConditionEra" -%}{{ ConditionEra(c, level, isPlural, countCriteria, indexLabel) }}
{%- elif type_name == "ConditionOccurrence" -%}{{ ConditionOccurrence(c, level, isPlural, countCriteria, indexLabel) }}
{%- elif type_name == "Death" -%}{{ Death(c, level, isPlural, countCriteria, indexLabel) }}
{%- elif type_name == "DeviceExposure" -%}{{ DeviceExposure(c, level, isPlural, countCriteria, indexLabel) }}
Expand Down
115 changes: 115 additions & 0 deletions circe/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Extension Registry for OMOP CDM.

This module provides the central registry for managing extensions to circe-py,
allowing external projects to register custom criteria classes, SQL builders,
and markdown renderers.
"""
from typing import Dict, List, Optional, Type, Set, Union
from pathlib import Path

# Forward references to avoid circular imports
# Actual imports happen inside methods or with TYPE_CHECKING
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .cohortdefinition.criteria import Criteria
from .cohortdefinition.builders.base import CriteriaSqlBuilder

class ExtensionRegistry:
"""Central registry for OMOP CDM extensions."""

def __init__(self):
# Maps criteria names to criteria classes (for JSON deserialization)
self._criteria_classes: Dict[str, Type['Criteria']] = {}

# Maps criteria types to SQL builder classes
self._sql_builders: Dict[Type['Criteria'], Type['CriteriaSqlBuilder']] = {}

# Maps criteria types to markdown template names
self._markdown_templates: Dict[Type['Criteria'], str] = {}

# List of paths to search for Jinja2 templates
self._template_paths: List[Path] = []

def register_criteria_class(self, name: str, cls: Type['Criteria']) -> None:
"""Register a new criteria class for JSON deserialization.

Args:
name: The name of the criteria type (e.g. "WaveformOccurrence")
cls: The Criteria subclass
"""
self._criteria_classes[name] = cls

def register_sql_builder(self, criteria_cls: Type['Criteria'], builder_cls: Type['CriteriaSqlBuilder']) -> None:
"""Register a SQL builder for a criteria type.

Args:
criteria_cls: The Criteria subclass
builder_cls: The CriteriaSqlBuilder subclass
"""
self._sql_builders[criteria_cls] = builder_cls

def register_markdown_template(self, criteria_cls: Type['Criteria'], template_name: str) -> None:
"""Register a Jinja2 template for markdown rendering.

Args:
criteria_cls: The Criteria subclass
template_name: The name of the template file (e.g. "waveform_occurrence.j2")
"""
self._markdown_templates[criteria_cls] = template_name

def add_template_path(self, path: Path) -> None:
"""Add a path to search for Jinja2 templates.

Args:
path: Path to a directory containing Jinja2 templates
"""
if path not in self._template_paths:
self._template_paths.append(path)

def get_builder(self, criteria: 'Criteria') -> Optional['CriteriaSqlBuilder']:
"""Get the SQL builder for a criteria instance.

Args:
criteria: The criteria instance

Returns:
An instance of the registered SQL builder, or None if not found
"""
builder_cls = self._sql_builders.get(type(criteria))
return builder_cls() if builder_cls else None

def get_template(self, criteria: 'Criteria') -> Optional[str]:
"""Get the markdown template name for a criteria instance.

Args:
criteria: The criteria instance

Returns:
The template name, or None if not found
"""
return self._markdown_templates.get(type(criteria))

def get_criteria_class(self, name: str) -> Optional[Type['Criteria']]:
"""Get a registered criteria class by name.

Args:
name: The name of the criteria type

Returns:
The Criteria subclass, or None if not found
"""
return self._criteria_classes.get(name)

@property
def template_paths(self) -> List[Path]:
"""Get all registered template paths."""
return list(self._template_paths)

# Global registry instance
_registry = ExtensionRegistry()

def get_registry() -> ExtensionRegistry:
"""Get the global extension registry instance."""
return _registry
Loading
Loading