Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
465b651
use instantiate model to load data and model from ckpt
aditya0by0 Nov 16, 2025
2acd166
update readme
aditya0by0 Nov 16, 2025
cfbf392
set no grad for predict
aditya0by0 Nov 17, 2025
82b365c
predict pipeline in dm and lm
aditya0by0 Nov 26, 2025
fa6f1b5
there is no need that predict func must depend on trainer
aditya0by0 Nov 26, 2025
ae47608
model hparams for data predict pipeline and vice versa
aditya0by0 Nov 27, 2025
517a5a2
fix reader ident error
aditya0by0 Nov 27, 2025
40491e5
fix label None error
aditya0by0 Nov 27, 2025
7b7e48f
fix cli predict_from_file error
aditya0by0 Nov 27, 2025
6a38317
update readme for new prediction method
aditya0by0 Nov 27, 2025
31b12db
Revert "fix reader ident error"
aditya0by0 Nov 27, 2025
c6e8b61
modify pred logic to store model and dm as instance var
aditya0by0 Nov 28, 2025
63670dd
fix for unwanted args to predict_smiles
aditya0by0 Dec 6, 2025
d906ad4
avoid non_null_labels key in loss kwargs
aditya0by0 Dec 6, 2025
ba5884a
compile model
aditya0by0 Dec 6, 2025
5a17f72
revert the comment line for splits file path
aditya0by0 Dec 6, 2025
bd59d59
handle augment electra and old ckpt files
aditya0by0 Dec 12, 2025
6b8ae09
remove unnec device
aditya0by0 Dec 13, 2025
c029395
raise error for invalid smiles and return None
aditya0by0 Dec 15, 2025
6761079
rectify test as to return None for invalid strings
aditya0by0 Dec 15, 2025
bbb8f10
Merge branch 'dev' into fix/read_data
aditya0by0 Dec 15, 2025
d5e362b
pin rdkit version - https://github.com/ChEB-AI/python-chebai/issues/83
aditya0by0 Dec 16, 2025
5d302d0
handle None returns for invalid smiles
aditya0by0 Dec 16, 2025
b75209e
Merge branch 'dev' into fix/generalize_predict_func
aditya0by0 Dec 18, 2025
ead78b6
Merge branch 'fix/read_data' into fix/generalize_predict_func
aditya0by0 Dec 18, 2025
e86f03a
remove instanstior key from hparams as its causing unnecessary error
aditya0by0 Dec 18, 2025
994de55
return none for token limit too
aditya0by0 Dec 18, 2025
3dbf9ae
assert i/p and o/p dim greater than 0
aditya0by0 Jan 8, 2026
811fbf1
for None pred, fill corresponding row with None values across the cols
aditya0by0 Jan 8, 2026
0ea3490
save classification labels to checkpoints
aditya0by0 Jan 8, 2026
a1cdaca
update for prediction logic how ckpts with class labels
aditya0by0 Jan 8, 2026
330a8c2
pre-commit format
aditya0by0 Jan 8, 2026
704af93
update test for cli change
aditya0by0 Jan 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con

