From 465b651fbfcf346569674002799879436c3a176a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:17 +0100 Subject: [PATCH 01/30] use instantiate model to load data and model from ckpt --- chebai/trainer/CustomTrainer.py | 69 +++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index f7fbce26..e84aad4c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,4 +1,5 @@ import logging +import os from typing import Any, List, Optional, Tuple import pandas as pd @@ -6,13 +7,17 @@ from lightning import LightningModule, Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call from torch.nn.utils.rnn import pad_sequence +from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -44,6 +49,7 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ @@ -76,12 +82,10 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: def predict_from_file( self, - model: LightningModule, checkpoint_path: _PATH, input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -93,20 +97,21 @@ def predict_from_file( save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) with open(input_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) + self._predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) def _predict_smiles( - self, model: LightningModule, smiles: List[str] + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: _PATH = "predictions.csv", ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. @@ -118,22 +123,47 @@ def _predict_smiles( Returns: A tensor containing the predictions. """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.to(self.device) + model.eval() + + parsed_smiles = [dm.reader._read_data(s) for s in smiles] x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], + [torch.tensor(a, device=self.device) for a in parsed_smiles], batch_first=True, ) cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) + torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) * CLS_TOKEN ) features = torch.cat((cls_tokens, x), dim=1) model_output = model({"features": features}) preds = torch.sigmoid(model_output["logits"]) - print(preds.shape) - return preds + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: @@ -157,7 +187,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start From 2acd166b87f03df14839fa9c5c381b1669bc6e84 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:38 +0100 Subject: [PATCH 02/30] update readme --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 38 ++++++++++++++++++--------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index eeecd714..2555c0a6 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the one row for each SMILES string and one column for each class. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 68254007..f1357c88 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property @@ -1190,7 +1190,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) @@ -1315,3 +1316,14 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in custom trainer `chebai/trainer/CustomTrainer.py` + return os.path.join(self.processed_dir_main, "classes.txt") From cfbf392f13773c0ce92719b24a83ef231cea2e32 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Nov 2025 14:50:16 +0100 Subject: [PATCH 03/30] set no grad for predict --- chebai/trainer/CustomTrainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index e84aad4c..6f1a542c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -80,6 +80,7 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value + @torch.no_grad() def predict_from_file( self, checkpoint_path: _PATH, @@ -106,6 +107,7 @@ def predict_from_file( save_to=save_to, ) + @torch.no_grad() def _predict_smiles( self, checkpoint_path: _PATH, From 82b365ca31698da387f4077c13ea345d9572aa04 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 16:36:15 +0100 Subject: [PATCH 04/30] predict pipeline in dm and lm --- chebai/models/base.py | 8 ++++++- chebai/preprocessing/datasets/base.py | 31 ++++++++++++++++++++++----- chebai/trainer/CustomTrainer.py | 4 +++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 7653f13c..808ea59e 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -232,7 +232,13 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + labels = data["labels"] + model_output = self(data, **data.get("model_kwargs", dict())) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f1357c88..e2df794d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] + + return self._filter_to_token_limit(data) + + def _filter_to_token_limit( + self, data: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: # filter for missing features in resulting data, keep features length below token limit - data = [ + return [ val for val in data if val["features"] is not None @@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: ) ] - return data - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -400,10 +404,13 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs + self, + smiles_list: List[str], + **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ Returns the predict DataLoader. @@ -415,7 +422,21 @@ def predict_dataloader( Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + data = self._filter_to_token_limit(data) + + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 6f1a542c..acd468f2 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -4,7 +4,7 @@ import pandas as pd import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH from lightning.pytorch.cli import instantiate_module @@ -87,6 +87,7 @@ def predict_from_file( input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -114,6 +115,7 @@ def _predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: _PATH = "predictions.csv", + **kwargs, ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. From fa6f1b521d05d38261973f602ff53232ec5eabb3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 18:16:36 +0100 Subject: [PATCH 05/30] there is no need that predict func must depend on trainer --- chebai/result/prediction.py | 111 ++++++++++++++++++++++++++++++++ chebai/trainer/CustomTrainer.py | 107 ++++-------------------------- 2 files changed, 122 insertions(+), 96 deletions(-) create mode 100644 chebai/result/prediction.py diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..cb3e1415 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,111 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module +from torch.utils.data import DataLoader + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + super().__init__() + + def predict_from_file( + self, + checkpoint_path: _PATH, + input_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. + input_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + """ + with open(input_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + self.predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) + + @torch.inference_mode() + def predict_smiles( + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: Optional[_PATH] = None, + **kwargs, + ) -> torch.Tensor | None: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + model: The model to use for predictions. + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + print(f"Loaded datamodule class: {dm.__class__.__name__}") + + pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.eval() + # model = torch.compile(model) + + print(f"Loaded model class: {model.__class__.__name__}") + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + preds.append(model.predict_step(batch, batch_idx)) + + if not save_to: + # If no save path is provided, return the predictions tensor + return torch.cat(preds) + + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + CLI(Predictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index acd468f2..11ade921 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,23 +1,14 @@ import logging -import os -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence -from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -80,94 +71,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - @torch.no_grad() - def predict_from_file( + def predict( self, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - self._predict_smiles( - checkpoint_path, - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) - - @torch.no_grad() - def _predict_smiles( - self, - checkpoint_path: _PATH, - smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: _PATH = "predictions.csv", - **kwargs, - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) - model.to(self.device) - model.eval() - - parsed_smiles = [dm.reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=self.device) for a in parsed_smiles], - batch_first=True, - ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) - * CLS_TOKEN + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) - - predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) - - predictions_df.index = smiles - predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: From ae47608eb26918d497ef2cd60e750c4c2d585782 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 13:14:18 +0100 Subject: [PATCH 06/30] model hparams for data predict pipeline and vice versa --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 53 +++++++++++++++++---------- chebai/result/prediction.py | 23 ++++++------ 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 2555c0a6..713a9d42 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e2df794d..5e3064b3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or (self.label_filter is None), ( - "Filter balancing requires a filter" - ) + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert fold_index is None or self.use_inner_cross_validation is not None, ( - "fold_index can only be set if cross validation is used" - ) + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" if fold_index is not None and self.inner_k_folds is not None: - assert fold_index < self.inner_k_folds, ( - "fold_index can't be larger than the total number of folds" - ) + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert self._feature_vector_size is not None, ( - "size of feature vector must be set" - ) + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" return self._feature_vector_size @property @@ -410,6 +410,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def predict_dataloader( self, smiles_list: List[str], + model_hparams: Optional[dict] = None, **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ @@ -423,6 +424,26 @@ def predict_dataloader( Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ + data = self._process_input_for_prediction(smiles_list, model_hparams) + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> list: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + + Returns: + List[Dict[str, Any]]: Processed input data. + """ data = [ self.reader.to_data( {"id": f"smiles_{idx}", "features": smiles, "labels": None} @@ -430,13 +451,7 @@ def predict_dataloader( for idx, smiles in enumerate(smiles_list) ] data = self._filter_to_token_limit(data) - - return DataLoader( - data, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) + return data def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index cb3e1415..250d68b6 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -67,25 +67,26 @@ def predict_smiles( checkpoint_path, map_location=self.device, weights_only=False ) - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) + dm_hparams = ckpt_file["datamodule_hyper_parameters"] + dm_hparams.pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") - pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) + model_hparams = ckpt_file["hyper_parameters"] + model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") + # For certain data prediction piplines, we may need model hyperparameters + pred_dl: DataLoader = dm.predict_dataloader( + smiles_list=smiles, model_hparams=model_hparams + ) + preds = [] for batch_idx, batch in enumerate(pred_dl): - preds.append(model.predict_step(batch, batch_idx)) + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) if not save_to: # If no save path is provided, return the predictions tensor From 517a5a2061927f95a6e2ceb4661d1948e3d1defd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:19:55 +0100 Subject: [PATCH 07/30] fix reader ident error --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..7e41510c 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["features"]) + return row.get("ident", row["id"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From 40491e55f97d94d75e2f987ffd99f4f68e02f616 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:23:37 +0100 Subject: [PATCH 08/30] fix label None error --- chebai/models/base.py | 4 +++- chebai/result/prediction.py | 32 ++++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 808ea59e..c6c347a4 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -235,8 +235,10 @@ def predict_step( assert isinstance(batch, XYData) batch = batch.to(self.device) data = self._process_batch(batch, batch_idx) - labels = data["labels"] model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) pr, _ = self._get_prediction_and_labels(data, labels, model_output) return pr diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 250d68b6..6f8e41c1 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -20,21 +20,25 @@ def __init__(self): def predict_from_file( self, checkpoint_path: _PATH, - input_path: _PATH, + smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - model: The model to use for predictions. checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. + smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. """ - with open(input_path, "r") as input: + with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] self.predict_smiles( @@ -42,6 +46,7 @@ def predict_from_file( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, + batch_size=batch_size, ) @torch.inference_mode() @@ -51,16 +56,24 @@ def predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: Optional[_PATH] = None, + batch_size: Optional[int] = None, **kwargs, ) -> torch.Tensor | None: """ Predicts the output for a list of SMILES strings using the model. Args: - model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - - Returns: + classes_path: Optional path to a file containing class names. If no class + names are provided, code will try to get the class path from the datamodule, + else the columns will be numbered. + save_to: Optional path to save the predictions CSV file. If not provided, + predictions will be returned as a tensor. + batch_size: Optional batch size for the DataLoader. If not provided, the default + from the datamodule will be used. + + Returns: (if save_to is None) A tensor containing the predictions. """ ckpt_file = torch.load( @@ -71,10 +84,13 @@ def predict_smiles( dm_hparams.pop("splits_file_path") dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") + if batch_size is not None: + dm.batch_size = batch_size model_hparams = ckpt_file["hyper_parameters"] model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() + # TODO: Enable torch.compile when supported # model = torch.compile(model) print(f"Loaded model class: {model.__class__.__name__}") From 7b7e48f7bb581b87f89c162b0b3d4931a134872b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:16 +0100 Subject: [PATCH 09/30] fix cli predict_from_file error --- chebai/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de48f615 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -83,7 +83,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } From 6a383177eff88ba768076bd094c660aeeba27102 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:32 +0100 Subject: [PATCH 10/30] update readme for new prediction method --- README.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 713a9d42..7af8aad4 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,19 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs. + + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation From 31b12dbf2ce85b685d2ee8bd721bf4f97e995db6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 23:26:10 +0100 Subject: [PATCH 11/30] Revert "fix reader ident error" This reverts commit 517a5a2061927f95a6e2ceb4661d1948e3d1defd. --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 7e41510c..22b91a0e 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["id"]) + return row.get("ident", row["features"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From c6e8b6137c6bc0ca0e97481587361a57d3ccbcda Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 28 Nov 2025 15:42:48 +0100 Subject: [PATCH 12/30] modify pred logic to store model and dm as instance var --- chebai/result/prediction.py | 140 ++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 6f8e41c1..b84e59ad 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,116 +13,132 @@ class Predictor: - def __init__(self): + def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - super().__init__() + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + print(f"Loaded datamodule class: {self._dm.__class__.__name__}") + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + self._model.eval() + # TODO: Enable torch.compile when supported + # model = torch.compile(model) + print(f"Loaded model class: {self._model.__class__.__name__}") def predict_from_file( self, - checkpoint_path: _PATH, smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - checkpoint_path: Path to the model checkpoint. smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names: if no class names are provided, code will try to get the class path from the datamodule, else the columns will be numbered. - batch_size: Optional batch size for the DataLoader. If not provided, - the default from the datamodule will be used. """ with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - self.predict_smiles( - checkpoint_path, + preds: torch.Tensor = self.predict_smiles( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, - batch_size=batch_size, ) + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + _add_class_columns(self._dm.classes_txt_file_path) + + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + @torch.inference_mode() def predict_smiles( self, - checkpoint_path: _PATH, smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: Optional[_PATH] = None, - batch_size: Optional[int] = None, - **kwargs, - ) -> torch.Tensor | None: + ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. Args: - checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - classes_path: Optional path to a file containing class names. If no class - names are provided, code will try to get the class path from the datamodule, - else the columns will be numbered. - save_to: Optional path to save the predictions CSV file. If not provided, - predictions will be returned as a tensor. - batch_size: Optional batch size for the DataLoader. If not provided, the default - from the datamodule will be used. - - Returns: (if save_to is None) + + Returns: A tensor containing the predictions. """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - dm_hparams = ckpt_file["datamodule_hyper_parameters"] - dm_hparams.pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) - print(f"Loaded datamodule class: {dm.__class__.__name__}") - if batch_size is not None: - dm.batch_size = batch_size - - model_hparams = ckpt_file["hyper_parameters"] - model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) - model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") - # For certain data prediction piplines, we may need model hyperparameters - pred_dl: DataLoader = dm.predict_dataloader( - smiles_list=smiles, model_hparams=model_hparams + pred_dl: DataLoader = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams ) preds = [] for batch_idx, batch in enumerate(pred_dl): # For certain model prediction pipelines, we may need data module hyperparameters - preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) - if not save_to: - # If no save path is provided, return the predictions tensor - return torch.cat(preds) + return torch.cat(preds) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) - predictions_df.index = smiles - predictions_df.to_csv(save_to) + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) if __name__ == "__main__": # python chebai/result/prediction.py predict_from_file --help - CLI(Predictor, as_positional=False) + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) From 63670ddada766d84c706752319841e2b7d5d4c60 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 19:32:14 +0100 Subject: [PATCH 13/30] fix for unwanted args to predict_smiles --- chebai/result/prediction.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index b84e59ad..c558d9c7 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -27,7 +27,7 @@ def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): ) self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - self._dm_hparams.pop("splits_file_path") + # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) @@ -63,11 +63,7 @@ def predict_from_file( with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - preds: torch.Tensor = self.predict_smiles( - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) + preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) From d906ad404ecac72fe1236193de462391f9da607b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:22:18 +0100 Subject: [PATCH 14/30] avoid non_null_labels key in loss kwargs --- chebai/preprocessing/datasets/base.py | 5 ++++- chebai/result/prediction.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 5e3064b3..ee60de99 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -444,9 +444,12 @@ def _process_input_for_prediction( Returns: List[Dict[str, Any]]: Processed input data. """ + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. data = [ self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} ) for idx, smiles in enumerate(smiles_list) ] diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index c558d9c7..a0a01050 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -65,7 +65,7 @@ def predict_from_file( preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) def _add_class_columns(class_file_path: _PATH): with open(class_file_path, "r") as f: From ba5884a996ed663ff2784d7dd3779c697df1540f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:37:58 +0100 Subject: [PATCH 15/30] compile model --- chebai/result/prediction.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index a0a01050..212a95e9 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,36 +13,47 @@ class Predictor: - def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): """Initializes the Predictor with a model loaded from the checkpoint. Args: checkpoint_path: Path to the model checkpoint. batch_size: Optional batch size for the DataLoader. If not provided, the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_file = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) - print(f"Loaded datamodule class: {self._dm.__class__.__name__}") if batch_size is not None and int(batch_size) > 0: self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") self._model_hparams = ckpt_file["hyper_parameters"] self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + if compile_model: + self._model = torch.compile(self._model) self._model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {self._model.__class__.__name__}") + print("-" * 50) def predict_from_file( self, From 5a17f722df61ce6e51b24c475a002e4cfbc75ffc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:43:48 +0100 Subject: [PATCH 16/30] revert the comment line for splits file path --- chebai/result/prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 212a95e9..e904c57d 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -36,7 +36,7 @@ def __init__( print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - # self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From bd59d5924f1de6fae1d4a2d54e2cf6a4bb3a67a1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 00:16:42 +0100 Subject: [PATCH 17/30] handle augment electra and old ckpt files --- chebai/result/prediction.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index e904c57d..a3b82364 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -37,6 +37,13 @@ def __init__( self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + assert "_class_path" in self._dm_hparams, ( + "Datamodule hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From 6b8ae0992e96a4ae77756521b9d5b1c54877d093 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 19:58:48 +0100 Subject: [PATCH 18/30] remove unnec device --- chebai/trainer/CustomTrainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 11ade921..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,7 +1,6 @@ import logging from typing import Any, Optional, Tuple -import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.pytorch.loggers import WandbLogger @@ -40,7 +39,6 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ From c0293959238b038165bb005e377c428373a46583 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 15 Dec 2025 15:07:22 +0100 Subject: [PATCH 19/30] raise error for invalid smiles and return None --- chebai/preprocessing/reader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..d63671f7 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -203,10 +203,13 @@ def _read_data(self, raw_data: str) -> List[int]: print(f"RDKit failed to process {raw_data}") print(f"\t{e}") try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] except ValueError as e: print(f"could not process {raw_data}") - print(f"\t{e}") + print(f"\tError: {e}") return None From 676107936d557c70c55b55e8c33481915e1b1a70 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 15 Dec 2025 22:48:05 +0100 Subject: [PATCH 20/30] rectify test as to return None for invalid strings --- tests/unit/readers/testChemDataReader.py | 32 +++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index ec018f00..9d322f27 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -42,19 +42,22 @@ def test_read_data(self) -> None: """ Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. """ - raw_data = "CC(=O)NC1[Mg-2]" + raw_data = "CC(=O)NC1CC1[Mg-2]" # Expected output as per the tokens already in the cache, and ")" getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 0, # C - EMBEDDING_OFFSET + 5, # = - EMBEDDING_OFFSET + 3, # O - EMBEDDING_OFFSET + 1, # N - EMBEDDING_OFFSET + len(self.reader.cache), # ( - EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 5, # ( + EMBEDDING_OFFSET + 3, # = + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token + EMBEDDING_OFFSET + 2, # N EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 4, # 1 - EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token ] result = self.reader._read_data(raw_data) self.assertEqual( @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None: Test the _read_data method with an invalid input. The invalid token should prompt a return value None """ - raw_data = "%INVALID%" - - result = self.reader._read_data(raw_data) - self.assertIsNone( - result, - "The output for invalid token '%INVALID%' should be None.", - ) + # see https://github.com/ChEB-AI/python-chebai/issues/137 + raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"] + for raw_data in raw_datas: + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + f"The output for invalid token '{raw_data}' should be None.", + ) @patch("builtins.open", new_callable=mock_open) def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: From d5e362b7a812da012235d52cbf29b4686263a51b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Dec 2025 21:32:19 +0100 Subject: [PATCH 21/30] pin rdkit version - https://github.com/ChEB-AI/python-chebai/issues/83 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eb75643d..4ba71f8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers", "pysmiles==1.1.2", - "rdkit", + "rdkit==2024.3.6", "lightning==2.5.1", ] From 5d302d0c03687c70c187afd9608a4db8f55aa1a0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Dec 2025 11:53:24 +0100 Subject: [PATCH 22/30] handle None returns for invalid smiles --- chebai/preprocessing/datasets/base.py | 57 ++++++++++++++++++--------- chebai/result/prediction.py | 15 +++++-- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index ee60de99..abda9471 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -412,49 +412,68 @@ def predict_dataloader( smiles_list: List[str], model_hparams: Optional[dict] = None, **kwargs, - ) -> Union[DataLoader, List[DataLoader]]: + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - data = self._process_input_for_prediction(smiles_list, model_hparams) - return DataLoader( - data, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, ) def _process_input_for_prediction( self, smiles_list: list[str], model_hparams: Optional[dict] = None - ) -> list: + ) -> tuple[list, list]: """ Process input data for prediction. Args: smiles_list (List[str]): List of SMILES strings. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. Returns: - List[Dict[str, Any]]: Processed input data. + tuple[list, list]: Processed input data and valid indices. """ + data, valid_indices = [], [] + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + data.append(result) + valid_indices.append(idx) + + data = self._filter_to_token_limit(data) + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. - data = [ - self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} - ) - for idx, smiles in enumerate(smiles_list) - ] - data = self._filter_to_token_limit(data) - return data + return self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index a3b82364..ad5775a1 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -6,7 +6,6 @@ from jsonargparse import CLI from lightning.fabric.utilities.types import _PATH from lightning.pytorch.cli import instantiate_module -from torch.utils.data import DataLoader from chebai.models.base import ChebaiBaseNet from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -101,7 +100,7 @@ def _add_class_columns(class_file_path: _PATH): def predict_smiles( self, smiles: List[str], - ) -> torch.Tensor: + ) -> list[torch.Tensor | None]: """ Predicts the output for a list of SMILES strings using the model. @@ -112,7 +111,7 @@ def predict_smiles( A tensor containing the predictions. """ # For certain data prediction piplines, we may need model hyperparameters - pred_dl: DataLoader = self._dm.predict_dataloader( + pred_dl, valid_indices = self._dm.predict_dataloader( smiles_list=smiles, model_hparams=self._model_hparams ) @@ -122,8 +121,16 @@ def predict_smiles( preds.append( self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) ) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred - return torch.cat(preds) + return output class MainPredictor: From e86f03ae8d5869bee3bd715691478710c4ed1258 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 18 Dec 2025 13:45:21 +0100 Subject: [PATCH 23/30] remove instanstior key from hparams as its causing unnecessary error --- chebai/result/prediction.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index ad5775a1..1410a6c5 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -30,6 +30,15 @@ def __init__( ckpt_file = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + print("-" * 50) print(f"For Loaded checkpoint from: {checkpoint_path}") print("Below are the modules loaded from the checkpoint:") @@ -38,11 +47,7 @@ def __init__( self._dm_hparams.pop("splits_file_path") self._dm_hparams.pop("augment_smiles", None) self._dm_hparams.pop("aug_smiles_variations", None) - assert "_class_path" in self._dm_hparams, ( - "Datamodule hyperparameters must include a '_class_path' key.\n" - "Hence, either the checkpoint is corrupted or " - "it was not saved properly with latest lightning version" - ) + self._dm_hparams.pop("_instantiator", None) self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) @@ -51,6 +56,7 @@ def __init__( print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) From 994de55469a1bab66e2c9ffdbaabb78a4d42ac7b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 18 Dec 2025 13:48:38 +0100 Subject: [PATCH 24/30] return none for token limit too --- chebai/preprocessing/datasets/base.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 19cc222f..d47f38f7 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -341,20 +341,17 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: if d["features"] is not None ] - return self._filter_to_token_limit(data) + data = [val for val in data if self._filter_to_token_limit(val)] + return data - def _filter_to_token_limit( - self, data: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def _filter_to_token_limit(self, data_instance: dict) -> bool: # filter for missing features in resulting data, keep features length below token limit - return [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ @@ -459,10 +456,11 @@ def _process_input_for_prediction( result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) if result is None or result["features"] is None: continue + if not self._filter_to_token_limit(result): + continue data.append(result) valid_indices.append(idx) - data = self._filter_to_token_limit(data) return data, valid_indices def _preprocess_smiles_for_pred( From 3dbf9ae69156d1aa81b67304dde7b1c335f6b3df Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 17:27:43 +0100 Subject: [PATCH 25/30] assert i/p and o/p dim greater than 0 --- chebai/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 4be89678..19fc6c6a 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -47,8 +47,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + assert out_dim is not None and out_dim > 0, "out_dim must be specified" + assert input_dim is not None and input_dim > 0, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( From 811fbf1c2c1eb99a0b2499b99d40d68855c3eb4d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 17:28:22 +0100 Subject: [PATCH 26/30] for None pred, fill corresponding row with None values across the cols --- chebai/result/prediction.py | 41 +++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 1410a6c5..0e478f66 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -86,20 +86,45 @@ def predict_from_file( with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) + CLASS_LABELS: list | None = None - predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): + def _add_class_columns(class_file_path: _PATH) -> list[str]: with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] + return [cls.strip() for cls in f.readlines()] if classes_path is not None: - _add_class_columns(classes_path) + CLASS_LABELS = _add_class_columns(classes_path) elif os.path.exists(self._dm.classes_txt_file_path): - _add_class_columns(self._dm.classes_txt_file_path) + CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) + + preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) + if all(pred is None for pred in preds): + print("No valid predictions were made. (All predictions are None.)") + return + + if CLASS_LABELS is not None and self._model.out_dim is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + assert len(CLASS_LABELS) == self._model.out_dim, ( + f"Number of class labels ({len(CLASS_LABELS)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + num_of_cols = len(CLASS_LABELS) + elif CLASS_LABELS is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + num_of_cols = len(CLASS_LABELS) + elif self._model.out_dim is not None: + num_of_cols = self._model.out_dim + else: + # find first non-None tensor to determine width + num_of_cols = next(x.numel() for x in preds if x is not None) + CLASS_LABELS = [f"class_{i}" for i in range(num_of_cols)] + + rows = [ + pred.tolist() if pred is not None else [None] * num_of_cols + for pred in preds + ] + predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings) - predictions_df.index = smiles_strings predictions_df.to_csv(save_to) @torch.inference_mode() From 0ea34903622de1c2655c0c15d6aabd44a9d973fc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 19:46:30 +0100 Subject: [PATCH 27/30] save classification labels to checkpoints --- chebai/cli.py | 6 ++++ chebai/models/base.py | 16 +++++++-- chebai/preprocessing/datasets/base.py | 48 ++++++++++++++------------- chebai/result/prediction.py | 1 + 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index 5c517efa..8b51e45f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]): apply_on="instantiate", ) + parser.link_arguments( + "data.classes_txt_file_path", + "model.init_args.classes_txt_file_path", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ( "micro-f1", diff --git a/chebai/models/base.py b/chebai/models/base.py index 19fc6c6a..538180a1 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -40,6 +40,7 @@ def __init__( pass_loss_kwargs: bool = True, optimizer_kwargs: Optional[Dict[str, Any]] = None, exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + classes_txt_file_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -77,6 +78,17 @@ def __init__( self.validation_metrics = val_metrics self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) + + def on_save_checkpoint(self, checkpoint): + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a @@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs): def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gets the predictions and labels from the model output. @@ -151,7 +163,7 @@ def _process_for_loss( model_output: torch.Tensor, labels: torch.Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Processes the data for loss computation. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index d47f38f7..2f08ae4e 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property @@ -619,6 +619,19 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - results/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class MergedDataset(XYBaseDataModule): MERGED = [] @@ -1373,14 +1386,3 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} - - @property - def classes_txt_file_path(self) -> str: - """ - Returns the filename for the classes text file. - - Returns: - str: The filename for the classes text file. - """ - # This property also used in custom trainer `chebai/trainer/CustomTrainer.py` - return os.path.join(self.processed_dir_main, "classes.txt") diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 0e478f66..8664fd15 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -126,6 +126,7 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]: predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings) predictions_df.to_csv(save_to) + print(f"Predictions saved to: {save_to}") @torch.inference_mode() def predict_smiles( From a1cdacaffba62e4585ebaff1004537f388ff951d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 20:05:54 +0100 Subject: [PATCH 28/30] update for prediction logic how ckpts with class labels --- README.md | 6 +++--- chebai/result/prediction.py | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5054ddd2..b195324c 100644 --- a/README.md +++ b/README.md @@ -81,10 +81,10 @@ python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-t * **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. -* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. - -* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs. +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. +* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs. + * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. * If provided, the CSV columns will be named using the ChEBI IDs. * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 8664fd15..fda8a308 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -62,6 +62,19 @@ def __init__( ) print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + self._classification_labels: list | None = ckpt_file.get( + "classification_labels", None + ) + if self._classification_labels is not None: + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert len(self._classification_labels) > 0, ( + "Classification labels list is empty." + ) + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + if compile_model: self._model = torch.compile(self._model) self._model.eval() @@ -92,7 +105,10 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]: with open(class_file_path, "r") as f: return [cls.strip() for cls in f.readlines()] - if classes_path is not None: + if self._classification_labels is not None: + CLASS_LABELS = self._classification_labels + # --- For old checkpoints that do not have classification_labels saved --- + elif classes_path is not None: CLASS_LABELS = _add_class_columns(classes_path) elif os.path.exists(self._dm.classes_txt_file_path): CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) @@ -102,6 +118,7 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]: print("No valid predictions were made. (All predictions are None.)") return + # --- Logic for old checkpoints that do not have classification_labels saved --- if CLASS_LABELS is not None and self._model.out_dim is not None: assert len(CLASS_LABELS) > 0, "Class labels list is empty." assert len(CLASS_LABELS) == self._model.out_dim, ( From 330a8c2ea35545ace4e8791e97250fc492056df3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 20:13:36 +0100 Subject: [PATCH 29/30] pre-commit format --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 24 ++++++++++++------------ chebai/result/prediction.py | 6 +++--- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b195324c..11755aed 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-t * **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. * **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs. - * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. + * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. * If provided, the CSV columns will be named using the ChEBI IDs. * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 2f08ae4e..e04463f3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or (self.label_filter is None), ( - "Filter balancing requires a filter" - ) + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert fold_index is None or self.use_inner_cross_validation is not None, ( - "fold_index can only be set if cross validation is used" - ) + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" if fold_index is not None and self.inner_k_folds is not None: - assert fold_index < self.inner_k_folds, ( - "fold_index can't be larger than the total number of folds" - ) + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert self._feature_vector_size is not None, ( - "size of feature vector must be set" - ) + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" return self._feature_vector_size @property diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index fda8a308..b7994de0 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -67,9 +67,9 @@ def __init__( ) if self._classification_labels is not None: print(f"Loaded {len(self._classification_labels)} classification labels.") - assert len(self._classification_labels) > 0, ( - "Classification labels list is empty." - ) + assert ( + len(self._classification_labels) > 0 + ), "Classification labels list is empty." assert len(self._classification_labels) == self._model.out_dim, ( f"Number of class labels ({len(self._classification_labels)}) does not match " f"the model output dimension ({self._model.out_dim})." From 704af93b15b24dfd665ea32e63edebc66b4704fa Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 Jan 2026 20:32:38 +0100 Subject: [PATCH 30/30] update test for cli change --- tests/unit/cli/classification_labels.txt | 10 ++++++++++ tests/unit/cli/mock_dm.py | 6 ++++++ tests/unit/cli/testCLI.py | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 tests/unit/cli/classification_labels.txt diff --git a/tests/unit/cli/classification_labels.txt b/tests/unit/cli/classification_labels.txt new file mode 100644 index 00000000..f50323b4 --- /dev/null +++ b/tests/unit/cli/classification_labels.txt @@ -0,0 +1,10 @@ +label_1 +label_2 +label_3 +label_4 +label_5 +label_6 +label_7 +label_8 +label_9 +label_10 \ No newline at end of file diff --git a/tests/unit/cli/mock_dm.py b/tests/unit/cli/mock_dm.py index 25116e21..e3fd60a7 100644 --- a/tests/unit/cli/mock_dm.py +++ b/tests/unit/cli/mock_dm.py @@ -1,3 +1,5 @@ +import os + import torch from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader @@ -29,6 +31,10 @@ def num_of_labels(self): def feature_vector_size(self): return self._feature_vector_size + @property + def classes_txt_file_path(self) -> str: + return os.path.join("tests", "unit", "cli", "classification_labels.txt") + def train_dataloader(self): assert self.feature_vector_size is not None, "feature_vector_size must be set" # Dummy dataset for example purposes diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..584da5e7 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -9,7 +9,7 @@ def setUp(self): "fit", "--trainer=configs/training/default_trainer.yml", "--model=configs/model/ffn.yml", - "--model.init_args.hidden_layers=[10]", + "--model.init_args.hidden_layers=[1]", "--model.train_metrics=configs/metrics/micro-macro-f1.yml", "--data=tests/unit/cli/mock_dm_config.yml", "--model.pass_loss_kwargs=false",