From e8f678526153f2dab1633a09d0efd4c4dce22d5d Mon Sep 17 00:00:00 2001 From: Raman Date: Thu, 10 Oct 2024 17:36:40 +0200 Subject: [PATCH 1/4] implemented computation and gathering of domain module ids for each fold-model --- README.md | 2 ++ .../screening/gather_classifier_checkpoints.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/README.md b/README.md index 6c5c282..a10a0b7 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,8 @@ pip install . ----------------------------------------- ## Quick start +To predict using the model based on TPS language model only, put the sequences of interest into a `.fasta` file and run + ```bash cd TerpeneMiner conda activate terpene_miner diff --git a/terpeneminer/src/screening/gather_classifier_checkpoints.py b/terpeneminer/src/screening/gather_classifier_checkpoints.py index e4deb49..ef0530f 100644 --- a/terpeneminer/src/screening/gather_classifier_checkpoints.py +++ b/terpeneminer/src/screening/gather_classifier_checkpoints.py @@ -88,6 +88,22 @@ def parse_args() -> argparse.Namespace: with open(fold_class_latest_path / f"model_fold_{fold_i}.pkl", "rb") as file: model = pickle.load(file) model.classifier.classes_ = model.config.class_names + if hasattr(model, "allowed_feat_indices"): + with open("data/clustering__domain_dist_based_features.pkl", "rb") as file: + domain_module_id_2_dist_matrix_index = pickle.load(file)[-1] + + with open("data/domains_subset.pkl", "rb") as file: + feat_indices_subset = pickle.load(file)[-1] + domain_module_id_2_dist_matrix_index_subset = {domain_id: [i for i in indices if i in feat_indices_subset] + for domain_id, indices in + domain_module_id_2_dist_matrix_index.items() + if len([i for i in indices if i in feat_indices_subset])} + feat_idx_2_module_id = {} + for module_id, feat_indices in domain_module_id_2_dist_matrix_index_subset.items(): + for feat_idx in feat_indices: + feat_idx_2_module_id[feat_idx] = module_id + order_of_domain_modules = [feat_idx_2_module_id[feat_i] for feat_i in model.allowed_feat_indices] + model.classifier.order_of_domain_modules = order_of_domain_modules classifiers.append(model.classifier) with open(args.output_path, "wb") as file_writer: From b0b9fe0492290467006148be9c7851d523a4ba90 Mon Sep 17 00:00:00 2001 From: Raman Date: Sat, 26 Oct 2024 18:51:54 +0200 Subject: [PATCH 2/4] bulk commit of work on a backend app to deploy terpeneminer --- README.md | 19 +- app.py | 282 ++++++++++++++++++ scripts/setup_env.sh | 2 +- .../Blastp/against_wetlab_data/config.yaml | 29 -- .../against_wetlab_data/config.yaml | 30 -- .../with_minor_reactions/config.yaml | 0 .../with_minor_reactions/config.yaml | 63 ---- .../config.yaml | 63 ---- .../HMM/against_wetlab_data/config.yaml | 30 -- .../config.yaml | 0 .../config.yaml | 0 .../config.yaml | 0 .../esm-1v_with_minor_reactions/config.yaml | 0 .../esm-2_with_minor_reactions/config.yaml | 0 .../main_config.yaml | 0 .../config.yaml | 0 .../config.yaml | 0 .../config.yaml | 0 terpeneminer/src/evaluation/plotting.py | 1 + terpeneminer/src/models/__init__.py | 2 +- terpeneminer/src/models/baselines/__init__.py | 2 +- .../gather_classifier_checkpoints.py | 18 +- .../comparing_to_known_domains.py | 206 +++++++++++++ .../comparing_to_known_domains_foldseek.py | 75 +++++ .../structure_processing/domain_detections.py | 128 ++++---- .../predict_domain_types.py | 89 ++++++ .../structural_algorithms.py | 21 +- .../train_domain_type_classifiers.py | 91 ++++++ 28 files changed, 863 insertions(+), 288 deletions(-) create mode 100644 app.py delete mode 100644 terpeneminer/configs/Blastp/against_wetlab_data/config.yaml delete mode 100644 terpeneminer/configs/CLEAN.ignore/against_wetlab_data/config.yaml rename terpeneminer/configs/{CLEAN.ignore => CLEAN}/with_minor_reactions/config.yaml (100%) delete mode 100644 terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions/config.yaml delete mode 100644 terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions_global_tuning/config.yaml delete mode 100644 terpeneminer/configs/HMM/against_wetlab_data/config.yaml rename terpeneminer/configs/PlmRandomForest/{tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore => tps_esm-1v-subseq_with_minor_reactions_global_tuning}/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/ankh_base_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/ankh_large_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/esm-1v_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/esm-2_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/main_config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/tps_ankh_base_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/tps_esm-1v-subseq_with_minor_reactions/config.yaml (100%) rename terpeneminer/configs/{PlmXgb.ignore => PlmXgb}/tps_esm-1v_with_minor_reactions/config.yaml (100%) create mode 100644 terpeneminer/src/structure_processing/comparing_to_known_domains.py create mode 100644 terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py create mode 100644 terpeneminer/src/structure_processing/predict_domain_types.py create mode 100644 terpeneminer/src/structure_processing/train_domain_type_classifiers.py diff --git a/README.md b/README.md index a10a0b7..efb75b5 100644 --- a/README.md +++ b/README.md @@ -278,9 +278,12 @@ cd TerpeneMiner conda activate terpene_miner python -m terpeneminer.src.structure_processing.domain_detections \ --needed-proteins-csv-path "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" \ + --csv-id-column "Uniprot ID" \ --input-directory-with-structures "data/alphafold_structs/" \ + --is-bfactor-confidence \ + --recompute-existing-secondary-structure-residues \ --n-jobs 16 --detections-output-path "data/filename_2_detected_domains_completed_confident.pkl" \ - --store-domains --domains-output-path "data/detected domains" > outputs/logs/tps_structures_segmentation.log 2>&1 + --store-domains --domains-output-path "data/detected_domains" > outputs/logs/tps_structures_segmentation.log 2>&1 ``` #### 2 - Pairwise comparison of the detected domains @@ -293,8 +296,9 @@ cd TerpeneMiner conda activate terpene_miner python -m terpeneminer.src.structure_processing.compute_pairwise_similarities_of_domains \ --name all \ - --n-jobs 64 \ - --precomputed-scores-path "data/precomputed_tmscores.pkl" > outputs/logs/pairwise_comparisons.log 2>&1 + --needed-proteins-csv-path "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" \ + --csv-id-column "Uniprot ID" \ + --n-jobs 64 > outputs/logs/pairwise_comparisons.log 2>&1 ``` Note the `--precomputed-scores-path` argument. It is used to store the previously computed TM-scores. For the efficiency of any future extensions of the project, we share the precomputed TM-scores in `data/precomputed_tmscores.pkl` on GitHub. @@ -363,6 +367,13 @@ jupyter notebook Then, execute the notebook `notebooks/notebook_3_clustering_domains.ipynb`. +#### 4 - Train classifiers of domain types and novel-domain detectors +```bash +cd TerpeneMiner +conda activate terpene_miner +python -m terpeneminer.structure_processing.train_domain_type_classifiers > outputs/logs/domain_type_classifier_training.log 2>&1 +``` + ----------------------------------------- ### Predictive Modeling @@ -418,7 +429,7 @@ python -m terpeneminer.src.models.plm_domain_faster.get_domains_feature_importan ``` -###### Troubleshoting +###### Troubleshooting - Please note, that if you run into error `FileNotFoundError: [Errno 2] No such file or directory: '/model_fold_0.pkl'`, you might need to re-run the training of the model while specifying the `save_trained_model: true` in the config. diff --git a/app.py b/app.py new file mode 100644 index 0000000..d106022 --- /dev/null +++ b/app.py @@ -0,0 +1,282 @@ +from uuid import uuid4 + +from fastapi import FastAPI, File, UploadFile, BackgroundTasks, Form +from fastapi.responses import FileResponse +from pathlib import Path +import os +import pickle +from shutil import copyfile, rmtree +import logging +import subprocess +import re + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler() # Output to console + ] +) + +# Create FastAPI app instance +app = FastAPI() + +@app.post("/detect_domains/") +async def upload_file(file: UploadFile = File(...), + is_bfactor_confidence: bool = Form(...)): + # Read the contents of the uploaded file + file_contents = await file.read() + + # Define the path where the .pdb file will be saved + pdb_directory_temp = Path("_temp") + if not pdb_directory_temp.exists(): + pdb_directory_temp.mkdir() + af_source_path = Path("/home/samusevich/TerpeneMiner/data/alphafold_structs") + for pdb_standard_id in ["1ps1", "5eat", "3p5r", "P48449"]: + pdb_standard_file_path = af_source_path / f"{pdb_standard_id}.pdb" + copyfile(pdb_standard_file_path, pdb_directory_temp / f"{pdb_standard_id}.pdb") + + # Getting the ID + pdb_id = file.filename.split(".")[0] + pdb_id = re.sub(r'\(.*?\)', '', pdb_id) + pdb_id = "".join(pdb_id.replace("-", "").split()) + + # Define the path where the .pdb file will be saved + pdb_file_path = pdb_directory_temp / f"{pdb_id}.pdb" + + # Saving the ID into a csv file + id_filepath = f'{pdb_directory_temp / "dummy_id.csv"}' + with open(id_filepath, "a") as file: + file.writelines(f"ID\n{pdb_id}\n") + + # Save the content as a .pdb file + with open(pdb_file_path, "wb") as pdb_file: + pdb_file.write(file_contents) + temp_filepath_name = Path("data/alphafold_structs") / f"{pdb_id}.pdb" + if not temp_filepath_name.exists(): + copyfile(pdb_file_path, temp_filepath_name) + temp_filepath_name_to_delete = not temp_filepath_name.exists() + + domain_detections_path = f"_temp/filename_2_detected_domains_completed_confident_{pdb_id}.pkl" + detected_domain_structures_root = Path("_temp/detected_domains") + if not detected_domain_structures_root.exists(): + detected_domain_structures_root.mkdir() + os.system( + "python -m terpeneminer.src.structure_processing.domain_detections " + f'--needed-proteins-csv-path "{id_filepath}" ' + "--csv-id-column ID " + "--n-jobs 16 " + "--input-directory-with-structures _temp " + f"{'--is-bfactor-confidence ' if is_bfactor_confidence else ''}" + f'--detections-output-path "{domain_detections_path}" ' + f'--detected-regions-root-path _temp ' + f'--domains-output-path "{detected_domain_structures_root}" ' + "--store-domains " + "--recompute-existing-secondary-structure-residues " + "--do-not-store-intermediate-files" + ) + + with open(domain_detections_path, "rb") as file: + detected_domains = pickle.load(file) + + all_secondary_structure_residues_path = "_temp/file_2_all_residues.pkl" + with open(all_secondary_structure_residues_path, "rb") as file: + file_2_all_residues = pickle.load(file) + if pdb_id in file_2_all_residues: + secondary_structure_res = file_2_all_residues[pdb_id] + else: + secondary_structure_res = None + + logger.info("Detected %d domains. Starting comparison to the known domains..", len(detected_domains)) + if detected_domains: + current_computation_id = uuid4() + comparison_results_path = f"_temp/filename_2_regions_vs_known_reg_dists_{current_computation_id}.pkl" + os.system("python -m terpeneminer.src.structure_processing.comparing_to_known_domains_foldseek " + f'--known-domain-structures-root data/detected_domains/all ' + f'--detected-domain-structures-root "{detected_domain_structures_root}" ' + '--path-to-known-domains-subset data/domains_subset.pkl ' + f'--output-path "{comparison_results_path}" ' + f'--pdb-id "{pdb_id}"') + + logger.info("Compared detected domains to the known ones!") + + with open(comparison_results_path, "rb") as file: + comparison_results = pickle.load(file) + + domain_id_2_aligned_pdb = {} + + with open('data/reaction_types_and_kingdoms.pkl', 'rb') as file: + id_2_reaction_types, id_2_kingdom = pickle.load(file) + with open('data/domain_module_id_2_domain_type.pkl', 'rb') as file: + domain_module_id_2_domain_type = pickle.load(file) + with open('data/id_2_domain_config.pkl', 'rb') as file: + id_2_domain_config = pickle.load(file) + + for detected_domain_id in comparison_results[pdb_id]: + detected_domain_file_path = detected_domain_structures_root / f"{detected_domain_id}.pdb" + pdb_id_current = detected_domain_id.split('_')[0] + closest_known_domain_id, foldseek_tm_score = max([(known_domain_id, tmscore) + for known_domain_id, tmscore in comparison_results[pdb_id][detected_domain_id] + if known_domain_id.split('_')[0] != pdb_id_current], + key=lambda x: x[1]) + closest_known_domain_id_pdb_id = closest_known_domain_id.split('_')[0] + closest_known_domain_file_path = Path("data/detected_domains/all") / f"{closest_known_domain_id}.pdb" + + aligned_pdb_path = Path("_temp") / f"aligned_{detected_domain_id}_to_{closest_known_domain_id}" + # Run TM-align and capture the output + try: + result = subprocess.run( + ["TMalign", closest_known_domain_file_path, detected_domain_file_path, "-o", aligned_pdb_path], + check=True, + capture_output=True, # Capture stdout and stderr + text=True # Ensure output is in text form, not bytes + ) + except subprocess.CalledProcessError as e: + raise ValueError(f"TM-align failed, details {e}") + + # Extract TM-score from the output + output = result.stdout + tm_score = None + for line in output.splitlines(): + if "TM-score" in line and "Chain_1" in line: # TM-score line (ignores local TM-scores) + tm_score = float(line.split()[1]) + break + domain_id_2_aligned_pdb[detected_domain_id] = {"closest_known_domain_pdb_id": closest_known_domain_id_pdb_id, + "whole_structure_domain_config": id_2_domain_config[closest_known_domain_id_pdb_id], + "closest_domain_type": domain_module_id_2_domain_type[closest_known_domain_id], + "closest_id_reaction_types": [tps_type.replace('Class', 'class') + for tps_type in id_2_reaction_types[closest_known_domain_id_pdb_id]], + "closest_id_kingdom": id_2_kingdom[closest_known_domain_id_pdb_id], + "tm_score": tm_score, + "aligned_pdb_name": f"{aligned_pdb_path.name}_all_atm"} + os.remove(aligned_pdb_path) + os.remove(f"{aligned_pdb_path}_all") + os.remove(f"{aligned_pdb_path}_atm") + os.remove(f"{aligned_pdb_path}_all_atm_lig") + + logger.info("Predicting domain types..") + domain_predictions_path = f"_temp/domain_id_2_predictions_{uuid4()}.pkl" + os.system( + "python -m terpeneminer.src.structure_processing.predict_domain_types " + "--tps-classifiers-path data/classifier_domain_and_plm_checkpoints.pkl " + "--domain-classifiers-path data/domain_type_predictors_foldseek.pkl " + f"--path-to-domain-comparisons {comparison_results_path} " + f'--id "{pdb_id}" ' + f'--output-path "{domain_predictions_path}" ') + + with open(domain_predictions_path, "rb") as file: + domain_id_2_predictions = pickle.load(file) + os.remove(comparison_results_path) + os.remove(domain_predictions_path) + + os.remove(pdb_file_path) + if temp_filepath_name_to_delete: + os.remove(temp_filepath_name) + os.remove(id_filepath) + os.remove(domain_detections_path) + rmtree(detected_domain_structures_root) + + return {"domains": detected_domains, "secondary_structure_residues": secondary_structure_res, + "comparison_to_known_domains": comparison_results[pdb_id] if detected_domains else None, + "domain_type_predictions": domain_id_2_predictions if detected_domains else None, + "aligned_pdb_filepaths": domain_id_2_aligned_pdb if detected_domains else None} + +def delete_file(file_path: str): + os.remove(file_path) + +# endpoint to download the aligned PDB file +@app.get("/download_pdb/{aligned_pdb_name}") +async def download_aligned_pdb(aligned_pdb_name: str, background_tasks: BackgroundTasks): + aligned_pdb_path = Path("_temp") / aligned_pdb_name + if not os.path.exists(aligned_pdb_path): + return {"error": "File not found"} + # schedule file deletion after the response is sent + background_tasks.add_task(delete_file, aligned_pdb_path) + return FileResponse(aligned_pdb_path, media_type='application/octet-stream', filename=Path(aligned_pdb_path).name) + + +@app.post("/predict_tps/") +async def upload_file(file: UploadFile = File(...), + is_bfactor_confidence: bool = Form(...)): + # Read the contents of the uploaded file + file_contents = await file.read() + + # Define the path where the .pdb file will be saved + pdb_directory_temp = Path("_temp") + if not pdb_directory_temp.exists(): + pdb_directory_temp.mkdir() + af_source_path = Path("/home/samusevich/TerpeneMiner/data/alphafold_structs") + for pdb_standard_id in ["1ps1", "5eat", "3p5r", "P48449"]: + pdb_standard_file_path = af_source_path / f"{pdb_standard_id}.pdb" + copyfile(pdb_standard_file_path, pdb_directory_temp / f"{pdb_standard_id}.pdb") + + # Getting the ID + pdb_id = file.filename.split(".")[0] + pdb_id = re.sub(r'\(.*?\)', '', pdb_id) + pdb_id = "".join(pdb_id.replace("-", "").split()) + + # Define the path where the .pdb file will be saved + pdb_file_path = pdb_directory_temp / f"{pdb_id}.pdb" + + # Saving the ID into a csv file + id_filepath = f'{pdb_directory_temp / "dummy_id.csv"}' + with open(id_filepath, "a") as file: + file.writelines(f"ID\n{pdb_id}\n") + + # Save the content as a .pdb file + with open(pdb_file_path, "wb") as pdb_file: + pdb_file.write(file_contents) + temp_filepath_name = Path("data/alphafold_structs") / f"{pdb_id}.pdb" + if not temp_filepath_name.exists(): + copyfile(pdb_file_path, temp_filepath_name) + temp_filepath_name_to_delete = not temp_filepath_name.exists() + + domain_detections_path = f"_temp/filename_2_detected_domains_completed_confident_{pdb_id}.pkl" + detected_domain_structures_root = Path("_temp/detected_domains") + if not detected_domain_structures_root.exists(): + detected_domain_structures_root.mkdir() + os.system( + "python -m terpeneminer.src.structure_processing.domain_detections " + f'--needed-proteins-csv-path "{id_filepath}" ' + "--csv-id-column ID " + "--n-jobs 16 " + "--input-directory-with-structures _temp " + f"{'--is-bfactor-confidence ' if is_bfactor_confidence else ''}" + f'--detections-output-path "{domain_detections_path}" ' + f'--detected-regions-root-path _temp ' + f'--domains-output-path "{detected_domain_structures_root}" ' + "--store-domains " + "--recompute-existing-secondary-structure-residues " + "--do-not-store-intermediate-files" + ) + + with open(domain_detections_path, "rb") as file: + detected_domains = pickle.load(file) + + logger.info("Detected %d domains. Starting comparison to the known domains..", len(detected_domains)) + if detected_domains: + current_computation_id = uuid4() + comparison_results_path = f"_temp/filename_2_regions_vs_known_reg_dists_{current_computation_id}.pkl" + os.system( + "python -m terpeneminer.src.structure_processing.comparing_to_known_domains " + "--input-directory-with-structures data/alphafold_structs/ " + "--n-jobs 16 " + f'--domain-detections-path "{domain_detections_path}" ' + "--domain-detections-residues-path _temp/file_2_all_residues.pkl " + "--path-to-all-known-domains data/alphafold_structs/regions_completed_very_confident_all_ALL.pkl " + "--path-to-known-domains-subset data/domains_subset.pkl " + f'--pdb-filepath "{temp_filepath_name}" ' + f'--output-path "{comparison_results_path}"') + + logger.info("Compared detected domains to the known ones!") + + with open(comparison_results_path, "rb") as file: + comparison_results = pickle.load(file) + + print(comparison_results) + + diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh index 2562a8d..548cf0b 100644 --- a/scripts/setup_env.sh +++ b/scripts/setup_env.sh @@ -1,7 +1,7 @@ #!/bin/bash # Create env and install required packages -conda create -n terpene_miner python==3.10.0 scikit-learn==1.5.1 pandas==2.2.2 numpy==1.26.4 scipy==1.13.0 jupyter matplotlib seaborn pymol==3.0.2 pymol-psico==3.4.19 tmalign==20170708 -c schrodinger -c speleo3 -c conda-forge -y +conda create -n terpene_miner python==3.10.0 scikit-learn==1.5.1 pandas==2.2.2 numpy==1.26.4 scipy==1.13.0 jupyter matplotlib seaborn foldseek pymol==3.0.2 pymol-psico==3.4.19 tmalign==20170708 -c schrodinger -c speleo3 -c conda-forge -c bioconda -y conda activate terpene_miner pip install torch --index-url https://download.pytorch.org/whl/cu121 diff --git a/terpeneminer/configs/Blastp/against_wetlab_data/config.yaml b/terpeneminer/configs/Blastp/against_wetlab_data/config.yaml deleted file mode 100644 index f4afbce..0000000 --- a/terpeneminer/configs/Blastp/against_wetlab_data/config.yaml +++ /dev/null @@ -1,29 +0,0 @@ -run_against_wetlab: True -id_col_name: "Uniprot ID" -target_col_name: "SMILES_substrate_canonical_no_stereo" -split_col_name: "stratified_phylogeny_based_split_with_minor_products" -class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "precursor substr", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C", - "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "isTPS"] -optimize_hyperparams: false -random_state: 0 -hyperparam_dimensions: none -seq_col_name: "Amino acid sequence" -n_calls_hyperparams_opt: 0 -n_neighbours: 1 -e_threshold: 0.001 -n_jobs: 64 -pred_batch_size: 32 -neg_val: "Unknown" -negatives_sample_path: "data/sampled_id_2_seq.pkl" -tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: false -reuse_existing_partial_results: false -load_per_class_params_from: "" diff --git a/terpeneminer/configs/CLEAN.ignore/against_wetlab_data/config.yaml b/terpeneminer/configs/CLEAN.ignore/against_wetlab_data/config.yaml deleted file mode 100644 index 1a3e762..0000000 --- a/terpeneminer/configs/CLEAN.ignore/against_wetlab_data/config.yaml +++ /dev/null @@ -1,30 +0,0 @@ -run_against_wetlab: True -id_col_name: "Uniprot ID" -target_col_name: "SMILES_substrate_canonical_no_stereo" -split_col_name: "stratified_phylogeny_based_split_with_minor_products" -class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "precursor substr", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C", - "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "isTPS"] -optimize_hyperparams: false -random_state: 0 -hyperparam_dimensions: none -seq_col_name: "Amino acid sequence" -n_calls_hyperparams_opt: 0 -clean_installation_root: "/home/samusevich/CLEAN" -rhea2ec_link: "https://ftp.expasy.org/databases/rhea/tsv/rhea2ec.tsv" -rhea_reaction_smiles_link: "https://ftp.expasy.org/databases/rhea/tsv/rhea-reaction-smiles.tsv" -rhea_directions_link: "https://ftp.expasy.org/databases/rhea/tsv/rhea-directions.tsv" -clean_working_dir: "_clean_working_dir" -neg_val: "Unknown" -negatives_sample_path: "data/sampled_id_2_seq.pkl" -tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: false -reuse_existing_partial_results: false -load_per_class_params_from: "" diff --git a/terpeneminer/configs/CLEAN.ignore/with_minor_reactions/config.yaml b/terpeneminer/configs/CLEAN/with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/CLEAN.ignore/with_minor_reactions/config.yaml rename to terpeneminer/configs/CLEAN/with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions/config.yaml b/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions/config.yaml deleted file mode 100644 index d3205d6..0000000 --- a/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions/config.yaml +++ /dev/null @@ -1,63 +0,0 @@ -id_col_name: "Uniprot ID" -target_col_name: "SMILES_substrate_canonical_no_stereo" -split_col_name: "stratified_phylogeny_based_split_with_minor_products" -class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "precursor substr", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C", - "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "isTPS"] -optimize_hyperparams: true -random_state: 0 -n_calls_hyperparams_opt: 350 -hyperparam_dimensions: - reg_lambda: - type: "float" - args: [1, 100.0, "uniform"] - gamma: - type: "float" - args: [1.0e-6, 0.5, "uniform"] - max_depth: - type: "int" - args: [1, 100, "uniform"] - subsample: - type: "float" - args: [0.7, 1.0, "log-uniform"] - colsample_bytree: - type: "float" - args: [0.7, 1.0, "log-uniform"] - scale_pos_weight: - type: "float" - args: [1.0e-6, 30.0, 'uniform'] - min_child_weight: - type: "int" - args: [1, 10] - n_estimators: - type: "int" - args: [20, 400, "uniform"] - max_train_negs_proportion: - type: "float" - args: [0.5, 0.99, "log-uniform" ] -n_jobs: -1 -objective: "binary:logistic" -booster: "gbtree" -reg_lambda: 1 -gamma: 2.0e-6 -max_depth: 6 -subsample: 1.0 -colsample_bytree: 1.0 -scale_pos_weight: 1 -min_child_weight: 1 -n_estimators: 50 -max_train_negs_proportion: 0.98 -neg_val: "Unknown" -save_trained_model: true -negatives_sample_path: "data/sampled_id_2_seq.pkl" -tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: true -reuse_existing_partial_results: false -load_per_class_params_from: "" diff --git a/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions_global_tuning/config.yaml deleted file mode 100644 index b4a25bf..0000000 --- a/terpeneminer/configs/DomainsXgb.ignore/with_minor_reactions_global_tuning/config.yaml +++ /dev/null @@ -1,63 +0,0 @@ -id_col_name: "Uniprot ID" -target_col_name: "SMILES_substrate_canonical_no_stereo" -split_col_name: "stratified_phylogeny_based_split_with_minor_products" -class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "precursor substr", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C", - "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "isTPS"] -optimize_hyperparams: true -random_state: 0 -n_calls_hyperparams_opt: 350 -hyperparam_dimensions: - reg_lambda: - type: "float" - args: [1, 100.0, "uniform"] - gamma: - type: "float" - args: [1.0e-6, 0.5, "uniform"] - max_depth: - type: "int" - args: [1, 100, "uniform"] - subsample: - type: "float" - args: [0.7, 1.0, "log-uniform"] - colsample_bytree: - type: "float" - args: [0.7, 1.0, "log-uniform"] - scale_pos_weight: - type: "float" - args: [1.0e-6, 30.0, 'uniform'] - min_child_weight: - type: "int" - args: [1, 10] - n_estimators: - type: "int" - args: [20, 400, "uniform"] - max_train_negs_proportion: - type: "float" - args: [0.5, 0.99, "log-uniform" ] -n_jobs: -1 -objective: "binary:logistic" -booster: "gbtree" -reg_lambda: 1 -gamma: 2.0e-6 -max_depth: 6 -subsample: 1.0 -colsample_bytree: 1.0 -scale_pos_weight: 1 -min_child_weight: 1 -n_estimators: 50 -max_train_negs_proportion: 0.98 -neg_val: "Unknown" -save_trained_model: true -negatives_sample_path: "data/sampled_id_2_seq.pkl" -tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: false -reuse_existing_partial_results: false -load_per_class_params_from: "" diff --git a/terpeneminer/configs/HMM/against_wetlab_data/config.yaml b/terpeneminer/configs/HMM/against_wetlab_data/config.yaml deleted file mode 100644 index 0d84fe3..0000000 --- a/terpeneminer/configs/HMM/against_wetlab_data/config.yaml +++ /dev/null @@ -1,30 +0,0 @@ -run_against_wetlab: True -id_col_name: "Uniprot ID" -target_col_name: "SMILES_substrate_canonical_no_stereo" -split_col_name: "stratified_phylogeny_based_split_with_minor_products" -class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "precursor substr", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C", - "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", - "isTPS"] -optimize_hyperparams: false -random_state: 0 -hyperparam_dimensions: none -seq_col_name: "Amino acid sequence" -n_calls_hyperparams_opt: 0 -search_e_threshold: 0.01 -n_jobs: 64 -pred_batch_size: 32 -zero_conf_level: 0.1 -group_column_name: "Kingdom" -neg_val: "Unknown" -negatives_sample_path: "data/sampled_id_2_seq.pkl" -tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: false -reuse_existing_partial_results: false -load_per_class_params_from: "" diff --git a/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml similarity index 100% rename from terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore/config.yaml rename to terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/ankh_base_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/ankh_base_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/ankh_base_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/ankh_base_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/ankh_large_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/ankh_large_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/ankh_large_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/ankh_large_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/esm-1v_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/esm-1v_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/esm-1v_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/esm-1v_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/esm-2_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/esm-2_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/esm-2_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/esm-2_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/main_config.yaml b/terpeneminer/configs/PlmXgb/main_config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/main_config.yaml rename to terpeneminer/configs/PlmXgb/main_config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/tps_ankh_base_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/tps_ankh_base_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/tps_ankh_base_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/tps_ankh_base_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/tps_esm-1v-subseq_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/tps_esm-1v-subseq_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/tps_esm-1v-subseq_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/tps_esm-1v-subseq_with_minor_reactions/config.yaml diff --git a/terpeneminer/configs/PlmXgb.ignore/tps_esm-1v_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmXgb/tps_esm-1v_with_minor_reactions/config.yaml similarity index 100% rename from terpeneminer/configs/PlmXgb.ignore/tps_esm-1v_with_minor_reactions/config.yaml rename to terpeneminer/configs/PlmXgb/tps_esm-1v_with_minor_reactions/config.yaml diff --git a/terpeneminer/src/evaluation/plotting.py b/terpeneminer/src/evaluation/plotting.py index 6d96f02..971750c 100644 --- a/terpeneminer/src/evaluation/plotting.py +++ b/terpeneminer/src/evaluation/plotting.py @@ -162,6 +162,7 @@ def plot_boxplots_per_type( class_list = [] val_list = [] + print(model_2_class_2_metric_vals.keys()) present_type_names = set() for model_i, model_name in enumerate(models): class_dicts = model_2_class_2_metric_vals[model_name] diff --git a/terpeneminer/src/models/__init__.py b/terpeneminer/src/models/__init__.py index eb40e6b..b701a53 100644 --- a/terpeneminer/src/models/__init__.py +++ b/terpeneminer/src/models/__init__.py @@ -7,4 +7,4 @@ from .plm_domains_mlp import PlmDomainsMLP from .plm_domains_logistic_regression import PlmDomainsLogisticRegression -from .baselines import Blastp, Foldseek, HMM, PfamSUPFAM # , CLEAN +from .baselines import Blastp, Foldseek, HMM, PfamSUPFAM, CLEAN diff --git a/terpeneminer/src/models/baselines/__init__.py b/terpeneminer/src/models/baselines/__init__.py index 9249af5..e6de26d 100644 --- a/terpeneminer/src/models/baselines/__init__.py +++ b/terpeneminer/src/models/baselines/__init__.py @@ -3,5 +3,5 @@ from .foldseek import Foldseek from .hmm import HMM -# from .CLEAN import CLEAN +from .CLEAN import CLEAN from .pfam_supfam import PfamSUPFAM diff --git a/terpeneminer/src/screening/gather_classifier_checkpoints.py b/terpeneminer/src/screening/gather_classifier_checkpoints.py index ef0530f..360328b 100644 --- a/terpeneminer/src/screening/gather_classifier_checkpoints.py +++ b/terpeneminer/src/screening/gather_classifier_checkpoints.py @@ -94,15 +94,21 @@ def parse_args() -> argparse.Namespace: with open("data/domains_subset.pkl", "rb") as file: feat_indices_subset = pickle.load(file)[-1] - domain_module_id_2_dist_matrix_index_subset = {domain_id: [i for i in indices if i in feat_indices_subset] - for domain_id, indices in - domain_module_id_2_dist_matrix_index.items() - if len([i for i in indices if i in feat_indices_subset])} + domain_module_id_2_dist_matrix_index_subset = { + domain_id: [i for i in indices if i in feat_indices_subset] + for domain_id, indices in domain_module_id_2_dist_matrix_index.items() + if [i for i in indices if i in feat_indices_subset] + } feat_idx_2_module_id = {} - for module_id, feat_indices in domain_module_id_2_dist_matrix_index_subset.items(): + for ( + module_id, + feat_indices, + ) in domain_module_id_2_dist_matrix_index_subset.items(): for feat_idx in feat_indices: feat_idx_2_module_id[feat_idx] = module_id - order_of_domain_modules = [feat_idx_2_module_id[feat_i] for feat_i in model.allowed_feat_indices] + order_of_domain_modules = [ + feat_idx_2_module_id[feat_i] for feat_i in model.allowed_feat_indices + ] model.classifier.order_of_domain_modules = order_of_domain_modules classifiers.append(model.classifier) diff --git a/terpeneminer/src/structure_processing/comparing_to_known_domains.py b/terpeneminer/src/structure_processing/comparing_to_known_domains.py new file mode 100644 index 0000000..3d364d9 --- /dev/null +++ b/terpeneminer/src/structure_processing/comparing_to_known_domains.py @@ -0,0 +1,206 @@ +"""This script detects TPS domains in protein structures""" + +import os +import argparse +from functools import partial +from multiprocessing import Pool +from pathlib import Path +from collections import defaultdict +import pickle +import time +import logging +import subprocess +from datetime import datetime +from shutil import copyfile +from uuid import uuid4 + +from pymol import cmd # type: ignore +import pandas as pd # type: ignore +from Bio import PDB # type: ignore +from tqdm.auto import tqdm # type: ignore +from terpeneminer.src.structure_processing.structural_algorithms import ( + SUPPORTED_DOMAINS, + DOMAIN_2_THRESHOLD, + MappedRegion, + get_alignments, + get_remaining_residues, + get_mapped_regions_with_surroundings_parallel, + compress_selection_list, +get_pairwise_tmscore +) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +def parse_args() -> argparse.Namespace: + """ + This function parses arguments + :return: current argparse.Namespace + """ + parser = argparse.ArgumentParser( + description="A script to compare detected TPS domains to the known ones" + ) + parser.add_argument( + "--input-directory-with-structures", + help="A directory containing PDB structures", + type=str, + default="data/alphafold_structs/", + ) + parser.add_argument("--n-jobs", type=int, default=16) + parser.add_argument( + "--domain-detections-path", + help="A path to a dictionary with the detected domains", + type=str, + default="_temp/filename_2_detected_domains_completed_confident.pkl", + ) + parser.add_argument( + "--domain-detections-residues-path", + help="A path to a dictionary with the secondary-structure residues per file", + type=str, + default="_temp/file_2_all_residues.pkl", + ) + parser.add_argument("--path-to-all-known-domains", type=str, default="data/alphafold_structs/regions_completed_very_confident_all_ALL.pkl") + parser.add_argument("--path-to-known-domains-subset", type=str, default="data/domains_subset.pkl") + parser.add_argument("--number-of-workers", type=int, default=16) + parser.add_argument("--output-path", type=str, default="_temp/filename_2_regions_vs_known_reg_dists.pkl") + parser.add_argument("--pdb-filepath", type=str, default="") + return parser.parse_args() + + +def compute_distances_to_known_regions( + segment_i: int, + current_region_segments: list[list[tuple[str, MappedRegion]]], + region_i: tuple[str, MappedRegion], + filename_2_all_residues: dict, + computation_id: str +): + """ + Computes pairwise distances (TM-scores) between a given region and all other regions within the same segment + and saves the results to a pickle file. + + This function calculates the TM-score (a structural similarity score) between a specific region + (`region_i`) and all regions in the same segment (`segment_i`) as defined by the `region_segments`. + If a results file already exists for the given region and segment, it skips the computation. + Otherwise, the function computes the distances, stores them in a list, and writes the results + to a temporary pickle file named based on the segment index and computation ID. + + Parameters + ---------- + segment_i : int + The index of the segment for which the distances should be computed. + current_region_segments : list[list[tuple[str, MappedRegion]]] + A list of segments, where each segment is a list of tuples. Each tuple consists of a string identifier + and a `MappedRegion` object representing a region within the segment. + region_i : tuple[str, MappedRegion] + A tuple containing a string identifier and a `MappedRegion` object for the region of interest + (the region whose distances to all other regions in the same segment are to be computed). + filename_2_all_residues : dict + A dictionary mapping filenames to residue data, which is used during the TM-score calculation. + computation_id : str + A unique string identifier for the computation, used to construct the output filename. + + Returns + ------- + list of tuple + A list of tuples, where each tuple contains the module ID of the compared region and the computed + TM-score distance to `region_i`. If the result already exists, the function will return `None`. + + Notes + ----- + The function checks if the results file already exists and skips computation if so. + The results are stored in a pickle file with the format: + `_temp_tm_region__segment__.pkl`. + + The TM-score computation relies on the external function `get_pairwise_tmscore`, + which takes in a PyMOL command object (`cmd`), the regions being compared, and residue data. + """ + computation_results_path = f"_temp_tm_region_{region_i[1].module_id}_segment_{segment_i}_{computation_id}.pkl" + if os.path.exists(computation_results_path): + return + results = [] + for region_j in current_region_segments[segment_i]: + tmscore = get_pairwise_tmscore(cmd, region_i, region_j, filename_2_all_residues) + results.append((region_j[1].module_id, tmscore)) + with open(computation_results_path, "wb") as write_file: + pickle.dump(results, write_file) + return results + + +if __name__ == "__main__": + args = parse_args() + + # loading secondary structure residues + input_directory = Path(args.input_directory_with_structures) + with open(args.domain_detections_residues_path, "rb") as f: + file_2_all_residues = pickle.load(f) + all_secondary_structure_residues_path = input_directory / "file_2_all_residues.pkl" + with open(all_secondary_structure_residues_path, "rb") as f: + file_2_all_residues_all = pickle.load(f) + file_2_all_residues.update(file_2_all_residues_all) + + # loading all known domains + with open(args.path_to_all_known_domains, "rb") as file: + regions_completed_very_confident_all_ALL = pickle.load(file) + with open(args.path_to_known_domains_subset, "rb") as file: + dom_subset, feat_indices_subset = pickle.load(file) + regions_all = [reg for reg in regions_completed_very_confident_all_ALL if reg[1].module_id in dom_subset] + + # preparing the data for parallel processing + n_workers = args.number_of_workers + temp_struct_name = input_directory / Path(args.pdb_filepath).name + if args.pdb_filepath != str(temp_struct_name): + copyfile(args.pdb_filepath, temp_struct_name) + + # loading detected domains + with open(args.domain_detections_path, "rb") as file: + filename_2_detected_regions_completed_confident = pickle.load(file) + filename_2_regions_vs_known_reg_dists = {} + type_2_regions = dict() + + cwd = os.getcwd() + os.chdir(input_directory) + + for filename, regions in filename_2_detected_regions_completed_confident.items(): + region_2_known_reg_dists = defaultdict(list) + for region in regions: + domain_type = region.domain + if domain_type not in type_2_regions: + type_2_regions[domain_type] = [el for el in regions_all if el[1].domain == domain_type] + regions_all_current_type = type_2_regions[domain_type] + + regions_segment_len = len(regions_all_current_type) // n_workers + 1 + region_segments = [] + start_i = 0 + print('len(regions_all): ', len(regions_all)) + while start_i < len(regions_all): + region_segments.append(regions_all_current_type[start_i:start_i + regions_segment_len]) + print() + start_i += regions_segment_len + + + print('region_segments cout: ', sum([len(x) for x in region_segments])) + + computations_id = str(uuid4()) + partial_dist_compute = partial( + compute_distances_to_known_regions, + region_i=(filename, region), + current_region_segments=region_segments, + filename_2_all_residues=file_2_all_residues, + computation_id=computations_id + ) + print('list(range(len(region_segments))): ', list(range(len(region_segments)))) + with Pool(n_workers - 2) as p: + list_of_distances_list = p.map(partial_dist_compute, list(range(len(region_segments)))) + for results_path in Path('.').glob(f'*{computations_id}.pkl'): + with open(results_path, "rb") as f: + results_partial = pickle.load(f) + region_2_known_reg_dists[region.module_id].extend(results_partial) + filename_2_regions_vs_known_reg_dists[filename] = region_2_known_reg_dists + + os.chdir(cwd) + + with open(args.output_path, "wb") as file: + pickle.dump(filename_2_regions_vs_known_reg_dists, file) + + os.remove(temp_struct_name) diff --git a/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py b/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py new file mode 100644 index 0000000..88ac61d --- /dev/null +++ b/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py @@ -0,0 +1,75 @@ +"""This script detects TPS domains in protein structures""" + +import os +import argparse +from pathlib import Path +from collections import defaultdict +import pickle +import logging +import subprocess +from uuid import uuid4 +from shutil import rmtree + +from pymol import cmd # type: ignore +import pandas as pd # type: ignore +from Bio import PDB # type: ignore +from tqdm.auto import tqdm # type: ignore + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +def parse_args() -> argparse.Namespace: + """ + This function parses arguments + :return: current argparse.Namespace + """ + parser = argparse.ArgumentParser( + description="A script to compare detected TPS domains to the known ones" + ) + parser.add_argument( + "--known-domain-structures-root", + help="A directory containing structures of known domains", + type=str, + default="data/detected_domains/all", + ) + parser.add_argument( + "--detected-domain-structures-root", + help="A path to new detected domain structures", + type=str, + default="_temp/detected_domains", + ) + parser.add_argument("--path-to-known-domains-subset", type=str, default="data/domains_subset.pkl") + parser.add_argument("--output-path", type=str, default="_temp/filename_2_regions_vs_known_reg_dists.pkl") + parser.add_argument("--pdb-id", type=str, default="") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + working_dir = Path('_temp') + if not working_dir.exists(): + working_dir.mkdir() + tsv_path = working_dir / f'aln_all_domains_vs_all_{uuid4()}.tsv' + tmp_path = working_dir / f'tmp_all_{uuid4()}' + foldseek_comparison_output = subprocess.check_output( + f'foldseek easy-search {args.detected_domain_structures_root} {args.known_domain_structures_root} {tsv_path} {tmp_path} --max-seqs 3000 -e 0.1 --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,alntmscore'.split()) + df_foldseek = pd.read_csv(tsv_path, sep='\t', header=None, + names=['query', 'target', 'fident', 'alnlen', 'mismatch', 'gapopen', 'qstart', 'qend', + 'tstart', 'tend', 'evalue', 'bits', 'alntmscore']) + + region_2_known_reg_dists = defaultdict(list) + with open(args.path_to_known_domains_subset, "rb") as file: + dom_subset, feat_indices_subset = pickle.load(file) + + for _, row in df_foldseek.iterrows(): + if row['target'] in dom_subset: + region_2_known_reg_dists[row['query']].append([row['target'], float(row['alntmscore'])]) + filename_2_regions_vs_known_reg_dists = {args.pdb_id: region_2_known_reg_dists} + + os.remove(tsv_path) + rmtree(tmp_path) + + with open(args.output_path, "wb") as file: + pickle.dump(filename_2_regions_vs_known_reg_dists, file) diff --git a/terpeneminer/src/structure_processing/domain_detections.py b/terpeneminer/src/structure_processing/domain_detections.py index 957c4b7..070c488 100644 --- a/terpeneminer/src/structure_processing/domain_detections.py +++ b/terpeneminer/src/structure_processing/domain_detections.py @@ -9,9 +9,9 @@ import logging import subprocess from datetime import datetime +from pymol import cmd # type: ignore import pandas as pd # type: ignore from Bio import PDB # type: ignore -from pymol import cmd # type: ignore from tqdm.auto import tqdm # type: ignore from terpeneminer.src.structure_processing.structural_algorithms import ( SUPPORTED_DOMAINS, @@ -28,6 +28,11 @@ logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) +if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) def parse_args() -> argparse.Namespace: @@ -43,6 +48,11 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--csv-id-column", + type=str, + default=None, + ) parser.add_argument( "--input-directory-with-structures", help="A directory containing PDB structures", @@ -61,12 +71,16 @@ def parse_args() -> argparse.Namespace: help="A flag to store detected domains", action="store_true", ) + parser.add_argument("--detected-regions-root-path", type=str, default="data/alphafold_structs/") parser.add_argument( "--domains-output-path", help="A root path for saving the detected domains to", type=str, default="data/detected_domains", ) + parser.add_argument("--is-bfactor-confidence", action="store_true") + parser.add_argument("--do-not-store-intermediate-files", action="store_true") + parser.add_argument("--recompute-existing-secondary-structure-residues", action="store_true") return parser.parse_args() @@ -108,7 +122,7 @@ def detect_domains_roughly( ) execution_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - if len(file_2_tmscore_residues_domain): + if len(file_2_tmscore_residues_domain) and not args.do_not_store_intermediate_files: plot_aligned_domains( file_2_tmscore_residues_domain, title=f"{domain_this} domain detections", @@ -136,12 +150,13 @@ def detect_domains_roughly( "Detected %d %s domains", len(regions_of_possible_domain), domain_this ) - with open( - output_root - / f"final_regions_{domain_this}s_tm_ALL_{execution_timestamp}.pkl", - "wb", - ) as result_file: - pickle.dump(regions_of_possible_domain, result_file) + if not args.do_not_store_intermediate_files: + with open( + output_root + / f"final_regions_{domain_this}s_tm_ALL_{execution_timestamp}.pkl", + "wb", + ) as result_file: + pickle.dump(regions_of_possible_domain, result_file) domain_2_possible_regions[domain_this] = regions_of_possible_domain @@ -178,20 +193,21 @@ def detect_domains_roughly( ) regions_of_possible_2nd_alphas.append(new_tuple) execution_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - if len(file_2_tmscore_residues_2nd_alpha): - plot_aligned_domains( - file_2_tmscore_residues_2nd_alpha, - title="2nd alpha domain detections", - save_path=output_root - / f"2nd_alpha_detections_{execution_timestamp}.png", - ) + if not args.do_not_store_intermediate_files: + if len(file_2_tmscore_residues_2nd_alpha): + plot_aligned_domains( + file_2_tmscore_residues_2nd_alpha, + title="2nd alpha domain detections", + save_path=output_root + / f"2nd_alpha_detections_{execution_timestamp}.png", + ) - with open( - output_root - / f"final_regions_2nd_alphas_tm_ALL_{execution_timestamp}.pkl", - "wb", - ) as result_file: - pickle.dump(regions_of_possible_2nd_alphas, result_file) + with open( + output_root + / f"final_regions_2nd_alphas_tm_ALL_{execution_timestamp}.pkl", + "wb", + ) as result_file: + pickle.dump(regions_of_possible_2nd_alphas, result_file) domain_2_possible_regions[domain_this] += regions_of_possible_2nd_alphas file_2_known_regions: dict = defaultdict(list) @@ -319,13 +335,14 @@ def get_confident_af_residues( # reading the needed proteins if args.needed_proteins_csv_path is not None: proteins_df = pd.read_csv(args.needed_proteins_csv_path) - relevant_protein_ids = set(proteins_df["Uniprot ID"].values) + relevant_protein_ids = set(proteins_df[args.csv_id_column].values) input_directory = Path(args.input_directory_with_structures) all_secondary_structure_residues_path = input_directory / "file_2_all_residues.pkl" - secondary_structure_computation_output = subprocess.check_output( - f"python -m terpeneminer.src.structure_processing.compute_secondary_structure_residues --input-directory {input_directory} --output-path {all_secondary_structure_residues_path}".split(), - ) + if not all_secondary_structure_residues_path.exists() or args.recompute_existing_secondary_structure_residues: + secondary_structure_computation_output = subprocess.check_output( + f"python -m terpeneminer.src.structure_processing.compute_secondary_structure_residues --input-directory {input_directory} --output-path {all_secondary_structure_residues_path}".split(), + ) with open(all_secondary_structure_residues_path, "rb") as file: file_2_all_residues = pickle.load(file) @@ -333,7 +350,7 @@ def get_confident_af_residues( cwd = os.getcwd() os.chdir(input_directory) blacklist_files = ( - {"1ps1.pdb", "5eat.pdb", "3p5r.pdb", "P48449.pdb"} + {"1ps1.pdb", "5eat.pdb", "3p5r.pdb"} .union({f"{domain}.pdb" for domain in SUPPORTED_DOMAINS}) .union({f"{domain}_object.pdb" for domain in SUPPORTED_DOMAINS}) ) @@ -430,23 +447,26 @@ def get_confident_af_residues( # Getting confident residues filename_2_known_regions_completed_confident = {} for filename, regions in tqdm(filename_2_known_regions_completed.items()): - conf_residues = get_confident_af_residues(filename) - new_regions = [] - for mapped_region_init in regions: - new_residues_mapping = { - res: res_dom - for res, res_dom in mapped_region_init.residues_mapping.items() - if res in conf_residues - } - new_regions.append( - MappedRegion( # pylint: disable=R0801 - module_id=mapped_region_init.module_id, - domain=mapped_region_init.domain, - tmscore=mapped_region_init.tmscore, - residues_mapping=new_residues_mapping, + if args.is_bfactor_confidence: + conf_residues = get_confident_af_residues(filename) + new_regions = [] + for mapped_region_init in regions: + new_residues_mapping = { + res: res_dom + for res, res_dom in mapped_region_init.residues_mapping.items() + if res in conf_residues + } + new_regions.append( + MappedRegion( # pylint: disable=R0801 + module_id=mapped_region_init.module_id, + domain=mapped_region_init.domain, + tmscore=mapped_region_init.tmscore, + residues_mapping=new_residues_mapping, + ) ) - ) - filename_2_known_regions_completed_confident[filename] = new_regions + filename_2_known_regions_completed_confident[filename] = new_regions + else: + filename_2_known_regions_completed_confident[filename] = regions # for further convenience, storing also regions separately per domain domain_2_regions_completed_confident = defaultdict(list) @@ -461,31 +481,26 @@ def get_confident_af_residues( (filename, region) ) - with open("regions_completed_very_confident_all_ALL.pkl", "wb") as f: + with open(Path(cwd) / args.detected_regions_root_path / "regions_completed_very_confident_all_ALL.pkl", "wb") as f: pickle.dump(domain_2_regions_completed_confident["all"], f) for domain_name in SUPPORTED_DOMAINS: - with open(f"regions_completed_very_confident_{domain_name}_ALL.pkl", "wb") as f: + with open(Path(cwd) / args.detected_regions_root_path / f"regions_completed_very_confident_{domain_name}_ALL.pkl", "wb") as f: pickle.dump(domain_2_regions_completed_confident[domain_name], f) - os.chdir(cwd) - # save the confident regions - with open(args.detections_output_path, "wb") as f: - pickle.dump(filename_2_known_regions_completed_confident, f) - if args.store_domains: - domains_output_path = Path(args.domains_output_path) + domains_output_path = Path(cwd) / args.domains_output_path if not domains_output_path.exists(): domains_output_path.mkdir(parents=True) for domain_name in SUPPORTED_DOMAINS: PATH = domains_output_path / f"tps_domain_detections_{domain_name}" - if not os.path.exists(PATH): - os.mkdir(PATH) + if not PATH.exists(): + PATH.mkdir(parents=True) for filename, protein_regions in tqdm( filename_2_known_regions_completed_confident.items() ): for region in protein_regions: - PATH = Path(f"../tps_domain_detections_{region.domain}") + PATH = Path(domains_output_path / f"tps_domain_detections_{region.domain}") mapped_residues = list(set(region.residues_mapping.keys())) cmd.delete(filename) cmd.load(f"{filename}.pdb") @@ -494,4 +509,11 @@ def get_confident_af_residues( f"{filename} & resi {compress_selection_list(mapped_residues)}", ) cmd.save(f"{PATH}/{region.module_id}.pdb", f"{region.module_id}") + print('saving to : ', f"{domains_output_path}/{region.module_id}.pdb") + cmd.save(f"{domains_output_path}/{region.module_id}.pdb", f"{region.module_id}") cmd.delete(filename) + + os.chdir(cwd) + # save the confident regions + with open(args.detections_output_path, "wb") as f: + pickle.dump(filename_2_known_regions_completed_confident, f) diff --git a/terpeneminer/src/structure_processing/predict_domain_types.py b/terpeneminer/src/structure_processing/predict_domain_types.py new file mode 100644 index 0000000..584cce2 --- /dev/null +++ b/terpeneminer/src/structure_processing/predict_domain_types.py @@ -0,0 +1,89 @@ +"""This script predicts domain types and novelty based on the TMScore distances between the detected domains and the known ones.""" +from collections import defaultdict + +from sklearn.ensemble import RandomForestClassifier # type: ignore +import pickle +import logging +import argparse +import numpy as np # type: ignore +from sklearn.model_selection import train_test_split # type: ignore +from sklearn.metrics import average_precision_score # type: ignore +from sklearn.preprocessing import MultiLabelBinarizer # type: ignore + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) +if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + +def parse_args() -> argparse.Namespace: + """ + This function parses arguments + :return: current argparse.Namespace + """ + parser = argparse.ArgumentParser( + description="A script to detect novel and known domains" + ) + parser.add_argument( + "--tps-classifiers-path", + type=str, + default="data/classifier_domain_and_plm_checkpoints.pkl", + ) + parser.add_argument( + "--domain-classifiers-path", + type=str, + default="data/domain_type_predictors.pkl", + ) + parser.add_argument("--path-to-domain-comparisons", type=str) + parser.add_argument("--id", type=str) + parser.add_argument("--output-path", type=str) + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + with open(args.tps_classifiers_path, "rb") as file: + tps_classifiers = pickle.load(file) + with open(args.domain_classifiers_path, "rb") as file: + novel_domain_detectors, domain_type_classifiers = pickle.load(file) + with open(args.path_to_domain_comparisons, "rb") as file: + comparison_results = pickle.load(file) + comparison_results = comparison_results[args.id] + + domain_id_2_predictions = {} + for new_protein_domain_id, domain_comparisons_result in comparison_results.items(): + known_domain_id_2_tmscore = dict(domain_comparisons_result) + is_novel_predictions = [] + domain_type_2_pred_values = defaultdict(list) + for FOLD in range(5): + logger.info('Processing fold: %d', FOLD) + classifier = tps_classifiers[FOLD] + fold_domains_order = classifier.order_of_domain_modules + feat_vector = np.zeros(len(fold_domains_order)) + for i, known_domain_id in enumerate(fold_domains_order): + if known_domain_id in known_domain_id_2_tmscore: + feat_vector[i] = known_domain_id_2_tmscore[known_domain_id] + X_np = np.array(feat_vector).reshape(1, -1) + + classifier = domain_type_classifiers[FOLD] + domain_type_pred = classifier.predict_proba(X_np) + for class_name, pred_val in zip(classifier.classes_, domain_type_pred[0]): + domain_type_2_pred_values[class_name].append(pred_val) + domain_type_2_pred = {dom_type: np.mean(vals) for dom_type, vals in domain_type_2_pred_values.items()} + # max_pred = -float('inf') + # gen_type_2_pred = {} + # for domain_type, type_preds in domain_type_2_pred.items(): + # if + # gen_type_2_pred[domain_type_gen] = novel_domain_detectors[domain_type].predict_proba(X_np)[0][1] + # for type_preds in domain_type_2_pred.values(): + # max_pred = max(max_pred, np.max(type_preds)) + # domain_type_2_pred.update({"novel": 1 - max_pred}) + domain_id_2_predictions[new_protein_domain_id] = domain_type_2_pred + + with open(args.output_path, "wb") as file: + pickle.dump(domain_id_2_predictions, file) + + + + diff --git a/terpeneminer/src/structure_processing/structural_algorithms.py b/terpeneminer/src/structure_processing/structural_algorithms.py index 28c6cad..ff24f68 100644 --- a/terpeneminer/src/structure_processing/structural_algorithms.py +++ b/terpeneminer/src/structure_processing/structural_algorithms.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Optional from uuid import uuid4 +import logging import matplotlib.pyplot as plt # type: ignore import numpy as np # type: ignore @@ -19,6 +20,14 @@ from scipy.spatial import KDTree # type: ignore from tqdm.auto import tqdm # type: ignore +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) +if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + SUPPORTED_DOMAINS = {"alpha", "beta", "gamma", "delta", "epsilon"} DOMAIN_2_THRESHOLD = { "beta": (0.6, 50), @@ -78,7 +87,7 @@ def prepare_domain(pymol_cmd, domain_name: str) -> tuple: required_file = domain_2_standard[domain_name] if not exists_in_pymol(pymol_cmd, required_file): if not os.path.exists(f"{required_file}.pdb"): - raise FileNotFoundError + raise FileNotFoundError(f"{required_file}.pdb while being in {os.getcwd()}") pymol_cmd.load(f"{required_file}.pdb") if "_" in domain_name: @@ -179,7 +188,6 @@ def compute_full_mapping( obj_residues_set = set(map(int, file_2_all_residues[larger_obj])) residues_mapping_full = {} for domain_res in sorted_domain_residues: - # print('domain_res', domain_res) if domain_res in obj_res_2_mapped_shift: shift = obj_res_2_mapped_shift[domain_res] else: @@ -192,9 +200,7 @@ def compute_full_mapping( ) ) mapped_res = int(domain_res) + shift - # print('mapped_res', mapped_res) if mapped_res not in mapped_residues and mapped_res in obj_residues_set: - # print('added!') residues_mapping_full[mapped_res] = domain_res _temp1.append(domain_res) _temp2.append(mapped_res) @@ -262,7 +268,7 @@ def get_super_res_alignment( if not exists_in_pymol(pymol_cmd, larger_obj): if not os.path.exists(f"{larger_obj}.pdb"): - raise FileNotFoundError + raise FileNotFoundError(f"{larger_obj}.pdb while being in {os.getcwd()}") pymol_cmd.load(f"{larger_obj}.pdb") file_2_all_residues[larger_obj] = get_secondary_structure_residues_set( larger_obj, pymol_cmd @@ -494,7 +500,7 @@ def get_pairwise_tmscore( for filename in [filename_1, filename_2]: if not exists_in_pymol(pymol_cmd, filename): if not os.path.exists(f"{filename}.pdb"): - raise FileNotFoundError + raise FileNotFoundError(f"{filename}.pdb while being in {os.getcwd()}") pymol_cmd.load(f"{filename}.pdb") region_residues_1 = set(fill_short_gaps(set(region_1.residues_mapping.keys()))) @@ -650,6 +656,7 @@ def get_alignments( pdb_filenames = [filepath.stem for filepath in pdb_filepaths] with Pool(n_jobs) as pool: list_of_alignment_results = pool.map(align_partial, pdb_filenames) + file_2_tmscore_residues = defaultdict(list) for pdb_path, (tmscore, residues_mapping) in zip( @@ -886,7 +893,7 @@ def get_mapped_regions_with_surroundings( if not exists_in_pymol(cmd, filename): if not os.path.exists(f"{filename}.pdb"): - raise FileNotFoundError(f"{filename}.pdb") + raise FileNotFoundError(f"{filename}.pdb while being in {os.getcwd()}") cmd.load(f"{filename}.pdb") # for each mapped region, compute alpha helixes diff --git a/terpeneminer/src/structure_processing/train_domain_type_classifiers.py b/terpeneminer/src/structure_processing/train_domain_type_classifiers.py new file mode 100644 index 0000000..15bec95 --- /dev/null +++ b/terpeneminer/src/structure_processing/train_domain_type_classifiers.py @@ -0,0 +1,91 @@ +"""This script trains domain type classifiers and novelty detectors based on the TMScore distances between the detected domains and the known ones.""" + + +from sklearn.ensemble import RandomForestClassifier # type: ignore +import pickle +import logging +import numpy as np # type: ignore +from sklearn.model_selection import train_test_split # type: ignore +from sklearn.metrics import average_precision_score # type: ignore +from sklearn.preprocessing import MultiLabelBinarizer # type: ignore + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) +if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + +if __name__ == "__main__": + + with open('data/classifier_domain_and_plm_checkpoints.pkl', 'rb') as file: + fold_classifiers = pickle.load(file) + with open("data/clustering__domain_dist_based_features.pkl", "rb") as file: + ( + feats_dom_dists, + all_ids_list_dom, + uniid_2_column_ids, + domain_module_id_2_dist_matrix_index, + ) = pickle.load(file) + with open("data/domains_subset.pkl", "rb") as file: + dom_subset, feat_indices_subset = pickle.load(file) + with open('data/domain_module_id_2_domain_type.pkl', 'rb') as file: + domain_module_id_2_domain_type = pickle.load(file) + with open("data/precomputed_tmscores.pkl", "rb") as file: + regions_ids_2_tmscore = pickle.load(file) + + + domain_type_classifiers = [] + novel_domain_detectors = [] + + for FOLD in range(5): + logger.info('Processing fold: %d', FOLD) + classifier = fold_classifiers[FOLD] + new_fold_domains = [module_id for module_id in domain_module_id_2_dist_matrix_index.keys() if module_id not in classifier.order_of_domain_modules] + ref_types = {domain_module_id_2_domain_type[mod_id] for mod_id in classifier.order_of_domain_modules} + y = np.array([domain_module_id_2_domain_type[mod_id] for mod_id in new_fold_domains]) + y_is_novel = np.array([int(dom_type not in ref_types) for dom_type in y]) + + X_list = [] + + for mod_id in new_fold_domains: + dists_current = [] + for ref_mod_id in classifier.order_of_domain_modules: + dom_ids = tuple(sorted([mod_id, ref_mod_id])) + tmscore = regions_ids_2_tmscore[dom_ids] + dists_current.append(tmscore) + X_list.append(dists_current) + X_np = np.array(X_list) + + # novelty detector + X_np_trn, X_np_test, y_is_novel_trn, y_is_novel_test = train_test_split(X_np, y_is_novel, stratify=y_is_novel) + classifier = RandomForestClassifier(500) + classifier.fit(X_np_trn, y_is_novel_trn) + y_pred = classifier.predict_proba(X_np_test)[:, 1] + logger.info(f'Novelty detection mAP: {average_precision_score(y_is_novel_test, y_pred):.3f}') + novelty_detector = RandomForestClassifier(500) + novelty_detector.fit(X_np, y_is_novel) + novel_domain_detectors.append(novelty_detector) + + #domain type classifier + X_np_trn, X_np_test, y_trn, y_test = train_test_split(X_np, y, stratify=y) + label_binarizer = MultiLabelBinarizer() + y_trn = label_binarizer.fit_transform( + y_trn + ) + y_test = label_binarizer.transform( + y_test + ) + classifier = RandomForestClassifier(500) + classifier.fit(X_np_trn, y_trn) + y_pred = classifier.predict_proba(X_np_test) + y_pred_all = np.array([y_pred_class[:, 1] for y_pred_class in y_pred]).T + logger.info(f'Domain type classification mAP: {average_precision_score(y_test, y_pred_all):.3f}') + + classifier = RandomForestClassifier(500) + classifier.fit(X_np, y) + domain_type_classifiers.append(classifier) + + with open("data/domain_type_predictors.pkl", "wb") as file: + pickle.dump([novel_domain_detectors, domain_type_classifiers], file) From 42ba19fb410f80e1e58068e77186631ccebd15bd Mon Sep 17 00:00:00 2001 From: Raman Date: Mon, 4 Nov 2024 17:11:55 +0100 Subject: [PATCH 3/4] bulk commit of work on a backend app to deploy terpeneminer --- README.md | 60 +++- app.py => app_faster_with_foldseek.py | 282 +++++++++++++----- scripts/easy_predict_sequence_only.py | 12 +- .../DomainsRandomForest/main_config.yaml | 5 +- .../with_minor_reactions_foldseek/config.yaml | 3 + .../PlmDomainsRandomForest/main_config.yaml | 7 +- .../config.yaml | 5 + .../config.yaml | 5 + .../config.yaml | 7 + .../config.yaml | 1 + .../config.yaml | 0 .../config.yaml | 5 + .../config.yaml | 2 + .../config.yaml | 2 + .../config.yaml | 2 - .../config.yaml | 5 + .../config.yaml | 2 + .../config.yaml | 5 + .../config.yaml | 2 + terpeneminer/src/evaluation/evaluation.py | 1 + .../experiment_runner.py | 2 +- terpeneminer/src/models/config_classes.py | 17 ++ .../models/domain_comparisons_randomforest.py | 8 +- .../models/ifaces/domains_sklearn_model.py | 11 +- .../models/ifaces/features_sklearn_model.py | 23 +- .../plm_domain_comparison_randomforest.py | 42 ++- .../get_domains_feature_importances.py | 73 +++-- .../get_plm_feature_importances.py | 172 +++++++++++ terpeneminer/src/models/plm_randomforest.py | 13 +- .../gather_classifier_checkpoints.py | 36 ++- .../src/screening/gather_detections_to_csv.py | 12 +- .../src/screening/tps_predict_fasta.py | 67 +++-- .../comparing_to_known_domains.py | 25 +- .../comparing_to_known_domains_foldseek.py | 2 +- .../structure_processing/domain_detections.py | 99 ++++-- .../predict_domain_types.py | 15 +- .../structural_algorithms.py | 38 ++- .../train_domain_type_classifiers.py | 72 ++--- terpeneminer/src/terpene_miner_main.py | 10 +- terpeneminer/src/utils/data.py | 38 ++- 40 files changed, 933 insertions(+), 255 deletions(-) rename app.py => app_faster_with_foldseek.py (50%) create mode 100644 terpeneminer/configs/DomainsRandomForest/with_minor_reactions_foldseek/config.yaml create mode 100644 terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning/config.yaml create mode 100644 terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset/config.yaml create mode 100644 terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml rename terpeneminer/configs/PlmDomainsRandomForest/{tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore => tps_esm-1v-subseq_with_minor_reactions_global_tuning}/config.yaml (73%) rename terpeneminer/configs/PlmDomainsRandomForest/{tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset.ignore => tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset}/config.yaml (100%) create mode 100644 terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml delete mode 100644 terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions.ignore/config.yaml create mode 100644 terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions/config.yaml create mode 100644 terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_plm_subset/config.yaml create mode 100644 terpeneminer/src/models/plm_domain_faster/get_plm_feature_importances.py diff --git a/README.md b/README.md index efb75b5..c05fe2b 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Table of contents - [6 - Evaluating performance](#6---evaluating-performance) - [7 - Visualization of performance](#7---visualization-of-performance) - [Screening large databases](#screening-large-databases) + - [TerpeneMiner deployment as a backend service](#terpeneminer-deployment-as-a-backend-service) - [Reference](#reference) @@ -92,6 +93,8 @@ pip install . ----------------------------------------- ## Quick start + +### Running sequence-based TPS detection and classification To predict using the model based on TPS language model only, put the sequences of interest into a `.fasta` file and run ```bash @@ -99,6 +102,7 @@ cd TerpeneMiner conda activate terpene_miner python scripts/easy_predict_sequence_only.py --input-fasta-path data/af_inputs_test.fasta --output-csv-path test_seqs_pred.csv --detection-threshold 0.2 --detect-precursor-synthase ``` + ----------------------------------------- ## Workflow @@ -371,7 +375,7 @@ Then, execute the notebook `notebooks/notebook_3_clustering_domains.ipynb`. ```bash cd TerpeneMiner conda activate terpene_miner -python -m terpeneminer.structure_processing.train_domain_type_classifiers > outputs/logs/domain_type_classifier_training.log 2>&1 +python -m terpeneminer.src.structure_processing.train_domain_type_classifiers > outputs/logs/domain_type_classifier_training.log 2>&1 ``` ----------------------------------------- @@ -425,7 +429,7 @@ After training a `PlmDomainsRandomForest`, to select the most important domains cd TerpeneMiner conda activate terpene_miner python -m terpeneminer.src.models.plm_domain_faster.get_domains_feature_importances \ - --top-most-important-domain-features-per-model 200 --output-path "data/domains_subset.pkl" > outputs/logs/domains_subset.log 2>&1 + --top-most-important-domain-features-per-model 200 --output-path "data/domains_subset.pkl" ``` @@ -463,7 +467,7 @@ bash scripts/tps_tune.sh # see the script for more details and accommodate to yo ``` For reproducability, we share outputs of the hyperparameter optimization -on [zenodo](https://zenodo.org/records/10567437) as `outputs.zip`. You can simply unzip its contents to the `outputs` +on [here](https://zenodo.org/records/10567437). You can simply unzip its contents to the `outputs` folder and run the consequent evaluation steps. If you want to train a single model using the best hyperparameters found during the previously run optimization, then set `optimize_hyperparams: false` in the config and run @@ -718,6 +722,7 @@ cd TerpeneMiner conda activate terpene_miner python -m terpeneminer.src.screening.gather_classifier_checkpoints --output-path data/classifier_checkpoints.pkl ``` +Depending on the way you trained the models for individual folds, you might need to set `--use-all-folds` flag. Next, to estimate the required number of workers for the screening, run @@ -749,6 +754,55 @@ python -m terpeneminer.src.screening.gather_detections_to_csv --screening-result ----------------------------------------- +## TerpeneMiner deployment as a backend service + +Prepare models for deployment: +```bash +cd TerpeneMiner +conda activate terpene_miner +terpene_miner_main --select-single-experiment run --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning +python -m terpeneminer.src.models.plm_domain_faster.get_domains_feature_importances \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning \ + --top-most-important-domain-features-per-model 50 --use-all-folds +python -m terpeneminer.src.models.plm_domain_faster.get_plm_feature_importances \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning \ + --top-most-important-plm-features-per-model 400 --use-all-folds +terpene_miner_main --select-single-experiment run --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset +python -m terpeneminer.src.screening.gather_classifier_checkpoints --output-path data/classifier_domain_and_plm_checkpoints.pkl --use-all-folds \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset +python -m terpeneminer.src.structure_processing.train_domain_type_classifiers +``` +Start backend: +```bash +# specify port +export PORT=<..> +nohup uvicorn app_faster_with_foldseek:app --host 0.0.0.0 --port $PORT &> webserver_app.log & +``` +For significantly slower but slightly more accurate predictions: +```bash +cd TerpeneMiner +conda activate terpene_miner +terpene_miner_main --select-single-experiment run --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_with_minor_reactions_global_tuning +python -m terpeneminer.src.models.plm_domain_faster.get_domains_feature_importances \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_with_minor_reactions_global_tuning \ + --top-most-important-domain-features-per-model 50 --use-all-folds +python -m terpeneminer.src.models.plm_domain_faster.get_plm_feature_importances \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_with_minor_reactions_global_tuning \ + --top-most-important-plm-features-per-model 400 --use-all-folds +terpene_miner_main --select-single-experiment run --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset +python -m terpeneminer.src.screening.gather_classifier_checkpoints --output-path data/classifier_domain_and_plm_checkpoints.pkl --use-all-folds \ + --model PlmDomainsRandomForest --model-version tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset +python -m terpeneminer.src.structure_processing.train_domain_type_classifiers +``` +and then start the backend: +```bash +# specify port +export PORT=<..> +nohup uvicorn app:app --host 0.0.0.0 --port $PORT &> webserver_app.log & +``` + +----------------------------------------- + # Reference > Samusevich, R., Hebra, T. et al. Highly accurate discovery of terpene synthases powered by machine learning reveals diff --git a/app.py b/app_faster_with_foldseek.py similarity index 50% rename from app.py rename to app_faster_with_foldseek.py index d106022..790951b 100644 --- a/app.py +++ b/app_faster_with_foldseek.py @@ -1,3 +1,4 @@ +from functools import partial from uuid import uuid4 from fastapi import FastAPI, File, UploadFile, BackgroundTasks, Form @@ -8,7 +9,14 @@ from shutil import copyfile, rmtree import logging import subprocess +from dataclasses import dataclass import re +import numpy as np +from terpeneminer.src.embeddings_extraction.esm_transformer_utils import ( + compute_embeddings, + get_model_and_tokenizer, +) +from terpeneminer.src.utils.data import extract_sequences_from_pdb logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -22,31 +30,54 @@ ] ) +@dataclass +class MotifDetection: + start: int + end: int + motif: str + class_tps: str + +model, batch_converter, alphabet = get_model_and_tokenizer( + "esm-1v-finetuned-subseq", return_alphabet=True + ) + +compute_embeddings_partial = partial( + compute_embeddings, + bert_model=model, + converter=batch_converter, + padding_idx=alphabet.padding_idx, + model_repr_layer=33, + max_len=1022, +) + +with open('data/classifier_domain_and_plm_checkpoints.pkl', 'rb') as file: + fold_classifiers = pickle.load(file) + +with open('data/classifier_plm_checkpoints.pkl', 'rb') as file: + fold_plm_classifiers = pickle.load(file) + # Create FastAPI app instance app = FastAPI() -@app.post("/detect_domains/") -async def upload_file(file: UploadFile = File(...), - is_bfactor_confidence: bool = Form(...)): - # Read the contents of the uploaded file - file_contents = await file.read() +def detect_domains(file_contents, filename, is_bfactor_confidence): # Define the path where the .pdb file will be saved pdb_directory_temp = Path("_temp") if not pdb_directory_temp.exists(): pdb_directory_temp.mkdir() af_source_path = Path("/home/samusevich/TerpeneMiner/data/alphafold_structs") - for pdb_standard_id in ["1ps1", "5eat", "3p5r", "P48449"]: + for pdb_standard_id in ["1ps1", "5eat", "3p5r", "P48449"]: pdb_standard_file_path = af_source_path / f"{pdb_standard_id}.pdb" copyfile(pdb_standard_file_path, pdb_directory_temp / f"{pdb_standard_id}.pdb") # Getting the ID - pdb_id = file.filename.split(".")[0] + pdb_id = filename.split(".")[0] pdb_id = re.sub(r'\(.*?\)', '', pdb_id) pdb_id = "".join(pdb_id.replace("-", "").split()) # Define the path where the .pdb file will be saved pdb_file_path = pdb_directory_temp / f"{pdb_id}.pdb" + pdb_file_to_delete_afterwards = not pdb_file_path.exists() # Saving the ID into a csv file id_filepath = f'{pdb_directory_temp / "dummy_id.csv"}' @@ -54,12 +85,13 @@ async def upload_file(file: UploadFile = File(...), file.writelines(f"ID\n{pdb_id}\n") # Save the content as a .pdb file - with open(pdb_file_path, "wb") as pdb_file: - pdb_file.write(file_contents) + if pdb_file_to_delete_afterwards: + with open(pdb_file_path, "wb") as pdb_file: + pdb_file.write(file_contents) temp_filepath_name = Path("data/alphafold_structs") / f"{pdb_id}.pdb" + temp_filepath_name_to_delete = not temp_filepath_name.exists() if not temp_filepath_name.exists(): copyfile(pdb_file_path, temp_filepath_name) - temp_filepath_name_to_delete = not temp_filepath_name.exists() domain_detections_path = f"_temp/filename_2_detected_domains_completed_confident_{pdb_id}.pkl" detected_domain_structures_root = Path("_temp/detected_domains") @@ -80,6 +112,32 @@ async def upload_file(file: UploadFile = File(...), "--do-not-store-intermediate-files" ) + return pdb_id, pdb_file_path, temp_filepath_name, id_filepath, domain_detections_path, detected_domain_structures_root, pdb_file_to_delete_afterwards, temp_filepath_name_to_delete + + +def detect_known_motifs(sequence: str) -> list[MotifDetection]: + simple_regex = "DD..D" + motif_detections = [] + for x in re.finditer(simple_regex, sequence): + motif_detections.append(MotifDetection(x.start() + 1, x.end() + 1, "DDxxD", 'class I')) + + simple_regex = "[ND]D..[ST]...E" + for x in re.finditer(simple_regex, sequence): + motif_detections.append(MotifDetection(x.start() + 1, x.end() + 1, "NSE/DTE", 'class I')) + + simple_regex = "D.DD" + for x in re.finditer(simple_regex, sequence): + motif_detections.append(MotifDetection(x.start() + 1, x.end() + 1, "DxDD", 'class II')) + return motif_detections + +@app.post("/detect_domains/") +async def upload_file(file: UploadFile = File(...), + is_bfactor_confidence: bool = Form(...)): + # Read the contents of the uploaded file + file_contents = await file.read() + + pdb_id, pdb_file_path, temp_filepath_name, id_filepath, domain_detections_path, detected_domain_structures_root, pdb_file_to_delete_afterwards, temp_filepath_name_to_delete = detect_domains(file_contents, file.filename, is_bfactor_confidence) + with open(domain_detections_path, "rb") as file: detected_domains = pickle.load(file) @@ -173,7 +231,16 @@ async def upload_file(file: UploadFile = File(...), os.remove(comparison_results_path) os.remove(domain_predictions_path) - os.remove(pdb_file_path) + # detecting motifs + chain_2_seq = extract_sequences_from_pdb(pdb_file_path) + input_seq = list(set(chain_2_seq.values())) + if len(input_seq) > 1: + logger.warning(f"Multiple chains in the file {pdb_file_path} are not supported") + input_seq = input_seq[0] + motif_detections = detect_known_motifs(input_seq) + + if pdb_file_to_delete_afterwards: + os.remove(pdb_file_path) if temp_filepath_name_to_delete: os.remove(temp_filepath_name) os.remove(id_filepath) @@ -181,6 +248,7 @@ async def upload_file(file: UploadFile = File(...), rmtree(detected_domain_structures_root) return {"domains": detected_domains, "secondary_structure_residues": secondary_structure_res, + "motif_detections": motif_detections if detected_domains else None, "comparison_to_known_domains": comparison_results[pdb_id] if detected_domains else None, "domain_type_predictions": domain_id_2_predictions if detected_domains else None, "aligned_pdb_filepaths": domain_id_2_aligned_pdb if detected_domains else None} @@ -205,54 +273,8 @@ async def upload_file(file: UploadFile = File(...), # Read the contents of the uploaded file file_contents = await file.read() - # Define the path where the .pdb file will be saved - pdb_directory_temp = Path("_temp") - if not pdb_directory_temp.exists(): - pdb_directory_temp.mkdir() - af_source_path = Path("/home/samusevich/TerpeneMiner/data/alphafold_structs") - for pdb_standard_id in ["1ps1", "5eat", "3p5r", "P48449"]: - pdb_standard_file_path = af_source_path / f"{pdb_standard_id}.pdb" - copyfile(pdb_standard_file_path, pdb_directory_temp / f"{pdb_standard_id}.pdb") - - # Getting the ID - pdb_id = file.filename.split(".")[0] - pdb_id = re.sub(r'\(.*?\)', '', pdb_id) - pdb_id = "".join(pdb_id.replace("-", "").split()) - - # Define the path where the .pdb file will be saved - pdb_file_path = pdb_directory_temp / f"{pdb_id}.pdb" - - # Saving the ID into a csv file - id_filepath = f'{pdb_directory_temp / "dummy_id.csv"}' - with open(id_filepath, "a") as file: - file.writelines(f"ID\n{pdb_id}\n") - - # Save the content as a .pdb file - with open(pdb_file_path, "wb") as pdb_file: - pdb_file.write(file_contents) - temp_filepath_name = Path("data/alphafold_structs") / f"{pdb_id}.pdb" - if not temp_filepath_name.exists(): - copyfile(pdb_file_path, temp_filepath_name) - temp_filepath_name_to_delete = not temp_filepath_name.exists() - - domain_detections_path = f"_temp/filename_2_detected_domains_completed_confident_{pdb_id}.pkl" - detected_domain_structures_root = Path("_temp/detected_domains") - if not detected_domain_structures_root.exists(): - detected_domain_structures_root.mkdir() - os.system( - "python -m terpeneminer.src.structure_processing.domain_detections " - f'--needed-proteins-csv-path "{id_filepath}" ' - "--csv-id-column ID " - "--n-jobs 16 " - "--input-directory-with-structures _temp " - f"{'--is-bfactor-confidence ' if is_bfactor_confidence else ''}" - f'--detections-output-path "{domain_detections_path}" ' - f'--detected-regions-root-path _temp ' - f'--domains-output-path "{detected_domain_structures_root}" ' - "--store-domains " - "--recompute-existing-secondary-structure-residues " - "--do-not-store-intermediate-files" - ) + pdb_id, pdb_file_path, temp_filepath_name, id_filepath, domain_detections_path, detected_domain_structures_root, pdb_file_to_delete_afterwards, temp_filepath_name_to_delete = detect_domains( + file_contents, file.filename, is_bfactor_confidence) with open(domain_detections_path, "rb") as file: detected_domains = pickle.load(file) @@ -261,22 +283,142 @@ async def upload_file(file: UploadFile = File(...), if detected_domains: current_computation_id = uuid4() comparison_results_path = f"_temp/filename_2_regions_vs_known_reg_dists_{current_computation_id}.pkl" - os.system( - "python -m terpeneminer.src.structure_processing.comparing_to_known_domains " - "--input-directory-with-structures data/alphafold_structs/ " - "--n-jobs 16 " - f'--domain-detections-path "{domain_detections_path}" ' - "--domain-detections-residues-path _temp/file_2_all_residues.pkl " - "--path-to-all-known-domains data/alphafold_structs/regions_completed_very_confident_all_ALL.pkl " - "--path-to-known-domains-subset data/domains_subset.pkl " - f'--pdb-filepath "{temp_filepath_name}" ' - f'--output-path "{comparison_results_path}"') + os.system("python -m terpeneminer.src.structure_processing.comparing_to_known_domains_foldseek " + f'--known-domain-structures-root data/detected_domains/all ' + f'--detected-domain-structures-root "{detected_domain_structures_root}" ' + '--path-to-known-domains-subset data/domains_subset.pkl ' + f'--output-path "{comparison_results_path}" ' + f'--pdb-id "{pdb_id}"') logger.info("Compared detected domains to the known ones!") with open(comparison_results_path, "rb") as file: comparison_results = pickle.load(file) + os.remove(comparison_results_path) + else: + comparison_results = None + + logger.info("Computing embeddings..") + chain_2_seq = extract_sequences_from_pdb(pdb_file_path) + input_seq = list(set(chain_2_seq.values())) + if len(input_seq) > 1: + logger.warning(f"Multiple chains in the file {pdb_file_path} are not supported") + input_seq = input_seq[:1] + ( + enzyme_encodings_np_batch, + _, + ) = compute_embeddings_partial(input_seqs=input_seq) + + logger.info("Predicting TPS substrates..") + + predictions = [] + n_samples = len(enzyme_encodings_np_batch) + assert n_samples == 1, "Currently, the backend supports only one sample at a time" + for classifier_i, classifier in enumerate(fold_classifiers): + logger.info(f"Predicting with classifier {classifier_i + 1}/{len(fold_classifiers)}..") + logger.info("Comparing domain detections to the selected known examples") + dom_features_count = sum(map(len, classifier.domain_type_2_order_of_domain_modules.values())) + dom_feat = np.zeros(dom_features_count) + if comparison_results is not None: + current_comparison_results = comparison_results[pdb_id] + was_alpha_observed = False + for domain_detection in detected_domains[pdb_id]: + domain_type = domain_detection.domain + detection_id = domain_detection.module_id + known_domain_id_2_tmscore = dict(current_comparison_results[detection_id]) + if domain_type == 'alpha': + if not was_alpha_observed: + alpha_idx = 1 + was_alpha_observed = True + else: + alpha_idx = 2 + domain_type = f"alpha{alpha_idx}" + for known_module_id, dom_feat_idx in classifier.domain_type_2_order_of_domain_modules[domain_type]: + # assert known_module_id in known_domain_id_2_tmscore, f"Known module {known_module_id} not found in comparison results" + dom_feat[dom_feat_idx] = known_domain_id_2_tmscore.get(known_module_id, 0) + if np.max(dom_feat) < 0.4: + logger.warning("No meaningful domain comparisons. Skipping the model.. ") + if hasattr(classifier, "domain_feature_novelty_detector") and getattr(classifier, + "domain_feature_novelty_detector") is not None: + novelty_prediction = classifier.domain_feature_novelty_detector.predict(dom_feat)[0] + logger.warning( + f"Novelty prediction would have been {novelty_prediction}") + continue + dom_feat = 1 - dom_feat.reshape(1, -1) + if hasattr(classifier, "domain_feature_novelty_detector") and getattr(classifier, "domain_feature_novelty_detector") is not None: + novelty_prediction = classifier.domain_feature_novelty_detector.predict(dom_feat)[0] + if novelty_prediction == -1: + logger.warning("Data drift detected in domain comparisons. Skipping the model..") + continue + if classifier.plm_feat_indices_subset is not None: + emb_plm = np.apply_along_axis(lambda i: i[classifier.plm_feat_indices_subset], 1, enzyme_encodings_np_batch) + else: + emb_plm = enzyme_encodings_np_batch + emb = np.concatenate((emb_plm, dom_feat), axis=1) + + y_pred_proba = classifier.predict_proba(emb) + for sample_i in range(n_samples): + predictions_raw = {} + for class_i, class_name in enumerate(classifier.classes_): + if class_name != "Unknown": + predictions_raw[class_name] = y_pred_proba[class_i][sample_i, 1] + if len(predictions) == 0: + predictions.append( + { + class_name: [value] + for class_name, value in predictions_raw.items() + } + ) + else: + for class_name, value in predictions_raw.items(): + predictions[sample_i][class_name].append(value) + print('predictions: ', predictions) + if len(predictions) == 0: + logger.warning("Falling back to PLM features only due to severe data drift in domain comparisons") + predictions = [] + for classifier_i, classifier in enumerate(fold_plm_classifiers): + logger.info(f"Predicting with plm classifier {classifier_i + 1}/{len(fold_classifiers)}..") + if hasattr(classifier, "plm_feature_novelty_detector") and getattr(classifier, + "plm_feature_novelty_detector") is not None: + novelty_prediction = classifier.plm_feature_novelty_detector.predict(enzyme_encodings_np_batch)[0] + logger.warning( + f"PLM emb novelty prediction is {novelty_prediction}") + y_pred_proba = classifier.predict_proba(enzyme_encodings_np_batch) + for sample_i in range(n_samples): + predictions_raw = {} + for class_i, class_name in enumerate(classifier.classes_): + if class_name != "Unknown": + predictions_raw[class_name] = y_pred_proba[class_i][sample_i, 1] + if classifier_i == 0: + predictions.append( + { + class_name: [value] + for class_name, value + in predictions_raw.items() + } + ) + else: + for class_name, value in predictions_raw.items(): + predictions[sample_i][class_name].append(value) + + logger.info("Averaging predictions over all models..") + predictions_avg = [] + for prediction in predictions: + predictions_avg.append( + { + class_name: np.mean(values) + for class_name, values in prediction.items() + } + ) + if pdb_file_to_delete_afterwards: + os.remove(pdb_file_path) + if temp_filepath_name_to_delete: + os.remove(temp_filepath_name) + os.remove(id_filepath) + os.remove(domain_detections_path) + rmtree(detected_domain_structures_root) + return {'predictions': predictions_avg} + - print(comparison_results) diff --git a/scripts/easy_predict_sequence_only.py b/scripts/easy_predict_sequence_only.py index 59960f2..fecaf27 100644 --- a/scripts/easy_predict_sequence_only.py +++ b/scripts/easy_predict_sequence_only.py @@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--output-csv-path", type=str, default="trembl_screening") parser.add_argument("--detection-threshold", type=float, default=0.3) parser.add_argument("--detect-precursor-synthases", action="store_true") - + parser.add_argument("--model", type=str, default="esm-1v-finetuned-subseq") return parser.parse_args() @@ -33,7 +33,13 @@ def main(): plm_chkpt_path = Path("data/plm_checkpoints") if not plm_chkpt_path.exists(): plm_chkpt_path.mkdir(parents=True) - plm_path = plm_chkpt_path / "checkpoint-tps-esm1v-t33-subseq.ckpt" + assert args.model in { + "esm-1v", + "esm-1v-finetuned-subseq", + "ankh_tps", + "ankh_base", + }, f"Model {args.model} is not supported. Choose between esm-1v, esm-1v-finetuned-subseq, ankh_base, and ankh_tps" + plm_path = plm_chkpt_path / ("checkpoint-tps-esm1v-t33-subseq.ckpt" if args.model == "esm-1v-finetuned-subseq" else "tps_ankh_lr=5e-05_bs=32.pth") if not plm_path.exists(): logger.info("Downloading TPS language model checkpoint..") url = "https://drive.google.com/uc?id=1jU76oUl0-CmiB9m3XhaKmI2HorFhyxC7" @@ -48,7 +54,7 @@ def main(): if not Path(intermediate_outputs_root).exists(): Path(intermediate_outputs_root).mkdir(parents=True) os.system( - "python -m terpeneminer.src.screening.tps_predict_fasta --model esm-1v-finetuned-subseq" + f"python -m terpeneminer.src.screening.tps_predict_fasta --model {args.model}" f" --fasta-path {args.input_fasta_path} --output-root {intermediate_outputs_root}" f" --detect-precursor-synthases {args.detect_precursor_synthases}" f" --detection-threshold {args.detection_threshold}" diff --git a/terpeneminer/configs/DomainsRandomForest/main_config.yaml b/terpeneminer/configs/DomainsRandomForest/main_config.yaml index ec28b72..89d6690 100644 --- a/terpeneminer/configs/DomainsRandomForest/main_config.yaml +++ b/terpeneminer/configs/DomainsRandomForest/main_config.yaml @@ -11,7 +11,7 @@ class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "isTPS"] -optimize_hyperparams: true +optimize_hyperparams: false random_state: 0 n_calls_hyperparams_opt: 350 hyperparam_dimensions: @@ -42,7 +42,8 @@ neg_val: "Unknown" save_trained_model: true negatives_sample_path: "data/sampled_id_2_seq.pkl" tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: true +per_class_optimization: false per_class_with_multilabel_regularization: 0 reuse_existing_partial_results: false load_per_class_params_from: "" +foldseek_distances: false diff --git a/terpeneminer/configs/DomainsRandomForest/with_minor_reactions_foldseek/config.yaml b/terpeneminer/configs/DomainsRandomForest/with_minor_reactions_foldseek/config.yaml new file mode 100644 index 0000000..ee3486f --- /dev/null +++ b/terpeneminer/configs/DomainsRandomForest/with_minor_reactions_foldseek/config.yaml @@ -0,0 +1,3 @@ +include: ../main_config.yaml +tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" +foldseek_distances: true diff --git a/terpeneminer/configs/PlmDomainsRandomForest/main_config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/main_config.yaml index c56f19b..216dfce 100644 --- a/terpeneminer/configs/PlmDomainsRandomForest/main_config.yaml +++ b/terpeneminer/configs/PlmDomainsRandomForest/main_config.yaml @@ -11,7 +11,7 @@ class_names: ["CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O", "isTPS"] -optimize_hyperparams: true +optimize_hyperparams: false random_state: 0 n_calls_hyperparams_opt: 350 hyperparam_dimensions: @@ -37,12 +37,13 @@ n_estimators: 100 max_depth: 1000 n_jobs: -1 class_weight: None -max_train_negs_proportion: 0.98 +max_train_negs_proportion: 0.5 neg_val: "Unknown" save_trained_model: true negatives_sample_path: "data/sampled_id_2_seq.pkl" tps_cleaned_csv_path: "data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv" -per_class_optimization: true +per_class_optimization: false per_class_with_multilabel_regularization: 0 reuse_existing_partial_results: false load_per_class_params_from: "" +foldseek_distances: false diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning/config.yaml new file mode 100644 index 0000000..e2f6507 --- /dev/null +++ b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning/config.yaml @@ -0,0 +1,5 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" +per_class_optimization: false +foldseek_distances: true +requires_multioutputwrapper_for_multilabel: true diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset/config.yaml new file mode 100644 index 0000000..cc44b5c --- /dev/null +++ b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset/config.yaml @@ -0,0 +1,5 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" +per_class_optimization: false +n_calls_hyperparams_opt: 350 +foldseek_distances: true diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml new file mode 100644 index 0000000..fe81cbd --- /dev/null +++ b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_foldseek_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml @@ -0,0 +1,7 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" +per_class_optimization: false +n_calls_hyperparams_opt: 350 +foldseek_distances: true +requires_multioutputwrapper_for_multilabel: false + diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml similarity index 73% rename from terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore/config.yaml rename to terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml index 548d58f..71fc83c 100644 --- a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning.ignore/config.yaml +++ b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml @@ -1,3 +1,4 @@ include: ../main_config.yaml representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" per_class_optimization: false +requires_multioutputwrapper_for_multilabel: true diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset.ignore/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset/config.yaml similarity index 100% rename from terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset.ignore/config.yaml rename to terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset/config.yaml diff --git a/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml new file mode 100644 index 0000000..1565550 --- /dev/null +++ b/terpeneminer/configs/PlmDomainsRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_domains_subset_plm_subset/config.yaml @@ -0,0 +1,5 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" +per_class_optimization: false +n_calls_hyperparams_opt: 350 +requires_multioutputwrapper_for_multilabel: false diff --git a/terpeneminer/configs/PlmRandomForest/ankh_base_with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/PlmRandomForest/ankh_base_with_minor_reactions_global_tuning/config.yaml index 62fcb98..abfd4aa 100644 --- a/terpeneminer/configs/PlmRandomForest/ankh_base_with_minor_reactions_global_tuning/config.yaml +++ b/terpeneminer/configs/PlmRandomForest/ankh_base_with_minor_reactions_global_tuning/config.yaml @@ -1,3 +1,5 @@ include: ../main_config.yaml representations_path: "data/gathered_embs_ankh_base_embs_avg.h5" per_class_optimization: false +requires_multioutputwrapper_for_multilabel: false +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/configs/PlmRandomForest/esm-1v_with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/PlmRandomForest/esm-1v_with_minor_reactions_global_tuning/config.yaml index b6ccfc5..2577b4e 100644 --- a/terpeneminer/configs/PlmRandomForest/esm-1v_with_minor_reactions_global_tuning/config.yaml +++ b/terpeneminer/configs/PlmRandomForest/esm-1v_with_minor_reactions_global_tuning/config.yaml @@ -1,3 +1,5 @@ include: ../main_config.yaml representations_path: "data/gathered_embs_esm-1v_embs_avg.h5" per_class_optimization: false +requires_multioutputwrapper_for_multilabel: false +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions.ignore/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions.ignore/config.yaml deleted file mode 100644 index 5ac33fe..0000000 --- a/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions.ignore/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -include: ../main_config.yaml -representations_path: "data/gathered_embs_ankh_tps_embs_avg.h5" diff --git a/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions/config.yaml new file mode 100644 index 0000000..97d8c89 --- /dev/null +++ b/terpeneminer/configs/PlmRandomForest/tps_ankh_base_with_minor_reactions/config.yaml @@ -0,0 +1,5 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_ankh_tps_embs_avg.h5" +per_class_optimization: false +requires_multioutputwrapper_for_multilabel: true +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml index 548d58f..c4249bf 100644 --- a/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml +++ b/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning/config.yaml @@ -1,3 +1,5 @@ include: ../main_config.yaml representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" per_class_optimization: false +requires_multioutputwrapper_for_multilabel: false +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_plm_subset/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_plm_subset/config.yaml new file mode 100644 index 0000000..5d49c61 --- /dev/null +++ b/terpeneminer/configs/PlmRandomForest/tps_esm-1v-subseq_with_minor_reactions_global_tuning_plm_subset/config.yaml @@ -0,0 +1,5 @@ +include: ../main_config.yaml +representations_path: "data/gathered_embs_esm-1v-finetuned-subseq_embs_avg.h5" +per_class_optimization: false +requires_multioutputwrapper_for_multilabel: true +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/configs/PlmRandomForest/tps_esm-1v_with_minor_reactions_global_tuning/config.yaml b/terpeneminer/configs/PlmRandomForest/tps_esm-1v_with_minor_reactions_global_tuning/config.yaml index 1ea945a..569a02b 100644 --- a/terpeneminer/configs/PlmRandomForest/tps_esm-1v_with_minor_reactions_global_tuning/config.yaml +++ b/terpeneminer/configs/PlmRandomForest/tps_esm-1v_with_minor_reactions_global_tuning/config.yaml @@ -2,3 +2,5 @@ include: ../main_config.yaml representations_path: "data/gathered_embs_esm-1v-finetuned_embs_avg.h5" per_class_optimization: false n_calls_hyperparams_opt: 350 +requires_multioutputwrapper_for_multilabel: true +optimize_hyperparams: false \ No newline at end of file diff --git a/terpeneminer/src/evaluation/evaluation.py b/terpeneminer/src/evaluation/evaluation.py index d499d6c..a501427 100644 --- a/terpeneminer/src/evaluation/evaluation.py +++ b/terpeneminer/src/evaluation/evaluation.py @@ -147,6 +147,7 @@ def eval_experiment( average_precision = average_precision_score( y_true_category, y_pred_category ) + print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ', class_name_to_record, y_pred_category[y_true_category == 1].mean(), y_pred_category[y_true_category == 0].mean()) mccf1 = summary_mccf1( y_true_category, y_pred_category )["mccf1_metric"] diff --git a/terpeneminer/src/experiments_orchestration/experiment_runner.py b/terpeneminer/src/experiments_orchestration/experiment_runner.py index bf534d6..4b71090 100644 --- a/terpeneminer/src/experiments_orchestration/experiment_runner.py +++ b/terpeneminer/src/experiments_orchestration/experiment_runner.py @@ -264,7 +264,7 @@ def run_experiment(experiment_info: ExperimentInfo, load_hyperparameters: bool = elif (fold_root_dir / "all_classes").exists(): fold_class_path = fold_root_dir / "all_classes" else: - raise ValueError("No fold_class_path found") + raise ValueError(f"No fold_class_path found for class {class_name} in folder {fold_root_dir}") previous_results = list( fold_class_path.glob( "*/hyperparameters_optimization/optimization_results_detailed_*.pkl" diff --git a/terpeneminer/src/models/config_classes.py b/terpeneminer/src/models/config_classes.py index 4a79ec6..f42893b 100644 --- a/terpeneminer/src/models/config_classes.py +++ b/terpeneminer/src/models/config_classes.py @@ -37,11 +37,28 @@ class FeaturesRandomForestConfig(SklearnBaseConfig): per_class_with_multilabel_regularization: int +@dataclass +class DomainFeaturesRandomForestConfig(FeaturesRandomForestConfig): + """ + A data class to store model attributes + """ + + foldseek_distances: bool + + @dataclass class EmbRandomForestConfig(EmbSklearnBaseConfig, FeaturesRandomForestConfig): """ A data class to store the corresponding model attributes """ + requires_multioutputwrapper_for_multilabel: bool = False + +@dataclass +class EmbWithDomainsRandomForestConfig(EmbSklearnBaseConfig, DomainFeaturesRandomForestConfig): + """ + A data class to store the corresponding model attributes + """ + requires_multioutputwrapper_for_multilabel: bool = False @dataclass diff --git a/terpeneminer/src/models/domain_comparisons_randomforest.py b/terpeneminer/src/models/domain_comparisons_randomforest.py index 128b8f1..7c9a80b 100644 --- a/terpeneminer/src/models/domain_comparisons_randomforest.py +++ b/terpeneminer/src/models/domain_comparisons_randomforest.py @@ -5,7 +5,7 @@ from sklearn.ensemble import RandomForestClassifier # type: ignore from terpeneminer.src.models.ifaces import DomainsSklearnModel -from terpeneminer.src.models.config_classes import FeaturesRandomForestConfig +from terpeneminer.src.models.config_classes import DomainFeaturesRandomForestConfig logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -17,15 +17,15 @@ class DomainsRandomForest(DomainsSklearnModel): def __init__( self, - config: FeaturesRandomForestConfig, + config: DomainFeaturesRandomForestConfig, ): super().__init__(config=config) self.classifier_class = RandomForestClassifier @classmethod - def config_class(cls) -> Type[FeaturesRandomForestConfig]: + def config_class(cls) -> Type[DomainFeaturesRandomForestConfig]: """ A getter of the model-specific config class :return: A dataclass for config storage """ - return FeaturesRandomForestConfig + return DomainFeaturesRandomForestConfig diff --git a/terpeneminer/src/models/ifaces/domains_sklearn_model.py b/terpeneminer/src/models/ifaces/domains_sklearn_model.py index 94e4bf5..098c41f 100644 --- a/terpeneminer/src/models/ifaces/domains_sklearn_model.py +++ b/terpeneminer/src/models/ifaces/domains_sklearn_model.py @@ -32,7 +32,8 @@ def compare_domains_to_known_instances( for trn_id in trn_uni_ids: allowed_feat_indices.extend(model.uniid_2_column_ids[trn_id]) if domain_indices_subset is not None: - allowed_feat_indices = list(set(allowed_feat_indices) & domain_indices_subset) + allowed_feat_indices = list(set(allowed_feat_indices).intersection(domain_indices_subset)) + allowed_feat_indices = sorted(allowed_feat_indices) features_df_domain_detections = pd.DataFrame( { model.config.id_col_name: model.all_ids_list_dom, @@ -55,7 +56,11 @@ def __init__(self, config: SklearnBaseConfig): for param, value in config.__dict__.items(): setattr(self, param, value) self.config = config - with open("data/clustering__domain_dist_based_features.pkl", "rb") as file: + if hasattr(config, "foldseek_distances") and config.foldseek_distances: + domain_dist_path = "data/clustering__domain_dist_based_features_foldseek.pkl" + else: + domain_dist_path = "data/clustering__domain_dist_based_features.pkl" + with open(domain_dist_path, "rb") as file: ( self.feats_dom_dists, self.all_ids_list_dom, @@ -77,7 +82,7 @@ def _setup_features_df_for_current_data(self, input_df: pd.DataFrame): { "Uniprot ID": ids_without_domain_detections, "Emb": [ - np.zeros(len(self.allowed_feat_indices)) + np.ones(len(self.allowed_feat_indices)) for _ in range(len(ids_without_domain_detections)) ], } diff --git a/terpeneminer/src/models/ifaces/features_sklearn_model.py b/terpeneminer/src/models/ifaces/features_sklearn_model.py index 734f887..11e8a37 100644 --- a/terpeneminer/src/models/ifaces/features_sklearn_model.py +++ b/terpeneminer/src/models/ifaces/features_sklearn_model.py @@ -8,6 +8,8 @@ import sklearn.base # type: ignore from sklearn.multioutput import MultiOutputClassifier # type: ignore from sklearn.preprocessing import MultiLabelBinarizer # type: ignore +from sklearn.calibration import CalibratedClassifierCV # type: ignore +from sklearn.ensemble import RandomForestClassifier # type: ignore from .config_baseclasses import SklearnBaseConfig from .model_baseclass import BaseModel @@ -85,9 +87,17 @@ def fit_core(self, train_df: pd.DataFrame, class_name: str = None): ) except AttributeError: requires_multioutputwrapper_for_multilabel = False + model_params = self.get_model_specific_params() + # if self.classifier_class == RandomForestClassifier: + # model_params["class_weight"] = "balanced" + # logger.info("Balanced class weights are used") if requires_multioutputwrapper_for_multilabel: + logger.info("Fitting the model with MultiOutputClassifier...") self.classifier = MultiOutputClassifier( - self.classifier_class(**self.get_model_specific_params()) + CalibratedClassifierCV(self.classifier_class(**model_params), cv=10) + ) + self.classifier = MultiOutputClassifier( + self.classifier_class(**model_params) ) else: self.classifier = self.classifier_class( @@ -171,6 +181,9 @@ def predict_proba( test_df = val_df.merge( self.features_df, on=self.config.id_col_name, copy=False, how="left" ).set_index(self.config.id_col_name) + + # if 'Q9FXV8' in test_df.index: + # print('################################################', list(test_df.loc['Q9FXV8', 'Emb'])) logger.info( "In predict_proba(), features DF shape is: %d x %d", *self.features_df.shape, @@ -211,6 +224,11 @@ def predict_proba( if isinstance(y_pred_proba, list) else y_pred_proba[:, class_i] ) + + # if 'Q9FXV8' in test_df.index: + # bool_idx = test_df.index == 'Q9FXV8' + # print('$$$$$$$$$$$$$$$$$$$$' * 100, self.config.class_names[class_i], val_proba_np[bool_idx, class_i]) + else: for class_i, class_name in enumerate(self.config.class_names): if ( @@ -226,3 +244,6 @@ def predict_proba( else y_pred_proba[:, 1] ) return val_proba_np + + + diff --git a/terpeneminer/src/models/plm_domain_comparison_randomforest.py b/terpeneminer/src/models/plm_domain_comparison_randomforest.py index 27e4eb4..f6752d2 100644 --- a/terpeneminer/src/models/plm_domain_comparison_randomforest.py +++ b/terpeneminer/src/models/plm_domain_comparison_randomforest.py @@ -4,9 +4,11 @@ import numpy as np # type: ignore import pandas as pd # type: ignore +import logging +from sklearn.ensemble import IsolationForest # type: ignore from terpeneminer.src.models.config_classes import ( - EmbRandomForestConfig, + EmbWithDomainsRandomForestConfig, EmbMLPConfig, EmbLogisticRegressionConfig, ) @@ -17,6 +19,10 @@ ) +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + # pylint: disable=R0903, R0901 class PlmDomainsRandomForest(PlmRandomForest): """ @@ -25,13 +31,17 @@ class PlmDomainsRandomForest(PlmRandomForest): def __init__( self, - config: EmbRandomForestConfig | EmbMLPConfig | EmbLogisticRegressionConfig, + config: EmbWithDomainsRandomForestConfig | EmbMLPConfig | EmbLogisticRegressionConfig, ): super().__init__( config=config, ) # pylint: disable=R0801 - with open("data/clustering__domain_dist_based_features.pkl", "rb") as file: + if hasattr(config, "foldseek_distances") and config.foldseek_distances: + domain_dist_path = "data/clustering__domain_dist_based_features_foldseek.pkl" + else: + domain_dist_path = "data/clustering__domain_dist_based_features.pkl" + with open(domain_dist_path, "rb") as file: ( self.feats_dom_dists, self.all_ids_list_dom, @@ -41,6 +51,8 @@ def __init__( self.allowed_feat_indices: list[int] = None # type: ignore self.features_df_plm = self.features_df.copy() self.features_df = None + self.domain_feature_novelty_detector = None + self.plm_feature_novelty_detector = None # to experiment with the domain features subset if "domains_subset" in self.config.experiment_info.model_version: # to obtain the subset of domain features, run the following code: @@ -49,6 +61,13 @@ def __init__( _, self.feat_indices_subset = pickle.load(file) else: self.feat_indices_subset = None + if "plm_subset" in self.config.experiment_info.model_version: + # to obtain the subset of domain features, run the following code: + # python -m src.models.plm_domain_faster.get_domains_feature_importances + with open("data/plm_feats_subset.pkl", "rb") as file: + self.plm_feat_indices_subset = sorted(pickle.load(file)) + else: + self.plm_feat_indices_subset = None def fit_core(self, train_df: pd.DataFrame, class_name: str = None): """ @@ -64,6 +83,19 @@ def fit_core(self, train_df: pd.DataFrame, class_name: str = None): dom_features_df["Emb_dom"] = dom_features_df["Emb"] + nineth_percentile = dom_features_df["Emb_dom"].apply(lambda x: np.percentile(1 - x, 90)) + logger.info(f"Average 90th percentile of the tm-score: {nineth_percentile.mean()}") + + # novelty detector to check for data drift + dom_feats_trn = np.stack(dom_features_df["Emb_dom"].values) + + self.domain_feature_novelty_detector = IsolationForest(n_estimators=400).fit(dom_feats_trn) + logger.info(f"Novelty detector for domain features is trained. Proportion of outliers: {np.mean(self.domain_feature_novelty_detector.predict(dom_feats_trn) == -1):.2f}") + + plm_feats = np.stack(self.features_df_plm["Emb"].values) + self.plm_feature_novelty_detector = IsolationForest(n_estimators=400).fit(plm_feats) + logger.info(f"Novelty detector for plm features is trained. Proportion of outliers: {np.mean(self.plm_feature_novelty_detector.predict(plm_feats) == -1):.2f}") + self.features_df = self.features_df_plm.merge( dom_features_df[[self.config.id_col_name, "Emb_dom"]], on=self.config.id_col_name, @@ -72,7 +104,7 @@ def fit_core(self, train_df: pd.DataFrame, class_name: str = None): missing_dist_feats_bool_idx = self.features_df["Emb_dom"].isnull() self.features_df.loc[missing_dist_feats_bool_idx, "Emb_dom"] = pd.Series( [ - np.zeros(len(self.allowed_feat_indices)) + np.ones(len(self.allowed_feat_indices)) for _ in range(sum(missing_dist_feats_bool_idx)) ], index=self.features_df.loc[missing_dist_feats_bool_idx].index, @@ -91,4 +123,4 @@ def config_class(cls) -> Type[BaseConfig]: A getter of the model-specific config class :return: A dataclass for config storage """ - return EmbRandomForestConfig + return EmbWithDomainsRandomForestConfig diff --git a/terpeneminer/src/models/plm_domain_faster/get_domains_feature_importances.py b/terpeneminer/src/models/plm_domain_faster/get_domains_feature_importances.py index 5f7af83..ffd39bc 100644 --- a/terpeneminer/src/models/plm_domain_faster/get_domains_feature_importances.py +++ b/terpeneminer/src/models/plm_domain_faster/get_domains_feature_importances.py @@ -7,6 +7,9 @@ import pickle import pandas as pd # type: ignore +from sklearn.ensemble import RandomForestClassifier # type: ignore +from sklearn.multioutput import MultiOutputClassifier # type: ignore +from sklearn.calibration import CalibratedClassifierCV # type: ignore from terpeneminer.src.experiments_orchestration.experiment_selector import ( collect_single_experiment_arguments, @@ -58,14 +61,23 @@ def parse_args() -> argparse.Namespace: help="A flag to use all folds instead of individual fold checkpoints", action="store_true", ) + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--model-version", type=str, default=None) return parser.parse_args() if __name__ == "__main__": config_root_path = get_config_root() - experiment_kwargs = collect_single_experiment_arguments(config_root_path) - experiment_info = ExperimentInfo(**experiment_kwargs) args = parse_args() + if args.model is None or args.model_version is None: + experiment_kwargs = collect_single_experiment_arguments(config_root_path) + else: + experiment_kwargs = { + "model_type": args.model, + "model_version": args.model_version, + } + experiment_info = ExperimentInfo(**experiment_kwargs) + with open(args.domain_features_path, "rb") as file: ( @@ -121,35 +133,38 @@ def parse_args() -> argparse.Namespace: ) from index_error with open(fold_class_latest_path / f"model_fold_{fold_i}.pkl", "rb") as file: model = pickle.load(file) - importances = model.classifier.feature_importances_ - number_of_domain_comparisons = len(model.allowed_feat_indices) - plm_embedding_size = len(importances) - number_of_domain_comparisons - feature_names = [f"tps_{i}" for i in range(plm_embedding_size)] + [ - idx_2_domain_id[feat_i] for feat_i in model.allowed_feat_indices - ] - forest_importances = pd.Series(importances, index=feature_names) - forest_importances_domains = pd.Series( - importances[plm_embedding_size:], - index=[idx_2_domain_id[feat_i] for feat_i in model.allowed_feat_indices], - ) - forest_importances_indices = pd.Series( - importances[plm_embedding_size:], - index=model.allowed_feat_indices, - ) - domains_subset = domains_subset.union( - set( - forest_importances_domains.sort_values(ascending=False) - .iloc[: args.top_most_important_domain_features_per_model] - .index + if isinstance(model.classifier, RandomForestClassifier): + classifiers_fold = [model.classifier] + elif isinstance(model.classifier, MultiOutputClassifier): + classifiers_fold = [] + mo_estimators = model.classifier.estimators_ + for mo_estimator in mo_estimators: + if isinstance(mo_estimator, RandomForestClassifier): + classifiers_fold.append(mo_estimator) + elif isinstance(mo_estimator, CalibratedClassifierCV): + for calibrated_classifier in mo_estimator.calibrated_classifiers_: + if isinstance(calibrated_classifier.estimator, RandomForestClassifier): + classifiers_fold.append(calibrated_classifier.estimator) + for classifier_ in classifiers_fold: + importances = classifier_.feature_importances_ + number_of_domain_comparisons = len(model.allowed_feat_indices) + plm_embedding_size = len(importances) - number_of_domain_comparisons + feature_names = [f"tps_{i}" for i in range(plm_embedding_size)] + [ + f"dom_{feat_i}" for feat_i in model.allowed_feat_indices + ] + forest_importances = pd.Series(importances, index=feature_names) + forest_importances_indices = pd.Series( + importances[plm_embedding_size:], + index=model.allowed_feat_indices, ) - ) - feat_indices_subset = feat_indices_subset.union( - set( - forest_importances_indices.sort_values(ascending=False) - .iloc[: args.top_most_important_domain_features_per_model] - .index + domains_subset = domains_subset.union({idx_2_domain_id[feat_i] for feat_i in model.allowed_feat_indices}) + feat_indices_subset = feat_indices_subset.union( + set( + forest_importances_indices.sort_values(ascending=False) + .iloc[: args.top_most_important_domain_features_per_model] + .index + ) ) - ) terpene_synthases_df = pd.read_csv(args.tps_file_path) ids_rare_set = set( diff --git a/terpeneminer/src/models/plm_domain_faster/get_plm_feature_importances.py b/terpeneminer/src/models/plm_domain_faster/get_plm_feature_importances.py new file mode 100644 index 0000000..89e48a2 --- /dev/null +++ b/terpeneminer/src/models/plm_domain_faster/get_plm_feature_importances.py @@ -0,0 +1,172 @@ +# pylint: disable=R0801 +"""The helper script to obtain feature importances for the domains and select the most important ones. +Usage: python -m src.models.plm_domain_faster.get_domains_feature_importances +""" + +import argparse +import pickle + +import pandas as pd # type: ignore +from sklearn.ensemble import RandomForestClassifier # type: ignore +from sklearn.multioutput import MultiOutputClassifier # type: ignore +from sklearn.calibration import CalibratedClassifierCV # type: ignore + +from terpeneminer.src.experiments_orchestration.experiment_selector import ( + collect_single_experiment_arguments, +) +from terpeneminer.src.utils.project_info import ( + ExperimentInfo, + get_config_root, + get_output_root, +) + + +def parse_args() -> argparse.Namespace: + """ + This function parses arguments + :return: current argparse.Namespace + """ + parser = argparse.ArgumentParser( + description="A script to gather classifier checkpoints from an output directory" + ) + parser.add_argument( + "--top-most-important-plm-features-per-model", + help="A number of top features to take from each model (for 5-fold CV, there are 5 models)", + type=int, + default=100, + ) + parser.add_argument( + "--output-path", + help="A file to save the selected domains", + type=str, + default="data/plm_feats_subset.pkl", + ) + parser.add_argument( + "--domain-features-path", + help="A file with precomputed domain features", + type=str, + default="data/clustering__domain_dist_based_features.pkl", + ) + parser.add_argument( + "--n-folds", help="A number of folds used in CV", type=int, default=5 + ) + parser.add_argument( + "--tps-file-path", + help="A path to the TPS file", + type=str, + default="data/TPS-Nov19_2023_verified_all_reactions_with_neg_with_folds.csv", + ) + parser.add_argument( + "--use-all-folds", + help="A flag to use all folds instead of individual fold checkpoints", + action="store_true", + ) + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--model-version", type=str, default=None) + return parser.parse_args() + + +if __name__ == "__main__": + config_root_path = get_config_root() + args = parse_args() + if args.model is None or args.model_version is None: + experiment_kwargs = collect_single_experiment_arguments(config_root_path) + else: + experiment_kwargs = { + "model_type": args.model, + "model_version": args.model_version, + } + experiment_info = ExperimentInfo(**experiment_kwargs) + + with open(args.domain_features_path, "rb") as file: + ( + feats_dom_dists, + all_ids_list_dom, + uniid_2_column_ids, + domain_module_id_2_dist_matrix_index, + ) = pickle.load(file) + + idx_2_domain_id = {} + for domain_id, indices in domain_module_id_2_dist_matrix_index.items(): + for i in indices: + idx_2_domain_id[i] = domain_id + + n_folds = args.n_folds + experiment_output_folder_root = ( + get_output_root() / experiment_info.model_type / experiment_info.model_version + ) + assert ( + experiment_output_folder_root.exists() + ), f"Output folder {experiment_output_folder_root} for {experiment_info} does not exist" + model_version_fold_folders = { + x.stem for x in experiment_output_folder_root.glob("*") + } + if not args.use_all_folds and ( + len(model_version_fold_folders.intersection(set(map(str, range(n_folds))))) + == n_folds + ): + fold_2_root_dir = { + fold_i: experiment_output_folder_root / f"{fold_i}" + for fold_i in range(n_folds) + } + elif "all_folds" in model_version_fold_folders: + fold_2_root_dir = { + fold_i: experiment_output_folder_root / "all_folds" + for fold_i in range(n_folds) + } + else: + raise NotImplementedError( + f"Not all fold outputs found. Please run corresponding experiments ({experiment_info}) before evaluation" + ) + + feat_indices_subset: set = set() + for fold_i, fold_root_dir in fold_2_root_dir.items(): + fold_class_path = fold_root_dir / "all_classes" + assert fold_class_path.exists(), "Only all_classes are supported" + try: + fold_class_latest_path = sorted(fold_class_path.glob("*"))[-1] + except IndexError as index_error: + raise NotImplementedError( + f"Please run corresponding experiments ({experiment_info}) before evaluation" + ) from index_error + with open(fold_class_latest_path / f"model_fold_{fold_i}.pkl", "rb") as file: + model = pickle.load(file) + if isinstance(model.classifier, RandomForestClassifier): + classifiers_fold = [model.classifier] + elif isinstance(model.classifier, MultiOutputClassifier): + classifiers_fold = [] + mo_estimators = model.classifier.estimators_ + for mo_estimator in mo_estimators: + if isinstance(mo_estimator, RandomForestClassifier): + classifiers_fold.append(mo_estimator) + elif isinstance(mo_estimator, CalibratedClassifierCV): + for calibrated_classifier in mo_estimator.calibrated_classifiers_: + if isinstance(calibrated_classifier.estimator, RandomForestClassifier): + classifiers_fold.append(calibrated_classifier.estimator) + for classifier_ in classifiers_fold: + importances = classifier_.feature_importances_ + try: + number_of_domain_comparisons = len(model.allowed_feat_indices) + except AttributeError: + number_of_domain_comparisons = 0 + plm_embedding_size = len(importances) - number_of_domain_comparisons + feature_names = [f"tps_{i}" for i in range(plm_embedding_size)] + if number_of_domain_comparisons: + feature_names += [ + idx_2_domain_id[feat_i] for feat_i in model.allowed_feat_indices + ] + forest_importances = pd.Series(importances, index=feature_names) + forest_importances_indices = pd.Series( + importances[:plm_embedding_size], + index=list(range(plm_embedding_size)), + ) + feat_indices_subset = feat_indices_subset.union( + set( + forest_importances_indices.sort_values(ascending=False) + .iloc[: args.top_most_important_plm_features_per_model] + .index + ) + ) + print('feat_indices_subset size: ', len(feat_indices_subset)) + with open(args.output_path, "wb") as file_write: + pickle.dump(feat_indices_subset, file_write) diff --git a/terpeneminer/src/models/plm_randomforest.py b/terpeneminer/src/models/plm_randomforest.py index 32c6009..3eeb069 100644 --- a/terpeneminer/src/models/plm_randomforest.py +++ b/terpeneminer/src/models/plm_randomforest.py @@ -1,6 +1,6 @@ """A class for Random Forest predictive models on top of protein language model (PLM) embeddings""" from typing import Type - +import pickle from sklearn.ensemble import RandomForestClassifier # type: ignore from terpeneminer.src.models.config_classes import ( @@ -25,6 +25,17 @@ def __init__( config=config, ) self.classifier_class = RandomForestClassifier + if "plm_subset" in self.config.experiment_info.model_version: + # to obtain the subset of domain features, run the following code: + # python -m src.models.plm_domain_faster.get_domains_feature_importances + with open("data/plm_feats_subset.pkl", "rb") as file: + self.plm_feat_indices_subset = sorted(pickle.load(file)) + else: + self.plm_feat_indices_subset = None + if self.plm_feat_indices_subset is not None: + self.features_df["Emb"] = self.features_df["Emb"].apply( + lambda x: x[self.plm_feat_indices_subset] + ) @classmethod def config_class(cls) -> Type[BaseConfig]: diff --git a/terpeneminer/src/screening/gather_classifier_checkpoints.py b/terpeneminer/src/screening/gather_classifier_checkpoints.py index 360328b..d7cb6f1 100644 --- a/terpeneminer/src/screening/gather_classifier_checkpoints.py +++ b/terpeneminer/src/screening/gather_classifier_checkpoints.py @@ -2,6 +2,7 @@ """A helper script to gather classifier checkpoints from an output directory""" import argparse import pickle +from collections import defaultdict # pylint: disable=unused-import import scipy.stats # type: ignore @@ -37,14 +38,22 @@ def parse_args() -> argparse.Namespace: help="A flag to use all folds instead of individual fold checkpoints", action="store_true", ) + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--model-version", type=str, default=None) return parser.parse_args() if __name__ == "__main__": config_root_path = get_config_root() - experiment_kwargs = collect_single_experiment_arguments(config_root_path) - experiment_info = ExperimentInfo(**experiment_kwargs) args = parse_args() + if args.model is None or args.model_version is None: + experiment_kwargs = collect_single_experiment_arguments(config_root_path) + else: + experiment_kwargs = { + "model_type": args.model, + "model_version": args.model_version, + } + experiment_info = ExperimentInfo(**experiment_kwargs) n_folds = args.n_folds experiment_output_folder_root = ( get_output_root() / experiment_info.model_type / experiment_info.model_version @@ -92,6 +101,9 @@ def parse_args() -> argparse.Namespace: with open("data/clustering__domain_dist_based_features.pkl", "rb") as file: domain_module_id_2_dist_matrix_index = pickle.load(file)[-1] + with open('data/domain_2_start_end_cols.pkl', 'rb') as file: + domain_2_start_end_cols = pickle.load(file) + with open("data/domains_subset.pkl", "rb") as file: feat_indices_subset = pickle.load(file)[-1] domain_module_id_2_dist_matrix_index_subset = { @@ -110,6 +122,26 @@ def parse_args() -> argparse.Namespace: feat_idx_2_module_id[feat_i] for feat_i in model.allowed_feat_indices ] model.classifier.order_of_domain_modules = order_of_domain_modules + + assert model.allowed_feat_indices == sorted(model.allowed_feat_indices) + feat_idx_2_number_of_skipped_features = {-1: 0} + prev_present_index = -1 + for feat_idx in model.allowed_feat_indices: + feat_idx_2_number_of_skipped_features[feat_idx] = feat_idx_2_number_of_skipped_features[prev_present_index] + feat_idx - prev_present_index - 1 + prev_present_index = feat_idx + + domain_type_2_order_of_domain_modules = defaultdict(list) + for domain_type, (start_global_index, end_global_index) in domain_2_start_end_cols.items(): + for domain_type_feat_idx in range(start_global_index, end_global_index): + if domain_type_feat_idx in model.allowed_feat_indices: + domain_type_2_order_of_domain_modules[domain_type].append((feat_idx_2_module_id[domain_type_feat_idx], + domain_type_feat_idx - feat_idx_2_number_of_skipped_features[domain_type_feat_idx])) + model.classifier.domain_type_2_order_of_domain_modules = domain_type_2_order_of_domain_modules + + if hasattr(model, "domain_feature_novelty_detector") and getattr(model, "domain_feature_novelty_detector") is not None: + model.classifier.domain_feature_novelty_detector = model.domain_feature_novelty_detector + if hasattr(model, "plm_feat_indices_subset") and getattr(model, "plm_feat_indices_subset") is not None: + model.classifier.plm_feat_indices_subset = model.plm_feat_indices_subset classifiers.append(model.classifier) with open(args.output_path, "wb") as file_writer: diff --git a/terpeneminer/src/screening/gather_detections_to_csv.py b/terpeneminer/src/screening/gather_detections_to_csv.py index 3fdea1b..76441cd 100644 --- a/terpeneminer/src/screening/gather_detections_to_csv.py +++ b/terpeneminer/src/screening/gather_detections_to_csv.py @@ -13,6 +13,16 @@ logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) +SUBSTR_2_NAME = {"CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "Farnesyl pyrophosphate", + "CC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "Geranyl pyrophosphate", + "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "Geranylgeranyl pyrophosphate", + "CC(C)=CCCC(C)=CCCC(C)=CCCC=C(C)CCC=C(C)CCC1OC1(C)C": "(S)-2,3-epoxysqualene", + "CC1(C)CCCC2(C)C1CCC(=C)C2CCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "copalyl diphosphate", + "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "Geranylfarnesyl pyrophosphate", + "CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "2x Farnesyl pyrophosphate", + "CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O.CC(C)=CCCC(C)=CCCC(C)=CCCC(C)=CCOP([O-])(=O)OP([O-])([O-])=O": "2x Geranylgeranyl pyrophosphate", + } + def parse_args() -> argparse.Namespace: """ @@ -63,7 +73,7 @@ def parse_args() -> argparse.Namespace: processed_files.append(str(detected_file)) predicted_class_2_vals.update({"ID": ids}) - df_detections = pd.DataFrame(predicted_class_2_vals) + df_detections = pd.DataFrame({f"{class_name} {('(' if class_name in SUBSTR_2_NAME else '') + SUBSTR_2_NAME.get(class_name, '') + (')' if class_name in SUBSTR_2_NAME else '')}": values for class_name, values in predicted_class_2_vals.items()}) if len(df_detections) and "isTPS" in df_detections.columns: df_detections = df_detections.sort_values("isTPS", ascending=False) df_detections.to_csv(args.output_path, index=False) diff --git a/terpeneminer/src/screening/tps_predict_fasta.py b/terpeneminer/src/screening/tps_predict_fasta.py index 429328d..b3acd2b 100644 --- a/terpeneminer/src/screening/tps_predict_fasta.py +++ b/terpeneminer/src/screening/tps_predict_fasta.py @@ -19,6 +19,10 @@ compute_embeddings, get_model_and_tokenizer, ) +from terpeneminer.src.embeddings_extraction.ankh_transformer_utils import ( + compute_embeddings as ankh_compute_embeddings, + get_model_and_tokenizer as ankh_get_model_and_tokenizer +) def _extract_id_from_entry(entry: tuple) -> str: @@ -99,18 +103,28 @@ def main(arguments: argparse.Namespace): in the specified output directory. """ - model, batch_converter, alphabet = get_model_and_tokenizer( - arguments.model, return_alphabet=True - ) + if "esm" in args.model: + model, batch_converter, alphabet = get_model_and_tokenizer( + arguments.model, return_alphabet=True + ) - compute_embeddings_partial = partial( - compute_embeddings, - bert_model=model, - converter=batch_converter, - padding_idx=alphabet.padding_idx, - model_repr_layer=33, - max_len=arguments.max_len, - ) + compute_embeddings_partial = partial( + compute_embeddings, + bert_model=model, + converter=batch_converter, + padding_idx=alphabet.padding_idx, + model_repr_layer=33, + max_len=arguments.max_len, + ) + elif "ankh" in args.model: + model, tokenizer = ankh_get_model_and_tokenizer(args.model) + compute_embeddings_partial = partial( + ankh_compute_embeddings, bert_model=model, tokenizer=tokenizer + ) + else: + raise NotImplementedError( + f"Model {args.model} is not supported. Currently only esm, ankh model families are supported" + ) uniprot_generator = esm.data.read_fasta(arguments.fasta_path) @@ -133,25 +147,44 @@ def process_embeddings( predictions = [] n_samples = len(enzyme_encodings_np_batch) for classifier_i, classifier in enumerate(classifiers): - y_pred_proba = classifier.predict_proba(enzyme_encodings_np_batch) + if hasattr(classifier, "plm_feat_indices_subset") and classifier.plm_feat_indices_subset is not None: + emb_plm = np.apply_along_axis(lambda i: i[classifier.plm_feat_indices_subset], 1, + enzyme_encodings_np_batch) + else: + emb_plm = enzyme_encodings_np_batch + y_pred_proba = classifier.predict_proba(emb_plm) for sample_i in range(n_samples): predictions_raw = {} for class_i, class_name in enumerate(classifier.classes_): if class_name != "Unknown": predictions_raw[class_name] = y_pred_proba[class_i][sample_i, 1] + if sample_i == 0: + print('predictions_raw: ', predictions_raw) if classifier_i == 0: predictions.append( { - class_name.replace("-", "_"): value / len(classifiers) + class_name: [value] for class_name, value in predictions_raw.items() } ) else: for class_name, value in predictions_raw.items(): - predictions[sample_i][ - class_name.replace("-", "_") - ] += value / len(classifiers) - return predictions + predictions[sample_i][class_name].append(value) + # average the predictions + predictions_avg = [] + for prediction in predictions: + predictions_avg.append( + { + class_name: np.mean(values) + for class_name, values in prediction.items() + } + ) + print({ + class_name: len(values) + for class_name, values in prediction.items() + }) + print('predictions_avg: ', predictions_avg) + return predictions_avg next_batch = [] next_batch_ids = [] diff --git a/terpeneminer/src/structure_processing/comparing_to_known_domains.py b/terpeneminer/src/structure_processing/comparing_to_known_domains.py index 3d364d9..3486103 100644 --- a/terpeneminer/src/structure_processing/comparing_to_known_domains.py +++ b/terpeneminer/src/structure_processing/comparing_to_known_domains.py @@ -65,6 +65,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--number-of-workers", type=int, default=16) parser.add_argument("--output-path", type=str, default="_temp/filename_2_regions_vs_known_reg_dists.pkl") parser.add_argument("--pdb-filepath", type=str, default="") + parser.add_argument("--same-domain-only", action="store_true") return parser.parse_args() @@ -149,7 +150,8 @@ def compute_distances_to_known_regions( # preparing the data for parallel processing n_workers = args.number_of_workers temp_struct_name = input_directory / Path(args.pdb_filepath).name - if args.pdb_filepath != str(temp_struct_name): + temp_filepath_name_to_delete = not temp_struct_name.exists() + if args.pdb_filepath != str(temp_struct_name) and temp_filepath_name_to_delete: copyfile(args.pdb_filepath, temp_struct_name) # loading detected domains @@ -164,23 +166,21 @@ def compute_distances_to_known_regions( for filename, regions in filename_2_detected_regions_completed_confident.items(): region_2_known_reg_dists = defaultdict(list) for region in regions: - domain_type = region.domain - if domain_type not in type_2_regions: - type_2_regions[domain_type] = [el for el in regions_all if el[1].domain == domain_type] - regions_all_current_type = type_2_regions[domain_type] + if args.same_domain_only: + domain_type = region.domain + if domain_type not in type_2_regions: + type_2_regions[domain_type] = [el for el in regions_all if el[1].domain == domain_type] + regions_all_current_type = type_2_regions[domain_type] + else: + regions_all_current_type = regions_all regions_segment_len = len(regions_all_current_type) // n_workers + 1 region_segments = [] start_i = 0 - print('len(regions_all): ', len(regions_all)) while start_i < len(regions_all): region_segments.append(regions_all_current_type[start_i:start_i + regions_segment_len]) - print() start_i += regions_segment_len - - print('region_segments cout: ', sum([len(x) for x in region_segments])) - computations_id = str(uuid4()) partial_dist_compute = partial( compute_distances_to_known_regions, @@ -189,7 +189,6 @@ def compute_distances_to_known_regions( filename_2_all_residues=file_2_all_residues, computation_id=computations_id ) - print('list(range(len(region_segments))): ', list(range(len(region_segments)))) with Pool(n_workers - 2) as p: list_of_distances_list = p.map(partial_dist_compute, list(range(len(region_segments)))) for results_path in Path('.').glob(f'*{computations_id}.pkl'): @@ -202,5 +201,5 @@ def compute_distances_to_known_regions( with open(args.output_path, "wb") as file: pickle.dump(filename_2_regions_vs_known_reg_dists, file) - - os.remove(temp_struct_name) + if temp_filepath_name_to_delete: + os.remove(temp_struct_name) diff --git a/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py b/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py index 88ac61d..0a550a1 100644 --- a/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py +++ b/terpeneminer/src/structure_processing/comparing_to_known_domains_foldseek.py @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: tsv_path = working_dir / f'aln_all_domains_vs_all_{uuid4()}.tsv' tmp_path = working_dir / f'tmp_all_{uuid4()}' foldseek_comparison_output = subprocess.check_output( - f'foldseek easy-search {args.detected_domain_structures_root} {args.known_domain_structures_root} {tsv_path} {tmp_path} --max-seqs 3000 -e 0.1 --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,alntmscore'.split()) + f'foldseek easy-search {args.detected_domain_structures_root} {args.known_domain_structures_root} {tsv_path} {tmp_path} --max-seqs 5000 -e 1 -s 10 --exhaustive-search --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,alntmscore'.split()) df_foldseek = pd.read_csv(tsv_path, sep='\t', header=None, names=['query', 'target', 'fident', 'alnlen', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits', 'alntmscore']) diff --git a/terpeneminer/src/structure_processing/domain_detections.py b/terpeneminer/src/structure_processing/domain_detections.py index 070c488..1661cab 100644 --- a/terpeneminer/src/structure_processing/domain_detections.py +++ b/terpeneminer/src/structure_processing/domain_detections.py @@ -11,6 +11,7 @@ from datetime import datetime from pymol import cmd # type: ignore import pandas as pd # type: ignore +import numpy as np # type: ignore from Bio import PDB # type: ignore from tqdm.auto import tqdm # type: ignore from terpeneminer.src.structure_processing.structural_algorithms import ( @@ -131,9 +132,13 @@ def detect_domains_roughly( ) regions_of_possible_domain = [] + + for uniprot_id, current_detections in file_2_tmscore_residues_domain.items(): for i, (tm_score, res_mapping) in enumerate(current_detections): - if tm_score >= domain_2_threshold[domain_this][0]: + logger.info(f'tm_score: {tm_score:.2f}') + logger.info(f'len of res_mapping: {len(res_mapping)}') + if tm_score >= domain_2_threshold[domain_this][0] and len(res_mapping) >= domain_2_threshold[domain_this][1]: regions_of_possible_domain.append( ( uniprot_id, @@ -160,7 +165,7 @@ def detect_domains_roughly( domain_2_possible_regions[domain_this] = regions_of_possible_domain - if domain_this == "alpha": + if "alpha" in domain_this: file_2_mapped_regions = get_mapped_regions_per_file( {"alpha": file_2_tmscore_residues_domain}, domain_2_threshold ) @@ -211,7 +216,7 @@ def detect_domains_roughly( domain_2_possible_regions[domain_this] += regions_of_possible_2nd_alphas file_2_known_regions: dict = defaultdict(list) - for domain_name_to_include in ["alpha", "epsilon", "delta", "beta", "gamma"]: + for domain_name_to_include in ["alpha", "epsilon", "delta", "beta", "gamma", "alphaWeird"]: potential_regions = domain_2_possible_regions[domain_name_to_include] # filter clashes with already loaded domains regions_of_possible_domain_to_include = [ @@ -274,7 +279,7 @@ def can_there_be_unassigned_domain( file_name: str, filename_2_remaining_residues_mapping: dict[str, set[str]], filename_2_known_regions_mapping: dict[str, list[MappedRegion]], - min_len: int = 90, + min_continuous_len: int = 15, max_allowed_gap: int = 3, ) -> bool: """ @@ -283,7 +288,7 @@ def can_there_be_unassigned_domain( :param file_name: The name of the file to check for unassigned domains :param filename_2_remaining_residues_mapping: A dictionary mapping filenames to sets of remaining residues not yet assigned to any domain :param filename_2_known_regions_mapping: A dictionary mapping filenames to lists of known MappedRegion objects - :param min_len: The minimum length of residues required to consider the presence of an unassigned domain, defaults to 90 + :param min_continuous_len: The minimum length of residues required to consider the presence of an unassigned domain, defaults to 15 :param max_allowed_gap: The maximum gap allowed between residues in a continuous segment, defaults to 3 :return: True if there could be an unassigned domain in the file, otherwise False @@ -292,12 +297,12 @@ def can_there_be_unassigned_domain( return False region_types = {reg.domain for reg in filename_2_known_regions_mapping[file_name]} if "alpha" not in region_types: - return len(filename_2_remaining_residues_mapping[file_name]) > min_len + return len(filename_2_remaining_residues_mapping[file_name]) > min_continuous_len return ( len( find_continuous_segments_longer_than( filename_2_remaining_residues_mapping[file_name], - min_secondary_struct_len=min_len, + min_secondary_struct_len=min_continuous_len, max_allowed_gap=max_allowed_gap, ) ) @@ -330,6 +335,30 @@ def get_confident_af_residues( return confident_residues +def get_all_confidence_values( + uniprot_id: str +) -> list[int]: + """ + Retrieves a set of residues from an AlphaFold PDB file that have a confidence score (B-factor) above the specified threshold. + + :param uniprot_id: The UniProt ID of the protein for which the PDB file is to be parsed + :param confidence_threshold: The minimum B-factor required for a residue to be considered confident, defaults to 70 + + :return: A set of residue numbers that have a confidence score above the specified threshold + """ + parser = PDB.PDBParser() + structure = parser.get_structure(uniprot_id, f"{uniprot_id}.pdb") + + values = [] + for model in structure: + for chain in model: + for residue in chain: + for atom in residue: + values.append(atom.get_bfactor()) + break + return values + + if __name__ == "__main__": args = parse_args() # reading the needed proteins @@ -377,40 +406,31 @@ def get_confident_af_residues( n_jobs=args.n_jobs, ) - # Assigning missed secondary structure parts to the closest domains - filename_2_known_regions_completed = get_mapped_regions_with_surroundings_parallel( - list(filename_2_known_regions.keys()), - file_2_all_residues, - filename_2_known_regions, - n_jobs=args.n_jobs, - ) - - # Get unsegmented parts and iterate over all domain types for best hit + # Get unsegmented parts file_2_remaining_residues = get_remaining_residues( - filename_2_known_regions_completed, file_2_all_residues + filename_2_known_regions, file_2_all_residues ) + file_2_remaining_residues_unassigned = file_2_remaining_residues.copy() + for filename in filename_2_known_regions.keys(): + segments = find_continuous_segments_longer_than( + file_2_remaining_residues[filename], + min_secondary_struct_len=120, + max_allowed_gap=25, + ) + residues = set(map(str, sum(segments, []))) + if residues: + file_2_remaining_residues_unassigned[filename] = residues - pdb_files_with_poteintial_unsegmented_domains = [ - filename - for filename in pdb_files - if can_there_be_unassigned_domain( - filename.stem, - file_2_remaining_residues, - filename_2_known_regions_completed, - min_len=70, - max_allowed_gap=5, - ) - ] - + # attempt to detect additional domains in the remaining parts domain_2_file_2_tmscore_residues = {} for domain_type, ( tmscore_threshold, mapping_size_threshold, ) in DOMAIN_2_THRESHOLD.items(): domain_2_file_2_tmscore_residues[domain_type] = get_alignments( - pdb_files_with_poteintial_unsegmented_domains, + pdb_files, domain_type, - file_2_remaining_residues, + file_2_remaining_residues_unassigned, tmscore_threshold, mapping_size_threshold, n_jobs=args.n_jobs, @@ -444,11 +464,25 @@ def get_confident_af_residues( ): filename_2_known_regions[uni_id].append(new_region) + + + # Assigning missed secondary structure parts to the closest domains + filename_2_known_regions_completed = get_mapped_regions_with_surroundings_parallel( + list(filename_2_known_regions.keys()), + file_2_all_residues, + filename_2_known_regions, + n_jobs=args.n_jobs, + ) + # Getting confident residues filename_2_known_regions_completed_confident = {} for filename, regions in tqdm(filename_2_known_regions_completed.items()): if args.is_bfactor_confidence: conf_residues = get_confident_af_residues(filename) + if len(conf_residues) < 0.6 * len(file_2_all_residues[filename]): + logger.warning("Too few confident residues, leaving top-80% most confident residues") + all_confidence_values = get_all_confidence_values(filename) + conf_residues = get_confident_af_residues(filename, np.percentile(all_confidence_values, 20)) new_regions = [] for mapped_region_init in regions: new_residues_mapping = { @@ -504,12 +538,13 @@ def get_confident_af_residues( mapped_residues = list(set(region.residues_mapping.keys())) cmd.delete(filename) cmd.load(f"{filename}.pdb") + print(f"{region.module_id}", + f"{filename} & resi {compress_selection_list(mapped_residues)}") cmd.select( f"{region.module_id}", f"{filename} & resi {compress_selection_list(mapped_residues)}", ) cmd.save(f"{PATH}/{region.module_id}.pdb", f"{region.module_id}") - print('saving to : ', f"{domains_output_path}/{region.module_id}.pdb") cmd.save(f"{domains_output_path}/{region.module_id}.pdb", f"{region.module_id}") cmd.delete(filename) diff --git a/terpeneminer/src/structure_processing/predict_domain_types.py b/terpeneminer/src/structure_processing/predict_domain_types.py index 584cce2..2abbf2b 100644 --- a/terpeneminer/src/structure_processing/predict_domain_types.py +++ b/terpeneminer/src/structure_processing/predict_domain_types.py @@ -46,7 +46,7 @@ def parse_args() -> argparse.Namespace: with open(args.tps_classifiers_path, "rb") as file: tps_classifiers = pickle.load(file) with open(args.domain_classifiers_path, "rb") as file: - novel_domain_detectors, domain_type_classifiers = pickle.load(file) + domain_type_classifiers = pickle.load(file) with open(args.path_to_domain_comparisons, "rb") as file: comparison_results = pickle.load(file) comparison_results = comparison_results[args.id] @@ -71,14 +71,11 @@ def parse_args() -> argparse.Namespace: for class_name, pred_val in zip(classifier.classes_, domain_type_pred[0]): domain_type_2_pred_values[class_name].append(pred_val) domain_type_2_pred = {dom_type: np.mean(vals) for dom_type, vals in domain_type_2_pred_values.items()} - # max_pred = -float('inf') - # gen_type_2_pred = {} - # for domain_type, type_preds in domain_type_2_pred.items(): - # if - # gen_type_2_pred[domain_type_gen] = novel_domain_detectors[domain_type].predict_proba(X_np)[0][1] - # for type_preds in domain_type_2_pred.values(): - # max_pred = max(max_pred, np.max(type_preds)) - # domain_type_2_pred.update({"novel": 1 - max_pred}) + max_pred = -float('inf') + gen_type_2_pred = {} + for type_preds in domain_type_2_pred.values(): + max_pred = max(max_pred, np.max(type_preds)) + domain_type_2_pred.update({"novel": 1 - max_pred}) domain_id_2_predictions[new_protein_domain_id] = domain_type_2_pred with open(args.output_path, "wb") as file: diff --git a/terpeneminer/src/structure_processing/structural_algorithms.py b/terpeneminer/src/structure_processing/structural_algorithms.py index ff24f68..e458198 100644 --- a/terpeneminer/src/structure_processing/structural_algorithms.py +++ b/terpeneminer/src/structure_processing/structural_algorithms.py @@ -28,13 +28,14 @@ handler.setFormatter(formatter) logger.addHandler(handler) -SUPPORTED_DOMAINS = {"alpha", "beta", "gamma", "delta", "epsilon"} +SUPPORTED_DOMAINS = {"alpha", "beta", "gamma", "delta", "epsilon", "alphaWeird"} DOMAIN_2_THRESHOLD = { - "beta": (0.6, 50), - "delta": (0.6, 50), - "epsilon": (0.6, 50), - "gamma": (0.55, 50), + "beta": (0.5, 50), + "delta": (0.5, 50), + "epsilon": (0.5, 50), + "gamma": (0.5, 50), "alpha": (0.35, 130), + "alphaWeird": (0.5, 100), } @@ -81,6 +82,7 @@ def prepare_domain(pymol_cmd, domain_name: str) -> tuple: "gamma": "3p5r", "delta": "P48449", "epsilon": "P48449", + "alphaWeird": "Q7Z859", } ) assert domain_name in domain_2_standard, f"Domain {domain_name} is not supported" @@ -99,6 +101,7 @@ def prepare_domain(pymol_cmd, domain_name: str) -> tuple: "gamma": " & resi 138-151+157-171+185-222+233-248+258-275+281-304+313-339 & chain A & ss H+S", "delta": " & resi 73-87+385-399+401-403+405-421+454-470+480-493+531-547+553-570+585-599+610-622+633-638+649-662+667-680+707-722+727-729 & chain A & ss H+S", "epsilon": " & resi 103-115+123-134+151-164+171-183+191-200+213-217+226-228+231-246+254-263+268-270+273-277+291-306+309-330+337-351+356-371+376-378+510-515 & chain A & ss H+S", + "alphaWeird": " & resi 6-23+38-58+66-68+81-96+118-136+152-177+183-208+229-239+241-246 & chain A & ss H+S" }[domain_name] domain_name_new = f"{domain_name}_domain_{uuid4()}" pymol_cmd.select(domain_name_new, f"{required_file} {selection_condition}") @@ -891,6 +894,7 @@ def get_mapped_regions_with_surroundings( already_mapped_residues ) + if not exists_in_pymol(cmd, filename): if not os.path.exists(f"{filename}.pdb"): raise FileNotFoundError(f"{filename}.pdb while being in {os.getcwd()}") @@ -955,21 +959,23 @@ def get_mapped_regions_with_surroundings( if min_dist < helix_sheet_dist_threshold: if len(all_dists_with_regions) >= 2: # leave unassigned if it is similarly close to two different regions - second_closest_dist, second_closest_region_i = min( - [ + regions_apart_from_the_closest = [ (dist, region) for (dist, region) in all_dists_with_regions if dist > min_dist - ], - key=lambda x: x[0], - ) - if ( - second_closest_region_i == closest_region_i - or min_dist < 0.9 * second_closest_dist - ): - mapped_region_2_added_residues[closest_region_i].extend( - residue_segment_remaining + ] + if regions_apart_from_the_closest: + second_closest_dist, second_closest_region_i = min( + regions_apart_from_the_closest, + key=lambda x: x[0], ) + if ( + second_closest_region_i == closest_region_i + or min_dist < 0.9 * second_closest_dist + ): + mapped_region_2_added_residues[closest_region_i].extend( + residue_segment_remaining + ) else: mapped_region_2_added_residues[closest_region_i].extend( residue_segment_remaining diff --git a/terpeneminer/src/structure_processing/train_domain_type_classifiers.py b/terpeneminer/src/structure_processing/train_domain_type_classifiers.py index 15bec95..a9a2818 100644 --- a/terpeneminer/src/structure_processing/train_domain_type_classifiers.py +++ b/terpeneminer/src/structure_processing/train_domain_type_classifiers.py @@ -32,20 +32,23 @@ dom_subset, feat_indices_subset = pickle.load(file) with open('data/domain_module_id_2_domain_type.pkl', 'rb') as file: domain_module_id_2_domain_type = pickle.load(file) - with open("data/precomputed_tmscores.pkl", "rb") as file: + with open("data/precomputed_tmscores_foldseek.pkl", "rb") as file: regions_ids_2_tmscore = pickle.load(file) - domain_type_classifiers = [] - novel_domain_detectors = [] - + fold_2_domain_type_predictions = [] + fold_2_predictions = [] + y_is_novel_test_all, y_pred_novel_all = [], [] for FOLD in range(5): - logger.info('Processing fold: %d', FOLD) + hits_count = 0 + miss_count = 0 classifier = fold_classifiers[FOLD] - new_fold_domains = [module_id for module_id in domain_module_id_2_dist_matrix_index.keys() if module_id not in classifier.order_of_domain_modules] + new_fold_domains = [module_id for module_id in domain_module_id_2_dist_matrix_index.keys() if + module_id not in classifier.order_of_domain_modules] ref_types = {domain_module_id_2_domain_type[mod_id] for mod_id in classifier.order_of_domain_modules} y = np.array([domain_module_id_2_domain_type[mod_id] for mod_id in new_fold_domains]) y_is_novel = np.array([int(dom_type not in ref_types) for dom_type in y]) + # print(Counter(y), Counter(y_is_novel)) X_list = [] @@ -53,39 +56,38 @@ dists_current = [] for ref_mod_id in classifier.order_of_domain_modules: dom_ids = tuple(sorted([mod_id, ref_mod_id])) - tmscore = regions_ids_2_tmscore[dom_ids] + try: + tmscore = regions_ids_2_tmscore[dom_ids] + hits_count += 1 + except KeyError: + miss_count += 1 + tmscore = 0 dists_current.append(tmscore) X_list.append(dists_current) X_np = np.array(X_list) - # novelty detector - X_np_trn, X_np_test, y_is_novel_trn, y_is_novel_test = train_test_split(X_np, y_is_novel, stratify=y_is_novel) - classifier = RandomForestClassifier(500) - classifier.fit(X_np_trn, y_is_novel_trn) - y_pred = classifier.predict_proba(X_np_test)[:, 1] - logger.info(f'Novelty detection mAP: {average_precision_score(y_is_novel_test, y_pred):.3f}') - novelty_detector = RandomForestClassifier(500) - novelty_detector.fit(X_np, y_is_novel) - novel_domain_detectors.append(novelty_detector) + dom_classifier = RandomForestClassifier(500) + dom_classifier.fit(X_np, y) + domain_type_classifiers.append(dom_classifier) - #domain type classifier - X_np_trn, X_np_test, y_trn, y_test = train_test_split(X_np, y, stratify=y) - label_binarizer = MultiLabelBinarizer() - y_trn = label_binarizer.fit_transform( - y_trn - ) - y_test = label_binarizer.transform( - y_test - ) - classifier = RandomForestClassifier(500) - classifier.fit(X_np_trn, y_trn) - y_pred = classifier.predict_proba(X_np_test) - y_pred_all = np.array([y_pred_class[:, 1] for y_pred_class in y_pred]).T - logger.info(f'Domain type classification mAP: {average_precision_score(y_test, y_pred_all):.3f}') + # novelty detector evaluation + X_np_novel, y_novel = X_np[y_is_novel == 1], y[y_is_novel == 1] + X_np_known, y_known = X_np[y_is_novel == 0], y[y_is_novel == 0] + try: + X_np_trn, X_np_test, y_trn, y_test = train_test_split(X_np_known, y_known, stratify=y_known) + X_np_test = np.concatenate((X_np_test, X_np_novel)) + y_is_novel_test = np.concatenate((np.zeros(len(y_test)), np.ones(len(y_novel)))) - classifier = RandomForestClassifier(500) - classifier.fit(X_np, y) - domain_type_classifiers.append(classifier) + dom_classifier = RandomForestClassifier(500) + dom_classifier.fit(X_np_trn, y_trn) + y_pred_all = dom_classifier.predict_proba(X_np_test) + y_pred = 1 - y_pred_all.max(axis=1) + y_is_novel_test_all.extend(y_is_novel_test) + y_pred_novel_all.extend(y_pred) + except ValueError: + logger.warning(f'Not enough un-covered domain types for fold {FOLD} (it does not influence the final results, the fold is just excluded from the novelty detection evaluation metric)') + if sum(y_is_novel_test_all): + logger.info(f'Novelty detection mAP: {average_precision_score(y_is_novel_test_all, y_pred_novel_all):.3f}') - with open("data/domain_type_predictors.pkl", "wb") as file: - pickle.dump([novel_domain_detectors, domain_type_classifiers], file) + with open("data/domain_type_predictors_foldseek.pkl", "wb") as file: + pickle.dump(domain_type_classifiers, file) diff --git a/terpeneminer/src/terpene_miner_main.py b/terpeneminer/src/terpene_miner_main.py index 792db41..4e9dd36 100644 --- a/terpeneminer/src/terpene_miner_main.py +++ b/terpeneminer/src/terpene_miner_main.py @@ -30,6 +30,8 @@ def parse_args() -> argparse.Namespace: parser_run = subparsers.add_parser("run", help="Run experiment(s)") parser_run.set_defaults(cmd="run") parser_run.add_argument("--load-hyperparameters", action="store_true") + parser_run.add_argument("--model", type=str, default=None) + parser_run.add_argument("--model-version", type=str, default=None) parser_eval = subparsers.add_parser("evaluate", help="Evaluate experiment(s)") parser_eval.set_defaults(cmd="evaluate") @@ -195,7 +197,13 @@ def run_selected_experiments(args: argparse.Namespace): config_root_path = get_config_root() if args.select_single_experiment: - experiment_kwargs = collect_single_experiment_arguments(config_root_path) + if args.model is None or args.model_version is None: + experiment_kwargs = collect_single_experiment_arguments(config_root_path) + else: + experiment_kwargs = { + "model_type": args.model, + "model_version": args.model_version, + } experiment_info = ExperimentInfo(**experiment_kwargs) run_experiment(experiment_info, load_hyperparameters=args.load_hyperparameters) else: diff --git a/terpeneminer/src/utils/data.py b/terpeneminer/src/utils/data.py index e1787f7..2b7678a 100644 --- a/terpeneminer/src/utils/data.py +++ b/terpeneminer/src/utils/data.py @@ -5,7 +5,7 @@ import numpy as np # type: ignore import pandas as pd # type: ignore from indigo import Indigo # type: ignore - +from Bio.PDB import PDBParser, PPBuilder # type: ignore logging.getLogger("h5py").setLevel(logging.INFO) import h5py # type: ignore # pylint: disable=C0413 @@ -224,3 +224,39 @@ def get_canonical_smiles(smiles: str, without_stereo: bool = True): mol.clearCisTrans() mol.clearStereocenters() return mol.canonicalSmiles() + + +def extract_sequences_from_pdb(pdb_filepath: str): + """ + Extract amino acid sequences from a PDB file for each chain. + + Parameters: + pdb_file (str): Path to the PDB file. + + Returns: + dict: A dictionary where keys are chain IDs and values are amino acid sequences. + """ + # Create a PDBParser object + parser = PDBParser() + + # Parse the PDB file + structure = parser.get_structure('protein_structure', pdb_filepath) + + # Create a Polypeptide builder + ppb = PPBuilder() + + sequences = {} + + # Iterate over each model (usually there's only one) + for model in structure: + # Iterate over each chain in the model + for chain in model: + # Build polypeptides (sequences) from the chain + polypeptides = ppb.build_peptides(chain) + # Concatenate sequences if there are multiple polypeptides + sequence = ''.join([str(pp.get_sequence()) for pp in polypeptides]) + sequences[chain.id] = sequence + return sequences + + + From d21277ab7ee9f3a8c7ed782b14d9020d40580ee7 Mon Sep 17 00:00:00 2001 From: Raman Date: Sun, 10 Nov 2024 07:28:45 +0100 Subject: [PATCH 4/4] bulk commit of work on a backend app to deploy terpeneminer --- app_faster_with_foldseek.py | 39 +++++++++++-------- .../plm_domain_comparison_randomforest.py | 2 +- .../structure_processing/domain_detections.py | 18 ++++----- .../structural_algorithms.py | 8 ++-- 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/app_faster_with_foldseek.py b/app_faster_with_foldseek.py index 790951b..be9d636 100644 --- a/app_faster_with_foldseek.py +++ b/app_faster_with_foldseek.py @@ -50,12 +50,27 @@ class MotifDetection: max_len=1022, ) +model_fallback, batch_converter_fallback, alphabet_fallback = get_model_and_tokenizer( + "esm-1v", return_alphabet=True + ) +compute_embeddings_partial_fallback = partial( + compute_embeddings, + bert_model=model_fallback, + converter=batch_converter_fallback, + padding_idx=alphabet_fallback.padding_idx, + model_repr_layer=33, + max_len=1022, +) + with open('data/classifier_domain_and_plm_checkpoints.pkl', 'rb') as file: fold_classifiers = pickle.load(file) with open('data/classifier_plm_checkpoints.pkl', 'rb') as file: fold_plm_classifiers = pickle.load(file) +with open('data/classifier_plm_checkpoints_esm1v.pkl', 'rb') as file: + fold_plm_classifiers_fallback = pickle.load(file) + # Create FastAPI app instance app = FastAPI() @@ -338,18 +353,8 @@ async def upload_file(file: UploadFile = File(...), dom_feat[dom_feat_idx] = known_domain_id_2_tmscore.get(known_module_id, 0) if np.max(dom_feat) < 0.4: logger.warning("No meaningful domain comparisons. Skipping the model.. ") - if hasattr(classifier, "domain_feature_novelty_detector") and getattr(classifier, - "domain_feature_novelty_detector") is not None: - novelty_prediction = classifier.domain_feature_novelty_detector.predict(dom_feat)[0] - logger.warning( - f"Novelty prediction would have been {novelty_prediction}") continue dom_feat = 1 - dom_feat.reshape(1, -1) - if hasattr(classifier, "domain_feature_novelty_detector") and getattr(classifier, "domain_feature_novelty_detector") is not None: - novelty_prediction = classifier.domain_feature_novelty_detector.predict(dom_feat)[0] - if novelty_prediction == -1: - logger.warning("Data drift detected in domain comparisons. Skipping the model..") - continue if classifier.plm_feat_indices_subset is not None: emb_plm = np.apply_along_axis(lambda i: i[classifier.plm_feat_indices_subset], 1, enzyme_encodings_np_batch) else: @@ -374,15 +379,15 @@ async def upload_file(file: UploadFile = File(...), predictions[sample_i][class_name].append(value) print('predictions: ', predictions) if len(predictions) == 0: - logger.warning("Falling back to PLM features only due to severe data drift in domain comparisons") + logger.warning("Falling back to generic PLM features due to severe data drift") predictions = [] - for classifier_i, classifier in enumerate(fold_plm_classifiers): + for classifier_i, classifier in enumerate(fold_plm_classifiers_fallback): logger.info(f"Predicting with plm classifier {classifier_i + 1}/{len(fold_classifiers)}..") - if hasattr(classifier, "plm_feature_novelty_detector") and getattr(classifier, - "plm_feature_novelty_detector") is not None: - novelty_prediction = classifier.plm_feature_novelty_detector.predict(enzyme_encodings_np_batch)[0] - logger.warning( - f"PLM emb novelty prediction is {novelty_prediction}") + ( + enzyme_encodings_np_batch, + _, + ) = compute_embeddings_partial_fallback(input_seqs=input_seq) + y_pred_proba = classifier.predict_proba(enzyme_encodings_np_batch) for sample_i in range(n_samples): predictions_raw = {} diff --git a/terpeneminer/src/models/plm_domain_comparison_randomforest.py b/terpeneminer/src/models/plm_domain_comparison_randomforest.py index f6752d2..bda3147 100644 --- a/terpeneminer/src/models/plm_domain_comparison_randomforest.py +++ b/terpeneminer/src/models/plm_domain_comparison_randomforest.py @@ -89,7 +89,7 @@ def fit_core(self, train_df: pd.DataFrame, class_name: str = None): # novelty detector to check for data drift dom_feats_trn = np.stack(dom_features_df["Emb_dom"].values) - self.domain_feature_novelty_detector = IsolationForest(n_estimators=400).fit(dom_feats_trn) + self.domain_feature_novelty_detector = IsolationForest(contamination=0.02, n_estimators=400).fit(dom_feats_trn) logger.info(f"Novelty detector for domain features is trained. Proportion of outliers: {np.mean(self.domain_feature_novelty_detector.predict(dom_feats_trn) == -1):.2f}") plm_feats = np.stack(self.features_df_plm["Emb"].values) diff --git a/terpeneminer/src/structure_processing/domain_detections.py b/terpeneminer/src/structure_processing/domain_detections.py index 1661cab..e966602 100644 --- a/terpeneminer/src/structure_processing/domain_detections.py +++ b/terpeneminer/src/structure_processing/domain_detections.py @@ -406,6 +406,14 @@ def get_all_confidence_values( n_jobs=args.n_jobs, ) + # Assigning missed secondary structure parts to the closest domains + filename_2_known_regions_completed = get_mapped_regions_with_surroundings_parallel( + list(filename_2_known_regions.keys()), + file_2_all_residues, + filename_2_known_regions, + n_jobs=args.n_jobs, + ) + # Get unsegmented parts file_2_remaining_residues = get_remaining_residues( filename_2_known_regions, file_2_all_residues @@ -464,16 +472,6 @@ def get_all_confidence_values( ): filename_2_known_regions[uni_id].append(new_region) - - - # Assigning missed secondary structure parts to the closest domains - filename_2_known_regions_completed = get_mapped_regions_with_surroundings_parallel( - list(filename_2_known_regions.keys()), - file_2_all_residues, - filename_2_known_regions, - n_jobs=args.n_jobs, - ) - # Getting confident residues filename_2_known_regions_completed_confident = {} for filename, regions in tqdm(filename_2_known_regions_completed.items()): diff --git a/terpeneminer/src/structure_processing/structural_algorithms.py b/terpeneminer/src/structure_processing/structural_algorithms.py index e458198..375def6 100644 --- a/terpeneminer/src/structure_processing/structural_algorithms.py +++ b/terpeneminer/src/structure_processing/structural_algorithms.py @@ -30,10 +30,10 @@ SUPPORTED_DOMAINS = {"alpha", "beta", "gamma", "delta", "epsilon", "alphaWeird"} DOMAIN_2_THRESHOLD = { - "beta": (0.5, 50), - "delta": (0.5, 50), - "epsilon": (0.5, 50), - "gamma": (0.5, 50), + "beta": (0.6, 50), + "delta": (0.6, 50), + "epsilon": (0.6, 50), + "gamma": (0.55, 50), "alpha": (0.35, 130), "alphaWeird": (0.5, 100), }