### Predicting classes given SMILES strings
```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.

* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).

* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.

* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.

* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs.
* Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required.
* If provided, the CSV columns will be named using the ChEBI IDs.
* If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially.

## Evaluation

Expand Down
7 changes: 6 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
apply_on="instantiate",
)

parser.link_arguments(
"data.classes_txt_file_path",
"model.init_args.classes_txt_file_path",
apply_on="instantiate",
)

for kind in ("train", "val", "test"):
for average in (
"micro-f1",
Expand Down Expand Up @@ -112,7 +118,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"predict_from_file": {"model"},
}


Expand Down
30 changes: 25 additions & 5 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@ def __init__(
pass_loss_kwargs: bool = True,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
classes_txt_file_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
# super().__init__()
if exclude_hyperparameter_logging is None:
exclude_hyperparameter_logging = tuple()
self.criterion = criterion
assert out_dim is not None, "out_dim must be specified"
assert input_dim is not None, "input_dim must be specified"
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
self.out_dim = out_dim
self.input_dim = input_dim
print(
Expand Down Expand Up @@ -77,6 +78,17 @@ def __init__(
self.validation_metrics = val_metrics
self.test_metrics = test_metrics
self.pass_loss_kwargs = pass_loss_kwargs
with open(classes_txt_file_path, "r") as f:
self.labels_list = [cls.strip() for cls in f.readlines()]
assert len(self.labels_list) > 0, "Class labels list is empty."
assert len(self.labels_list) == out_dim, (
f"Number of class labels ({len(self.labels_list)}) does not match "
f"the model output dimension ({out_dim})."
)

def on_save_checkpoint(self, checkpoint):
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
checkpoint["classification_labels"] = self.labels_list

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
Expand All @@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs):

def _get_prediction_and_labels(
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
) -> (torch.Tensor, torch.Tensor):
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gets the predictions and labels from the model output.

Expand Down Expand Up @@ -151,7 +163,7 @@ def _process_for_loss(
model_output: torch.Tensor,
labels: torch.Tensor,
loss_kwargs: Dict[str, Any],
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Processes the data for loss computation.

Expand Down Expand Up @@ -237,7 +249,15 @@ def predict_step(
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
"""
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
assert isinstance(batch, XYData)
batch = batch.to(self.device)
data = self._process_batch(batch, batch_idx)
model_output = self(data, **data.get("model_kwargs", dict()))

# Dummy labels to avoid errors in _get_prediction_and_labels
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
return pr

def _execute(
self,
Expand Down
100 changes: 85 additions & 15 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]
# filter for missing features in resulting data, keep features length below token limit
data = [
val
for val in data
if val["features"] is not None
and (
self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
)
]

data = [val for val in data if self._filter_to_token_limit(val)]
return data

def _filter_to_token_limit(self, data_instance: dict) -> bool:
# filter for missing features in resulting data, keep features length below token limit
if data_instance["features"] is not None and (
self.n_token_limit is None
or len(data_instance["features"]) <= self.n_token_limit
):
return True
return False

def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
"""
Returns the train DataLoader.
Expand Down Expand Up @@ -401,22 +402,77 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
"""

return self.dataloader("test", shuffle=False, **kwargs)

def predict_dataloader(
self, *args, **kwargs
) -> Union[DataLoader, List[DataLoader]]:
self,
smiles_list: List[str],
model_hparams: Optional[dict] = None,
**kwargs,
) -> tuple[DataLoader, list[int]]:
"""
Returns the predict DataLoader.

Args:
*args: Additional positional arguments (unused).
smiles_list (List[str]): List of SMILES strings to predict.
model_hparams (Optional[dict]): Model hyperparameters.
Some prediction pre-processing pipelines may require these.
**kwargs: Additional keyword arguments, passed to dataloader().

Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices.
"""
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)

data, valid_indices = self._process_input_for_prediction(
smiles_list, model_hparams
)
return (
DataLoader(
data,
collate_fn=self.reader.collator,
batch_size=self.batch_size,
**kwargs,
),
valid_indices,
)

def _process_input_for_prediction(
self, smiles_list: list[str], model_hparams: Optional[dict] = None
) -> tuple[list, list]:
"""
Process input data for prediction.

Args:
smiles_list (List[str]): List of SMILES strings.
model_hparams (Optional[dict]): Model hyperparameters.
Some prediction pre-processing pipelines may require these.

Returns:
tuple[list, list]: Processed input data and valid indices.
"""
data, valid_indices = [], []
for idx, smiles in enumerate(smiles_list):
result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams)
if result is None or result["features"] is None:
continue
if not self._filter_to_token_limit(result):
continue
data.append(result)
valid_indices.append(idx)

return data, valid_indices

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
return self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
)

def prepare_data(self, *args, **kwargs) -> None:
if self._prepare_data_flag != 1:
Expand Down Expand Up @@ -563,6 +619,19 @@ def raw_file_names_dict(self) -> dict:
"""
raise NotImplementedError

@property
def classes_txt_file_path(self) -> str:
"""
Returns the filename for the classes text file.

Returns:
str: The filename for the classes text file.
"""
# This property also used in following places:
# - results/prediction.py: to load class names for csv columns names
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
return os.path.join(self.processed_dir_main, "classes.txt")


class MergedDataset(XYBaseDataModule):
MERGED = []
Expand Down Expand Up @@ -1191,7 +1260,8 @@ def _retrieve_splits_from_csv(self) -> None:
print(f"Applying label filter from {self.apply_label_filter}...")
with open(self.apply_label_filter, "r") as f:
label_filter = [line.strip() for line in f]
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:

with open(self.classes_txt_file_path, "r") as cf:
classes = [line.strip() for line in cf]
# reorder labels
old_labels = np.stack(df_data["labels"])
Expand Down
5 changes: 4 additions & 1 deletion chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,13 @@ def _read_data(self, raw_data: str) -> List[int]:
print(f"RDKit failed to process {raw_data}")
print(f"\t{e}")
try:
mol = Chem.MolFromSmiles(raw_data.strip())
if mol is None:
raise ValueError(f"Invalid SMILES: {raw_data}")
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
except ValueError as e:
print(f"could not process {raw_data}")
print(f"\t{e}")
print(f"\tError: {e}")
return None

def _back_to_smiles(self, smiles_encoded):
Expand Down
Loading