Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
properties = self._sort_properties(properties)
else:
properties = []
self.properties = properties
self.properties: list[MolecularProperty] = properties
assert isinstance(self.properties, list) and all(
isinstance(p, MolecularProperty) for p in self.properties
)
Expand Down Expand Up @@ -184,6 +184,54 @@ def _after_setup(self, **kwargs) -> None:
self._setup_properties()
super()._after_setup(**kwargs)

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
result = self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
)
if result is None or result["features"] is None:
return None
for property in self.properties:
property.encoder.eval = True
property_value = self.reader.read_property(smiles, property)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
encoded_value = torch.stack(
[property.encoder.encode(v) for v in property_value]
)
if len(encoded_value.shape) == 3:
encoded_value = encoded_value.squeeze(0)
result[property.name] = encoded_value

result["features"] = self._prediction_merge_props_into_base_wrapper(
result, model_hparams
)

# apply transformation, e.g. masking for pretraining task
if self.transform is not None:
result["features"] = self.transform(result["features"])

return result

def _prediction_merge_props_into_base_wrapper(
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
) -> GeomData:
"""
Wrapper to merge properties into base features for prediction.

Args:
row: A dictionary or pd.Series containing 'features' and encoded properties.
Returns:
A GeomData object with merged features.
"""
return self._merge_props_into_base(row)


class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
def __init__(
Expand Down Expand Up @@ -220,7 +268,7 @@ def __init__(
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
)

def _merge_props_into_base(self, row: pd.Series) -> GeomData:
def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
"""
Merge encoded molecular properties into the GeomData object.

Expand Down Expand Up @@ -488,6 +536,8 @@ def _merge_props_into_base(
A GeomData object with merged features.
"""
geom_data = row["features"]
if geom_data is None:
return None
assert isinstance(geom_data, GeomData)

is_atom_node = geom_data.is_atom_node
Expand Down Expand Up @@ -571,6 +621,29 @@ def _merge_props_into_base(
is_graph_node=is_graph_node,
)

def _prediction_merge_props_into_base_wrapper(
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
) -> GeomData:
"""
Wrapper to merge properties into base features for prediction.

Args:
row: A dictionary or pd.Series containing 'features' and encoded properties.
Returns:
A GeomData object with merged features.
"""
if (
model_hparams is None
or "in_channels" not in model_hparams["config"]
or model_hparams["config"]["in_channels"] is None
):
raise ValueError(
f"model_hparams must be provided for data class: {self.__class__.__name__}"
f" which should contain 'in_channels' key with valid value in 'config' dictionary."
)
max_len_node_properties = int(model_hparams["config"]["in_channels"])
return self._merge_props_into_base(row, max_len_node_properties)


class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
READER = RandomFeatureInitializationReader
Expand Down