diff --git a/biasanalyzer/api.py b/biasanalyzer/api.py index df009ac..e3f73b9 100644 --- a/biasanalyzer/api.py +++ b/biasanalyzer/api.py @@ -49,17 +49,11 @@ def set_root_omop(self): db = self.config['root_omop_cdm_database']['database'] db_url = f"postgresql://{user}:{password}@{host}:{port}/{db}" self.omop_cdm_db = OMOPCDMDatabase(db_url) - self.bias_db = BiasDatabase(':memory:') - # load postgres extension in duckdb bias_db so that cohorts in duckdb can be joined - # with OMOP CDM tables in omop_cdm_db - self.bias_db.load_postgres_extension() - self.bias_db.omop_cdm_db_url = db_url - + self.bias_db = BiasDatabase(':memory:', omop_db_url=db_url) elif db_type == 'duckdb': db_path = self.config['root_omop_cdm_database'].get('database', ":memory:") self.omop_cdm_db = OMOPCDMDatabase(db_path) - self.bias_db = BiasDatabase(db_path) - self.bias_db.omop_cdm_db_url = db_path + self.bias_db = BiasDatabase(':memory:', omop_db_url=db_path) else: notify_users(f"Unsupported database type: {db_type}") diff --git a/biasanalyzer/cohort.py b/biasanalyzer/cohort.py index d6b9698..f9f8865 100644 --- a/biasanalyzer/cohort.py +++ b/biasanalyzer/cohort.py @@ -1,6 +1,4 @@ -from sqlalchemy.exc import SQLAlchemyError from functools import reduce -import duckdb import pandas as pd from datetime import datetime from tqdm.auto import tqdm @@ -122,31 +120,34 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: try: # Execute read-only query from OMOP CDM database result = self.omop_db.execute_query(query) - # Create CohortDefinition - cohort_def = CohortDefinition( - name=cohort_name, - description=description, - created_date=datetime.now().date(), - creation_info=clean_string(query), - created_by=created_by - ) - cohort_def_id = self.bias_db.create_cohort_definition(cohort_def, progress_obj=tqdm) - progress.update(1) - - progress.set_postfix_str(stages[2]) - # Store cohort_definition and cohort data into BiasDatabase - cohort_df = pd.DataFrame(result) - cohort_df['cohort_definition_id'] = cohort_def_id - cohort_df = cohort_df.rename(columns={"person_id": "subject_id"}) - self.bias_db.create_cohort_in_bulk(cohort_df) - progress.update(1) - - tqdm.write(f"Cohort {cohort_name} successfully created.") - return CohortData(cohort_id=cohort_def_id, bias_db=self.bias_db, omop_db=self.omop_db) - except duckdb.Error as e: - notify_users(f"Error executing query: {e}") - return None - except SQLAlchemyError as e: + if result: + # Create CohortDefinition + cohort_def = CohortDefinition( + name=cohort_name, + description=description, + created_date=datetime.now().date(), + creation_info=clean_string(query), + created_by=created_by + ) + cohort_def_id = self.bias_db.create_cohort_definition(cohort_def, progress_obj=tqdm) + progress.update(1) + + progress.set_postfix_str(stages[2]) + # Store cohort_definition and cohort data into BiasDatabase + cohort_df = pd.DataFrame(result) + cohort_df['cohort_definition_id'] = cohort_def_id + cohort_df = cohort_df.rename(columns={"person_id": "subject_id"}) + self.bias_db.create_cohort_in_bulk(cohort_df) + progress.update(1) + + tqdm.write(f"Cohort {cohort_name} successfully created.") + return CohortData(cohort_id=cohort_def_id, bias_db=self.bias_db, omop_db=self.omop_db) + else: + progress.update(2) + notify_users(f"No cohort is created due to empty results being returned from query") + return None + except Exception as e: + progress.update(2) notify_users(f"Error executing query: {e}") if omop_session is not None: omop_session.close() diff --git a/biasanalyzer/cohort_query_builder.py b/biasanalyzer/cohort_query_builder.py index b995d3f..6c3f82a 100644 --- a/biasanalyzer/cohort_query_builder.py +++ b/biasanalyzer/cohort_query_builder.py @@ -71,9 +71,12 @@ def build_query_cohort_creation(self, cohort_config: dict) -> str: temporal_events=temporal_events ) - def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_count: int, vocab: str) -> str: + def build_concept_prevalence_query(self, db_schema: str, omop_alias: str, concept_type: str, cid: int, + filter_count: int, vocab: str) -> str: """ Build a SQL query for concept prevalence statistics for a given domain and cohort. + :param db_schema: BiasDatabase database schema under which all tables are stored. + :param omop_alias: OMOP database alias attached to the BiasDataBase in-memory duckdb :param concept_type: Domain from DOMAIN_MAPPING (e.g., 'condition_occurrence'). :param cid: Cohort definition ID. :param filter_count: Minimum count threshold for concepts with 0 meaning no filtering @@ -93,6 +96,8 @@ def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_cou # Load and render the template template = self.env.get_template("cohort_concept_prevalence_query.sql.j2") return template.render( + db_schema=db_schema, + omop=omop_alias, table_name=DOMAIN_MAPPING[concept_type]["table"], concept_id_column=DOMAIN_MAPPING[concept_type]["concept_id"], start_date_column=DOMAIN_MAPPING[concept_type]["start_date"], diff --git a/biasanalyzer/database.py b/biasanalyzer/database.py index 35499a4..2be453a 100644 --- a/biasanalyzer/database.py +++ b/biasanalyzer/database.py @@ -7,7 +7,8 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import create_engine, text from biasanalyzer.models import CohortDefinition -from biasanalyzer.sql import * +from biasanalyzer.sql import (AGE_DISTRIBUTION_QUERY, GENDER_DISTRIBUTION_QUERY, AGE_STATS_QUERY, + GENDER_STATS_QUERY, RACE_STATS_QUERY, ETHNICITY_STATS_QUERY) from biasanalyzer.utils import build_concept_hierarchy, print_hierarchy, find_roots, notify_users @@ -29,24 +30,41 @@ def __new__(cls, *args, **kwargs): cls._instance._initialize(*args, **kwargs) # Initialize only once return cls._instance - def _initialize(self, db_url): + def _initialize(self, db_url, omop_db_url=None): # by default, duckdb uses in memory database self.conn = duckdb.connect(db_url) - self.omop_cdm_db_url = None + self.schema = "biasanalyzer" + self.omop_alias = 'omop' + self.conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") + self.omop_cdm_db_url = omop_db_url + if omop_db_url is not None: + if omop_db_url.startswith('postgresql://'): + # omop db is postgreSQL + self.load_postgres_extension() + self.conn.execute(f""" + ATTACH '{self.omop_cdm_db_url}' as {self.omop_alias} (TYPE postgres) + """) + elif omop_db_url.endswith('.duckdb'): + self.conn.execute(f""" + ATTACH '{self.omop_cdm_db_url}' as {self.omop_alias} + """) + else: + raise ValueError("Unsupported OMOP database backend") + self._create_cohort_definition_table() self._create_cohort_table() def _create_cohort_definition_table(self): try: - self.conn.execute('CREATE SEQUENCE id_sequence START 1') + self.conn.execute(f'CREATE SEQUENCE {self.schema}.id_sequence START 1') except duckdb.Error as e: if "already exists" in str(e).lower(): notify_users("Sequence already exists, skipping creation.") else: raise - self.conn.execute(''' - CREATE TABLE IF NOT EXISTS cohort_definition ( - id INTEGER DEFAULT nextval('id_sequence'), + self.conn.execute(f''' + CREATE TABLE IF NOT EXISTS {self.schema}.cohort_definition ( + id INTEGER DEFAULT nextval('{self.schema}.id_sequence'), name VARCHAR NOT NULL, description VARCHAR, created_date DATE, @@ -58,18 +76,19 @@ def _create_cohort_definition_table(self): notify_users("Cohort Definition table created.") def _create_cohort_table(self): - self.conn.execute(''' - CREATE TABLE IF NOT EXISTS cohort ( + self.conn.execute(f''' + CREATE TABLE IF NOT EXISTS {self.schema}.cohort ( subject_id BIGINT, cohort_definition_id INTEGER, cohort_start_date DATE, cohort_end_date DATE, - FOREIGN KEY (cohort_definition_id) REFERENCES cohort_definition(id) + FOREIGN KEY (cohort_definition_id) REFERENCES {self.schema}.cohort_definition(id) ) ''') try: - self.conn.execute(''' - CREATE INDEX idx_cohort_dates ON cohort (cohort_definition_id, cohort_start_date, cohort_end_date); + self.conn.execute(f''' + CREATE INDEX idx_cohort_dates ON {self.schema}.cohort (cohort_definition_id, cohort_start_date, + cohort_end_date); ''') except duckdb.Error as e: if "already exists" in str(e).lower(): @@ -79,12 +98,12 @@ def _create_cohort_table(self): notify_users("Cohort table created.") def load_postgres_extension(self): - self.conn.execute("INSTALL postgres_scanner;") - self.conn.execute("LOAD postgres_scanner;") + self.conn.execute("INSTALL postgres;") + self.conn.execute("LOAD postgres;") def create_cohort_definition(self, cohort_definition: CohortDefinition, progress_obj=None): - self.conn.execute(''' - INSERT INTO cohort_definition (name, description, created_date, creation_info, created_by) + self.conn.execute(f''' + INSERT INTO {self.schema}.cohort_definition (name, description, created_date, creation_info, created_by) VALUES (?, ?, ?, ?, ?) ''', ( cohort_definition.name, @@ -97,7 +116,7 @@ def create_cohort_definition(self, cohort_definition: CohortDefinition, progress notify_users("Cohort definition inserted successfully.") # pragma: no cover else: progress_obj.write("Cohort definition inserted successfully.") - self.conn.execute("SELECT id from cohort_definition ORDER BY id DESC LIMIT 1") + self.conn.execute(f"SELECT id from {self.schema}.cohort_definition ORDER BY id DESC LIMIT 1") created_cohort_id = self.conn.fetchone()[0] return created_cohort_id @@ -105,14 +124,14 @@ def create_cohort_definition(self, cohort_definition: CohortDefinition, progress def create_cohort_in_bulk(self, cohort_df: pd.DataFrame): # make duckdb to treat cohort_df dataframe as a virtual table named "cohort_df" self.conn.register("cohort_df", cohort_df) - self.conn.execute(''' - INSERT INTO cohort (subject_id, cohort_definition_id, cohort_start_date, cohort_end_date) + self.conn.execute(f''' + INSERT INTO {self.schema}.cohort (subject_id, cohort_definition_id, cohort_start_date, cohort_end_date) SELECT subject_id, cohort_definition_id, cohort_start_date, cohort_end_date FROM cohort_df ''') def get_cohort_definition(self, cohort_definition_id): results = self.conn.execute(f''' - SELECT id, name, description, created_date, creation_info, created_by FROM cohort_definition + SELECT id, name, description, created_date, creation_info, created_by FROM {self.schema}.cohort_definition WHERE id = {cohort_definition_id} ''') headers = [desc[0] for desc in results.description] @@ -124,27 +143,13 @@ def get_cohort_definition(self, cohort_definition_id): def get_cohort(self, cohort_definition_id): results = self.conn.execute(f''' - SELECT subject_id, cohort_definition_id, cohort_start_date, cohort_end_date FROM cohort + SELECT subject_id, cohort_definition_id, cohort_start_date, cohort_end_date FROM {self.schema}.cohort WHERE cohort_definition_id = {cohort_definition_id} ''') headers = [desc[0] for desc in results.description] rows = results.fetchall() return [dict(zip(headers, row)) for row in rows] - def _create_omop_table(self, table_name): - if self.omop_cdm_db_url is not None and not self.omop_cdm_db_url.endswith('duckdb'): - # need to create person table from OMOP CDM postgreSQL database - self.conn.execute(f""" - CREATE TABLE IF NOT EXISTS {table_name} AS - SELECT * from postgres_scan('{self.omop_cdm_db_url}', 'public', {table_name}) - """) - return True # success - elif self.omop_cdm_db_url is None: - return False - else: # omop table is already included in duckdb - return True - - def _execute_query(self, query_str): results = self.conn.execute(query_str) @@ -163,15 +168,11 @@ def get_cohort_basic_stats(self, cohort_definition_id: int, variable=''): """ try: if variable: - if self._create_omop_table('person'): - query_str = self.__class__.stats_queries.get(variable) - if query_str is None: - raise ValueError(f"Statistics for variable '{variable}' is not available. " - f"Valid variables are {self.__class__.stats_queries.keys()}") - stats_query = query_str.format(cohort_definition_id) - else: - notify_users("Cannot connect to the OMOP database to query person table") - return None + query_str = self.__class__.stats_queries.get(variable) + if query_str is None: + raise ValueError(f"Statistics for variable '{variable}' is not available. " + f"Valid variables are {self.__class__.stats_queries.keys()}") + stats_query = query_str.format(ba_schema=self.schema, omop=self.omop_alias, cohort_definition_id=cohort_definition_id) else: # Query the cohort data to get basic statistics stats_query = f''' @@ -182,7 +183,7 @@ def get_cohort_basic_stats(self, cohort_definition_id: int, variable=''): cohort_end_date, cohort_end_date - cohort_start_date AS duration_days FROM - cohort + {self.schema}.cohort WHERE cohort_definition_id = {cohort_definition_id} ) SELECT @@ -213,16 +214,13 @@ def get_cohort_distributions(self, cohort_definition_id: int, variable: str): Get distribution statistics for a cohort from the cohort table. """ try: - if self._create_omop_table('person'): - query_str = self.__class__.distribution_queries.get(variable) - if not query_str: - raise ValueError(f"Distribution for variable '{variable}' is not available. " - f"Valid variables are {self.__class__.distribution_queries.keys()}") - query = query_str.format(cohort_definition_id) - return self._execute_query(query) - else: - notify_users("Cannot connect to the OMOP database to query person table") - return None + query_str = self.__class__.distribution_queries.get(variable) + if not query_str: + raise ValueError(f"Distribution for variable '{variable}' is not available. " + f"Valid variables are {self.__class__.distribution_queries.keys()}") + query = query_str.format(ba_schema=self.schema, omop=self.omop_alias, + cohort_definition_id=cohort_definition_id) + return self._execute_query(query) except Exception as e: notify_users(f"Error computing cohort {variable} distributions: {e}", level='error') return None @@ -236,40 +234,35 @@ def get_cohort_concept_stats(self, cohort_definition_id: int, qry_builder, concept_stats = {} try: - if (self._create_omop_table('concept') and self._create_omop_table('concept_ancestor') - and self._create_omop_table(concept_type)): - # validate input vocab if it is not None - if vocab is not None: - valid_vocabs = self._execute_query("SELECT distinct vocabulary_id FROM concept") - valid_vocab_ids = [row['vocabulary_id'] for row in valid_vocabs] - if vocab not in valid_vocab_ids: - err_msg = (f"input {vocab} is not a valid vocabulary in OMOP. " - f"Supported vocabulary ids are: {valid_vocab_ids}") - notify_users(err_msg, level='error') - raise ValueError(err_msg) - - query = qry_builder.build_concept_prevalence_query(concept_type, cohort_definition_id, - filter_count, vocab) - concept_stats[concept_type] = self._execute_query(query) - cs_df = pd.DataFrame(concept_stats[concept_type]) - # Combine concept_name and prevalence into a "details" column - cs_df["details"] = cs_df.apply( - lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, " - f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})", axis=1) - - if print_concept_hierarchy: - filtered_cs_df = cs_df[cs_df['ancestor_concept_id'] != cs_df['descendant_concept_id']] - roots = find_roots(filtered_cs_df) - hierarchy = build_concept_hierarchy(filtered_cs_df) - notify_users(f'cohort concept hierarchy for {concept_type} with root concept ids {roots}:') - for root in roots: - root_detail = cs_df[(cs_df['ancestor_concept_id'] == root) - & (cs_df['descendant_concept_id'] == root)]['details'].iloc[0] - print_hierarchy(hierarchy, parent=root, level=0, parent_details=root_detail) - return concept_stats - else: - err_msg = "Cannot connect to the OMOP database to query concept table" - raise ValueError(err_msg) + # validate input vocab if it is not None + if vocab is not None: + valid_vocabs = self._execute_query(f"SELECT distinct vocabulary_id FROM {self.omop_alias}.concept") + valid_vocab_ids = [row['vocabulary_id'] for row in valid_vocabs] + if vocab not in valid_vocab_ids: + err_msg = (f"input {vocab} is not a valid vocabulary in OMOP. " + f"Supported vocabulary ids are: {valid_vocab_ids}") + notify_users(err_msg, level='error') + raise ValueError(err_msg) + + query = qry_builder.build_concept_prevalence_query(self.schema, self.omop_alias, concept_type, + cohort_definition_id, filter_count, vocab) + concept_stats[concept_type] = self._execute_query(query) + cs_df = pd.DataFrame(concept_stats[concept_type]) + # Combine concept_name and prevalence into a "details" column + cs_df["details"] = cs_df.apply( + lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, " + f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})", axis=1) + + if print_concept_hierarchy: + filtered_cs_df = cs_df[cs_df['ancestor_concept_id'] != cs_df['descendant_concept_id']] + roots = find_roots(filtered_cs_df) + hierarchy = build_concept_hierarchy(filtered_cs_df) + notify_users(f'cohort concept hierarchy for {concept_type} with root concept ids {roots}:') + for root in roots: + root_detail = cs_df[(cs_df['ancestor_concept_id'] == root) + & (cs_df['descendant_concept_id'] == root)]['details'].iloc[0] + print_hierarchy(hierarchy, parent=root, level=0, parent_details=root_detail) + return concept_stats except Exception as e: err_msg = f"Error computing cohort concept stats: {e}" raise ValueError(err_msg) diff --git a/biasanalyzer/module_test.py b/biasanalyzer/module_test.py index 9fad973..cacc79d 100644 --- a/biasanalyzer/module_test.py +++ b/biasanalyzer/module_test.py @@ -9,10 +9,12 @@ def cohort_creation_template_test(bias_obj): cohort_data = bias_obj.create_cohort('COVID-19 patients', 'COVID-19 patients', os.path.join(os.path.dirname(__file__), '..', 'tests', 'assets', 'cohort_creation', - # 'extras', + 'extras', + 'diabetes_example2', + 'cohort_creation_config_baseline_example2.yaml'), # 'covid_example3', # 'cohort_creation_config_baseline_example3.yaml'), - 'test_cohort_creation_condition_occurrence_config_study.yaml'), + # 'test_cohort_creation_condition_occurrence_config_study.yaml'), 'system') if cohort_data: md = cohort_data.metadata @@ -24,19 +26,19 @@ def cohort_creation_template_test(bias_obj): print(f'the cohort race stats: {cohort_data.get_stats("race")}') print(f'the cohort ethnicity stats: {cohort_data.get_stats("ethnicity")}') print(f'the cohort age distributions: {cohort_data.get_distributions("age")}') + print(f'the cohort gender distributions: {cohort_data.get_distributions("gender")}') compare_stats = bias_obj.compare_cohorts(cohort_data.metadata['id'], cohort_data.metadata['id']) print(f'compare_stats: {compare_stats}') return def condition_cohort_test(bias_obj): - baseline_cohort_query = ('SELECT c.person_id, c.condition_start_date as cohort_start_date, ' - 'c.condition_end_date as cohort_end_date ' + baseline_cohort_query = ('SELECT c.person_id, MIN(c.condition_start_date) as cohort_start_date, ' + 'MAX(c.condition_end_date) as cohort_end_date ' 'FROM condition_occurrence c JOIN ' 'person p ON c.person_id = p.person_id ' - 'WHERE c.condition_concept_id = 37311061 ' - 'AND p.gender_concept_id = 8532 AND p.year_of_birth > 2000') - cohort_data = bias_obj.create_cohort('COVID-19 patients', 'COVID-19 patients', + 'WHERE c.condition_concept_id = 201826 GROUP BY c.person_id') + cohort_data = bias_obj.create_cohort('Diabetics patients', 'Diabetics patients', baseline_cohort_query, 'system') if cohort_data: md = cohort_data.metadata @@ -51,8 +53,8 @@ def condition_cohort_test(bias_obj): t1 = time.time() _, cohort_concept_hierarchy = cohort_data.get_concept_stats(concept_type='condition_occurrence', filter_count=5000) - concept_node = cohort_concept_hierarchy.get_node(concept_id=37311061) - print(f'concept_node 37311061 metric: {concept_node.get_metrics(md["id"])}') + concept_node = cohort_concept_hierarchy.get_node(concept_id=201826) + print(f'concept_node 201826 metric: {concept_node.get_metrics(md["id"])}') # Print the root node root_nodes = cohort_concept_hierarchy.get_root_nodes() @@ -62,7 +64,7 @@ def condition_cohort_test(bias_obj): print(f"Root: {root}", flush=True) print(f"Leaves: {leaves}", flush=True) for node in cohort_concept_hierarchy.iter_nodes(root_nodes[0].id, serialization=True): - print(node) + print(node) hier_dict = cohort_concept_hierarchy.to_dict() pprint.pprint(hier_dict, indent=2) @@ -107,6 +109,7 @@ def concept_test(bias_obj): pd.set_option('display.width', 1000) try: bias = BIAS() + # bias.set_config(os.path.join(os.path.dirname(__file__), '..', 'config_duckdb.yaml')) bias.set_config(os.path.join(os.path.dirname(__file__), '..', 'config.yaml')) bias.set_root_omop() diff --git a/biasanalyzer/sql.py b/biasanalyzer/sql.py index 10bbef7..05ac55d 100644 --- a/biasanalyzer/sql.py +++ b/biasanalyzer/sql.py @@ -10,8 +10,8 @@ CURRENT_DATE ) ) - p.year_of_birth AS age - FROM cohort c JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} ), -- Define age bins manually using SELECT statements and UNION ALL Age_Bins AS ( @@ -63,9 +63,9 @@ ELSE 'other' END AS gender, p.person_id - FROM cohort c - JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c + JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} ) cd ON gc.gender = cd.gender GROUP BY gc.gender ) @@ -88,8 +88,8 @@ CURRENT_DATE ) ) - p.year_of_birth AS age - FROM cohort c JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} ) -- Calculate age distribution statistics SELECT @@ -111,8 +111,8 @@ END AS gender, COUNT(*) AS gender_count, ROUND(COUNT(*) / SUM(COUNT(*)) OVER (), 2) as probability - FROM cohort c JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.gender_concept_id ''' @@ -128,8 +128,8 @@ END AS race, COUNT(*) AS race_count, ROUND(COUNT(*) / SUM(COUNT(*)) OVER (), 2) AS probability - FROM cohort c JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.race_concept_id ''' @@ -142,7 +142,7 @@ END AS ethnicity, COUNT(*) AS ethnicity_count, ROUND(COUNT(*) / SUM(COUNT(*)) OVER (), 2) AS probability - FROM cohort c JOIN person p ON c.subject_id = p.person_id - WHERE c.cohort_definition_id = {} + FROM {ba_schema}.cohort c JOIN {omop}.person p ON c.subject_id = p.person_id + WHERE c.cohort_definition_id = {cohort_definition_id} GROUP BY p.ethnicity_concept_id ''' diff --git a/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 b/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 index b832cd8..425e130 100644 --- a/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 +++ b/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 @@ -4,9 +4,9 @@ WITH cohort_events AS ( e.{{ concept_id_column }} AS concept_id, ct.subject_id FROM - cohort ct + {{ db_schema }}.cohort ct JOIN - {{ table_name }} e ON ct.subject_id = e.person_id + {{ omop }}.{{ table_name }} e ON ct.subject_id = e.person_id AND e.{{ start_date_column }} >= ct.cohort_start_date AND (ct.cohort_end_date IS NULL OR e.{{ start_date_column }} <= ct.cohort_end_date) WHERE ct.cohort_definition_id = {{ cid }} @@ -19,9 +19,9 @@ aggregated_counts AS ( FROM cohort_events ce JOIN - concept_ancestor ca ON ce.concept_id = ca.descendant_concept_id + {{ omop }}.concept_ancestor ca ON ce.concept_id = ca.descendant_concept_id JOIN - concept anc ON ca.ancestor_concept_id = anc.concept_id + {{ omop }}.concept anc ON ca.ancestor_concept_id = anc.concept_id WHERE anc.vocabulary_id = '{{ vocab }}' AND ca.min_levels_of_separation >= 0 @@ -34,7 +34,7 @@ concept_hierarchy AS ( ca.ancestor_concept_id, ca.descendant_concept_id FROM - concept_ancestor ca + {{ omop }}.concept_ancestor ca WHERE ca.min_levels_of_separation <= 1 AND ca.descendant_concept_id IN (SELECT concept_id FROM aggregated_counts WHERE count_in_cohort > {{ filter_count }}) @@ -45,7 +45,7 @@ SELECT DISTINCT c.concept_name, c.concept_code, ac.count_in_cohort, - (ac.count_in_cohort * 1.0 / (SELECT COUNT(DISTINCT subject_id) FROM cohort WHERE cohort_definition_id = {{ cid }})) AS prevalence, + (ac.count_in_cohort * 1.0 / (SELECT COUNT(DISTINCT subject_id) FROM {{ db_schema }}.cohort WHERE cohort_definition_id = {{ cid }})) AS prevalence, ch.ancestor_concept_id, ch.descendant_concept_id FROM @@ -53,7 +53,7 @@ FROM JOIN concept_hierarchy ch ON ac.concept_id = ch.descendant_concept_id JOIN - concept c ON ac.concept_id = c.concept_id + {{ omop }}.concept c ON ac.concept_id = c.concept_id WHERE ac.count_in_cohort > {{ filter_count }} ORDER BY diff --git a/config_duckdb.yaml b/config_duckdb.yaml new file mode 100644 index 0000000..6421b21 --- /dev/null +++ b/config_duckdb.yaml @@ -0,0 +1,7 @@ +root_omop_cdm_database: + database_type: duckdb + username: dummy + password: dummy + hostname: dummy + database: /home/hongyi/Downloads/synpuf_100k_omop_54.duckdb + port: 5432 diff --git a/tests/query_based/test_cohort_creation.py b/tests/query_based/test_cohort_creation.py index faafa94..8f82190 100644 --- a/tests/query_based/test_cohort_creation.py +++ b/tests/query_based/test_cohort_creation.py @@ -186,305 +186,6 @@ def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db): assert_equal(len(patient_ids), 2) assert_equal(patient_ids, {108, 110}) -def test_cohort_creation_mixed_domains(test_db): - """ - Test cohort creation with mixed domains (condition, drug, visit, procedure). - """ - bias = test_db - cohort = bias.create_cohort( - "Female diabetes patients born between 1970 and 2000", - "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " - "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " - "with patients born after 1995 and with cardiac surgery excluded", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_config.yaml'), - "test_user" - ) - - # Test cohort object and methods - assert cohort is not None, "Cohort creation failed" - print(f'metadata: {cohort.metadata}') - assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" - stats = cohort.get_stats() - assert stats is not None, "Created cohort's stats is None" - assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) - assert_equal(len(patient_ids), 2) - assert_equal(patient_ids, {1, 2}) - start_dates = [item['cohort_start_date'] for item in cohort.data] - assert_equal(len(start_dates), 2) - assert_equal(start_dates, [datetime.date(2020, 6, 1), datetime.date(2020, 6, 1)]) - end_dates = [item['cohort_end_date'] for item in cohort.data] - assert_equal(len(end_dates), 2) - assert_equal(end_dates, [datetime.date(2020, 6, 20), datetime.date(2020, 6, 20)]) - -def test_cohort_comparison(test_db): - bias = test_db - cohort_base = bias.create_cohort( - "COVID-19 patient", - "Cohort of young female patients", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), - "test_user" - ) - cohort_study = bias.create_cohort( - "Female diabetes patients born between 1970 and 2000", - "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " - "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " - "with patients born after 1995 and with cardiac surgery excluded", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_config.yaml'), - "test_user" - ) - results = bias.compare_cohorts(cohort_base.cohort_id, cohort_study.cohort_id) - assert {'gender_hellinger_distance': 0.0} in results - assert any('age_hellinger_distance' in r for r in results) - -def test_cohort_invalid(caplog, test_db): - caplog.clear() - with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - 'invalid_yaml_file.yml', - 'invalid_created_by') - assert 'cohort creation configuration file does not exist' in caplog.text - assert invalid_cohort is None - - caplog.clear() - with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - os.path.join(os.path.dirname(__file__), '..', 'assets', 'config', - 'test_config.yaml'), 'invalid_created_by') - assert 'configuration yaml file is not valid' in caplog.text - assert invalid_cohort is None - - with caplog.at_level(logging.INFO): - invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', - 'INVALID SQL QUERY STRING', - 'invalid_created_by') - assert 'Error executing query:' in caplog.text - assert invalid_cohort is None - -def test_create_cohort_sqlalchemy_error(monkeypatch, fresh_bias_obj): - # Mock omop_db methods - class MockOmopDB: - def get_session(self): - return self # not used after error - def execute_query(self, query): - raise SQLAlchemyError("Mocked SQLAlchemy error") - def close(self): - pass - - class MockBiasDB: - def create_cohort_definition(self, *args, **kwargs): - pass - def create_cohort_in_bulk(self, *args, **kwargs): - pass - def close(self): - pass - - fresh_bias_obj.omop_cdm_db = MockOmopDB() - fresh_bias_obj.bias_db = MockBiasDB() - - result = fresh_bias_obj.create_cohort("test", "desc", "SELECT * FROM person", "test_user") - - assert result is None - - -import os -import datetime -import logging -import pytest -from sqlalchemy.exc import SQLAlchemyError -from numpy.ma.testutils import assert_equal -from biasanalyzer.models import DemographicsCriteria, TemporalEvent, TemporalEventGroup - - -def test_cohort_yaml_validation(test_db): - invalid_data = { - "gender": "female", - "min_birth_year": 2000, - "max_birth_year": 1999 # Invalid: less than min_birth_year - } - with pytest.raises(ValueError): - DemographicsCriteria(**invalid_data) - - invalid_data = { - "event_type": "date", - "event_concept_id": "dummy" - } - # validate date event_type must have a timestamp field - with pytest.raises(ValueError): - TemporalEvent(**invalid_data) - - invalid_data = { - "operator": "BEFORE", - "events": [ - {'event_type': 'condition_occurrence', - 'event_concept_id': 201826}, - {'event_type': 'drug_exposure', - 'event_concept_id': 4285892}, - ], - "interval": [100, 50] - } - # validate interval start must be smaller than interval end - with pytest.raises(ValueError): - TemporalEventGroup(**invalid_data) - - # validate interval must be either a list of 2 integers or a None - invalid_data["interval"] = [123] - with pytest.raises(ValueError): - TemporalEventGroup(**invalid_data) - - # validate NOT operator cannot have more than one event - invalid_data["operator"] = "NOT" - with pytest.raises(ValueError): - TemporalEventGroup(**invalid_data) - - # validate BEFORE operator must have two events - invalid_data["operator"] = "BEFORE" - del invalid_data["events"][1] - with pytest.raises(ValueError): - TemporalEventGroup(**invalid_data) - - -def test_cohort_creation_baseline(caplog, test_db): - bias = test_db - cohort = bias.create_cohort( - "COVID-19 patient", - "Cohort of young female patients", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), - "test_user" - ) - - # Test cohort object and methods - assert cohort is not None, "Cohort creation failed" - cohort_id = cohort.cohort_id - assert bias.bias_db.get_cohort_definition(cohort_id)['name'] == "COVID-19 patient" - assert bias.bias_db.get_cohort_definition(cohort_id + 1) == {} - assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" - assert cohort.data is not None, "Cohort creation wrongly returned None data" - caplog.clear() - with caplog.at_level(logging.ERROR): - cohort.get_distributions('ethnicity') - assert "Distribution for variable 'ethnicity' is not available" in caplog.text - - assert len(cohort.get_distributions('age')) == 10, "Cohort get_distribution('age') does not return 10 age_bin items" - assert len(cohort.get_distributions('gender')) == 3, ("Cohort get_distribution('gender') does not return " - "3 gender_bin items") - - patient_ids = set([item['subject_id'] for item in cohort.data]) - assert_equal(len(patient_ids), 5) - assert_equal(patient_ids, {106, 108, 110, 111, 112}) - # select two patients to check for cohort_start_date and cohort_end_date automatically computed - patient_106 = next(item for item in cohort.data if item['subject_id'] == 106) - patient_108 = next(item for item in cohort.data if item['subject_id'] == 108) - - # Replace dates with actual values from your test data - assert_equal(patient_106['cohort_start_date'], datetime.date(2023, 3, 1), - "Incorrect cohort_start_date for patient 106") - assert_equal(patient_106['cohort_end_date'], datetime.date(2023, 3, 15), - "Incorrect cohort_end_date for patient 106") - assert_equal(patient_108['cohort_start_date'], datetime.date(2020, 4, 10), - "Incorrect cohort_start_date for patient 108") - assert_equal(patient_108['cohort_end_date'], datetime.date(2020, 4, 27), - "Incorrect cohort_end_date for patient 108") - - -def test_cohort_creation_study(test_db): - bias = test_db - cohort = bias.create_cohort( - "COVID-19 patient", - "Cohort of young female patients with COVID-19", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_study.yaml'), - "test_user" - ) - # Test cohort object and methods - assert cohort is not None, "Cohort creation failed" - assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" - assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - assert_equal(len(patient_ids), 4) - assert_equal(patient_ids, {108, 110, 111, 112}) - - -def test_cohort_creation_study2(caplog, test_db): - bias = test_db - caplog.clear() - with caplog.at_level(logging.INFO): - cohort = bias.create_cohort( - "COVID-19 patient", - "Cohort of young female patients with no COVID-19", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config_study2.yaml'), - "test_user", - delay=1 - ) - assert 'Simulating long-running task' in caplog.text - # Test cohort object and methods - assert cohort is not None, "Cohort creation failed" - assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" - assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - assert_equal(len(patient_ids), 1) - assert_equal(patient_ids, {106}) - - -def test_cohort_creation_all(caplog, test_db): - bias = test_db - cohort = bias.create_cohort( - "COVID-19 patient", - "Cohort of young female patients with COVID-19 who have the condition with difficulty breathing 2 to 5 days " - "before a COVID diagnosis 3/15/20-12/11/20 AND have at least one emergency room visit or at least " - "two inpatient visits", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config.yaml'), - "test_user" - ) - # Test cohort object and methods - assert cohort is not None, "Cohort creation failed" - assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" - assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" - stats = cohort.get_stats() - assert stats is not None, "Created cohort's stats is None" - gender_stats = cohort.get_stats(variable='gender') - assert gender_stats is not None, "Created cohort's gender stats is None" - caplog.clear() - with caplog.at_level(logging.ERROR): - cohort.get_stats(variable='address') - assert 'is not available' in caplog.text - assert gender_stats is not None, "Created cohort's gender stats is None" - assert cohort.data is not None, "Cohort creation wrongly returned None data" - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) - assert_equal(len(patient_ids), 2) - assert_equal(patient_ids, {108, 110}) - - -def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db): - bias = test_db - cohort = bias.create_cohort( - "Patients with COVID or other emergency conditions", - "Cohort of young female patients who either have COVID-19 with difficulty breathing 2 to 5 days " - "before a COVID diagnosis 3/15/20-12/11/20 OR have at least one emergency room visit or at least " - "two inpatient visits", - os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', - 'test_cohort_creation_multiple_temporal_groups_without_operator.yaml'), - "test_user" - ) - # Test cohort object and methods - patient_ids = set([item['subject_id'] for item in cohort.data]) - print(f'patient_ids: {patient_ids}', flush=True) - assert_equal(len(patient_ids), 2) - assert_equal(patient_ids, {108, 110}) - - def test_cohort_creation_mixed_domains(test_db): """ Test cohort creation with mixed domains (condition, drug, visit, procedure). @@ -523,7 +224,6 @@ def test_cohort_creation_mixed_domains(test_db): datetime.date(2020, 6, 20), datetime.date(2018, 1, 20)]) - def test_cohort_comparison(test_db): bias = test_db cohort_base = bias.create_cohort( @@ -546,7 +246,6 @@ def test_cohort_comparison(test_db): assert {'gender_hellinger_distance': 0.0} in results assert any('age_hellinger_distance' in r for r in results) - def test_cohort_invalid(caplog, test_db): caplog.clear() with caplog.at_level(logging.INFO): @@ -571,26 +270,21 @@ def test_cohort_invalid(caplog, test_db): assert 'Error executing query:' in caplog.text assert invalid_cohort is None - def test_create_cohort_sqlalchemy_error(monkeypatch, fresh_bias_obj): # Mock omop_db methods class MockOmopDB: def get_session(self): return self # not used after error - def execute_query(self, query): raise SQLAlchemyError("Mocked SQLAlchemy error") - def close(self): pass class MockBiasDB: def create_cohort_definition(self, *args, **kwargs): pass - def create_cohort_in_bulk(self, *args, **kwargs): pass - def close(self): pass @@ -601,7 +295,6 @@ def close(self): assert result is None - def test_cohort_creation_negative_instance(test_db): """ Test cohort creation with negative event_instance (last occurrence of a condition). diff --git a/tests/test_biasanalyzer_api.py b/tests/test_biasanalyzer_api.py index d245017..a47a3a5 100644 --- a/tests/test_biasanalyzer_api.py +++ b/tests/test_biasanalyzer_api.py @@ -68,8 +68,9 @@ def close(self): # --- Step 3: Mock BiasDatabase and its methods --- class MockBiasDatabase: - def __init__(self, path): - self.omop_cdm_db_url = None + def __init__(self, path, omop_db_url=None): + self.omop_cdm_db_url = "postgresql://testuser:testpass@localhost:5432/testdb" + self.omop_db_url = "postgresql://testuser:testpass@localhost:5432/testdb" def load_postgres_extension(self): pass @@ -123,7 +124,7 @@ def test_cohorts_union_concept_stats(test_db): # Show what cohorts exist in the test DB and print cohorts and stats so we know what raw data looks like cohorts_df = test_db.bias_db.conn.execute(""" SELECT cohort_definition_id, COUNT(*) as n_subjects - FROM cohort + FROM biasanalyzer.cohort WHERE cohort_definition_id = 1 or cohort_definition_id = 2 GROUP BY cohort_definition_id ORDER BY cohort_definition_id @@ -135,8 +136,8 @@ def test_cohorts_union_concept_stats(test_db): SELECT c.cohort_definition_id, co.condition_concept_id, COUNT(*) as n - FROM cohort c - JOIN condition_occurrence co + FROM biasanalyzer.cohort c + JOIN omop.condition_occurrence co ON c.subject_id = co.person_id WHERE c.cohort_definition_id = 1 or c.cohort_definition_id = 2 GROUP BY c.cohort_definition_id, co.condition_concept_id diff --git a/tests/test_database.py b/tests/test_database.py index 9944965..8fb1ed5 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,38 +1,10 @@ import duckdb import pytest -import logging +from unittest.mock import Mock from biasanalyzer.cohort_query_builder import CohortQueryBuilder from biasanalyzer.database import BiasDatabase -def test_create_omop_table_postgres(monkeypatch): - # Set up tracking dict - called = {"executed": False, "query": None} - - # Patch before BiasDatabase instance is created - def mock_execute(self, query): - called["executed"] = True - called["query"] = query - return None - - # Monkeypatch at class level first - monkeypatch.setattr(duckdb.DuckDBPyConnection, "execute", mock_execute) - - # Now create the instance (so it uses the patched class method) - BiasDatabase._instance = None - db = BiasDatabase(":memory:") - db.omop_cdm_db_url = None - result = db._create_omop_table("person") - assert result is False - - db.omop_cdm_db_url = "postgresql://user:pass@localhost:5432/mydb" - - result = db._create_omop_table("person") - - assert result is True - assert called["executed"] is True - assert "postgres_scan" in called["query"] - def test_load_postgres_extension_executes_twice(monkeypatch): # Reset singleton to get a clean instance BiasDatabase._instance = None @@ -52,8 +24,36 @@ def execute(self, query): # Assert that execute() was called twice assert len(calls) == 2 - assert "INSTALL postgres_scanner" in calls[0] - assert "LOAD postgres_scanner" in calls[1] + assert "INSTALL postgres" in calls[0] + assert "LOAD postgres" in calls[1] + +def test_bias_db_postgres_omop_db_url(monkeypatch): + # Reset singleton to get a clean instance + BiasDatabase._instance = None + calls = [] + + class MockConn: + def execute(self, query): + calls.append(query) + return self + + def close(self): + pass + + # Mock duckdb.connect to return our MockConn + mock_connect = Mock(return_value=MockConn()) + monkeypatch.setattr("duckdb.connect", mock_connect) + db = BiasDatabase(":memory:", omop_db_url="postgresql://testuser:testpass@localhost:5432/testdb") + + assert len(calls) >= 3 + assert any("INSTALL postgres" in call for call in calls), "INSTALL postgres must be run at BiasDatabase init" + assert any("LOAD postgres" in call for call in calls), "LOAD postgres must be run at BiasDatabase init" + assert any("ATTACH" in call for call in calls), "ATTACH must be run at BiasDatabase init" + +def test_bias_db_invalid_omop_db_url(): + BiasDatabase._instance = None + with pytest.raises(ValueError, match='Unsupported OMOP database backend'): + db = BiasDatabase(":memory:", omop_db_url='dummy_invalid_url') def test_create_cohort_definition_table_error_on_sequence(): BiasDatabase._instance = None