diff --git a/README.md b/README.md index 45b238cb..11755aed 100644 --- a/README.md +++ b/README.md @@ -74,11 +74,19 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs. + * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..8b51e45f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]): apply_on="instantiate", ) + parser.link_arguments( + "data.classes_txt_file_path", + "model.init_args.classes_txt_file_path", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -112,7 +118,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } diff --git a/chebai/models/base.py b/chebai/models/base.py index 37bb6ef6..538180a1 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -40,6 +40,7 @@ def __init__( pass_loss_kwargs: bool = True, optimizer_kwargs: Optional[Dict[str, Any]] = None, exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + classes_txt_file_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -47,8 +48,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + assert out_dim is not None and out_dim > 0, "out_dim must be specified" + assert input_dim is not None and input_dim > 0, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( @@ -77,6 +78,17 @@ def __init__( self.validation_metrics = val_metrics self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) + + def on_save_checkpoint(self, checkpoint): + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a @@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs): def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gets the predictions and labels from the model output. @@ -151,7 +163,7 @@ def _process_for_loss( model_output: torch.Tensor, labels: torch.Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Processes the data for loss computation. @@ -237,7 +249,15 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 02b6ec72..e04463f3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data, keep features length below token limit - data = [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + data = [val for val in data if self._filter_to_token_limit(val)] return data + def _filter_to_token_limit(self, data_instance: dict) -> bool: + # filter for missing features in resulting data, keep features length below token limit + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -401,22 +402,77 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs - ) -> Union[DataLoader, List[DataLoader]]: + self, + smiles_list: List[str], + model_hparams: Optional[dict] = None, + **kwargs, + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> tuple[list, list]: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. + + Returns: + tuple[list, list]: Processed input data and valid indices. + """ + data, valid_indices = [], [] + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + if not self._filter_to_token_limit(result): + continue + data.append(result) + valid_indices.append(idx) + + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + return self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: @@ -563,6 +619,19 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - results/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class MergedDataset(XYBaseDataModule): MERGED = [] @@ -1191,7 +1260,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 2b3b1b0e..669c8fbf 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -208,10 +208,13 @@ def _read_data(self, raw_data: str) -> List[int]: print(f"RDKit failed to process {raw_data}") print(f"\t{e}") try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] except ValueError as e: print(f"could not process {raw_data}") - print(f"\t{e}") + print(f"\tError: {e}") return None def _back_to_smiles(self, smiles_encoded): diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..b7994de0 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,214 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + self._dm_hparams.pop("_instantiator", None) + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + self._classification_labels: list | None = ckpt_file.get( + "classification_labels", None + ) + if self._classification_labels is not None: + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert ( + len(self._classification_labels) > 0 + ), "Classification labels list is empty." + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + + if compile_model: + self._model = torch.compile(self._model) + self._model.eval() + print("-" * 50) + + def predict_from_file( + self, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + smiles_file_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + """ + with open(smiles_file_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + CLASS_LABELS: list | None = None + + def _add_class_columns(class_file_path: _PATH) -> list[str]: + with open(class_file_path, "r") as f: + return [cls.strip() for cls in f.readlines()] + + if self._classification_labels is not None: + CLASS_LABELS = self._classification_labels + # --- For old checkpoints that do not have classification_labels saved --- + elif classes_path is not None: + CLASS_LABELS = _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) + + preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) + if all(pred is None for pred in preds): + print("No valid predictions were made. (All predictions are None.)") + return + + # --- Logic for old checkpoints that do not have classification_labels saved --- + if CLASS_LABELS is not None and self._model.out_dim is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + assert len(CLASS_LABELS) == self._model.out_dim, ( + f"Number of class labels ({len(CLASS_LABELS)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + num_of_cols = len(CLASS_LABELS) + elif CLASS_LABELS is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + num_of_cols = len(CLASS_LABELS) + elif self._model.out_dim is not None: + num_of_cols = self._model.out_dim + else: + # find first non-None tensor to determine width + num_of_cols = next(x.numel() for x in preds if x is not None) + CLASS_LABELS = [f"class_{i}" for i in range(num_of_cols)] + + rows = [ + pred.tolist() if pred is not None else [None] * num_of_cols + for pred in preds + ] + predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings) + + predictions_df.to_csv(save_to) + print(f"Predictions saved to: {save_to}") + + @torch.inference_mode() + def predict_smiles( + self, + smiles: List[str], + ) -> list[torch.Tensor | None]: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + # For certain data prediction piplines, we may need model hyperparameters + pred_dl, valid_indices = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams + ) + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred + + return output + + +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) + + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index e93cff85..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,18 +1,13 @@ import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd -import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader log = logging.getLogger(__name__) @@ -74,68 +69,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - def predict_from_file( + def predict( self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, smiles: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], - batch_first=True, - ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - if model.model_type == "regression": - preds = model_output["logits"] - else: - preds = torch.sigmoid(model_output["logits"]) - - return preds @property def log_dir(self) -> Optional[str]: @@ -159,7 +104,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start diff --git a/pyproject.toml b/pyproject.toml index 56bcd64d..aa538b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers", "pysmiles==1.1.2", - "rdkit", + "rdkit==2024.3.6", "lightning==2.5.1", ] diff --git a/tests/unit/cli/classification_labels.txt b/tests/unit/cli/classification_labels.txt new file mode 100644 index 00000000..f50323b4 --- /dev/null +++ b/tests/unit/cli/classification_labels.txt @@ -0,0 +1,10 @@ +label_1 +label_2 +label_3 +label_4 +label_5 +label_6 +label_7 +label_8 +label_9 +label_10 \ No newline at end of file diff --git a/tests/unit/cli/mock_dm.py b/tests/unit/cli/mock_dm.py index 25116e21..e3fd60a7 100644 --- a/tests/unit/cli/mock_dm.py +++ b/tests/unit/cli/mock_dm.py @@ -1,3 +1,5 @@ +import os + import torch from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader @@ -29,6 +31,10 @@ def num_of_labels(self): def feature_vector_size(self): return self._feature_vector_size + @property + def classes_txt_file_path(self) -> str: + return os.path.join("tests", "unit", "cli", "classification_labels.txt") + def train_dataloader(self): assert self.feature_vector_size is not None, "feature_vector_size must be set" # Dummy dataset for example purposes diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..584da5e7 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -9,7 +9,7 @@ def setUp(self): "fit", "--trainer=configs/training/default_trainer.yml", "--model=configs/model/ffn.yml", - "--model.init_args.hidden_layers=[10]", + "--model.init_args.hidden_layers=[1]", "--model.train_metrics=configs/metrics/micro-macro-f1.yml", "--data=tests/unit/cli/mock_dm_config.yml", "--model.pass_loss_kwargs=false", diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index ec018f00..9d322f27 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -42,19 +42,22 @@ def test_read_data(self) -> None: """ Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. """ - raw_data = "CC(=O)NC1[Mg-2]" + raw_data = "CC(=O)NC1CC1[Mg-2]" # Expected output as per the tokens already in the cache, and ")" getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 0, # C - EMBEDDING_OFFSET + 5, # = - EMBEDDING_OFFSET + 3, # O - EMBEDDING_OFFSET + 1, # N - EMBEDDING_OFFSET + len(self.reader.cache), # ( - EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 5, # ( + EMBEDDING_OFFSET + 3, # = + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token + EMBEDDING_OFFSET + 2, # N EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 4, # 1 - EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token ] result = self.reader._read_data(raw_data) self.assertEqual( @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None: Test the _read_data method with an invalid input. The invalid token should prompt a return value None """ - raw_data = "%INVALID%" - - result = self.reader._read_data(raw_data) - self.assertIsNone( - result, - "The output for invalid token '%INVALID%' should be None.", - ) + # see https://github.com/ChEB-AI/python-chebai/issues/137 + raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"] + for raw_data in raw_datas: + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + f"The output for invalid token '{raw_data}' should be None.", + ) @patch("builtins.open", new_callable=mock_open) def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: