diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 4ae441a..fdf2e6d 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -77,7 +77,7 @@ def __init__( properties = self._sort_properties(properties) else: properties = [] - self.properties = properties + self.properties: list[MolecularProperty] = properties assert isinstance(self.properties, list) and all( isinstance(p, MolecularProperty) for p in self.properties ) @@ -184,6 +184,54 @@ def _after_setup(self, **kwargs) -> None: self._setup_properties() super()._after_setup(**kwargs) + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + result = self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) + if result is None or result["features"] is None: + return None + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property(smiles, property) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + result[property.name] = encoded_value + + result["features"] = self._prediction_merge_props_into_base_wrapper( + result, model_hparams + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + result["features"] = self.transform(result["features"]) + + return result + + def _prediction_merge_props_into_base_wrapper( + self, row: pd.Series | dict, model_hparams: Optional[dict] = None + ) -> GeomData: + """ + Wrapper to merge properties into base features for prediction. + + Args: + row: A dictionary or pd.Series containing 'features' and encoded properties. + Returns: + A GeomData object with merged features. + """ + return self._merge_props_into_base(row) + class GraphPropertiesMixIn(DataPropertiesSetter, ABC): def __init__( @@ -220,7 +268,7 @@ def __init__( f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" ) - def _merge_props_into_base(self, row: pd.Series) -> GeomData: + def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData: """ Merge encoded molecular properties into the GeomData object. @@ -488,6 +536,8 @@ def _merge_props_into_base( A GeomData object with merged features. """ geom_data = row["features"] + if geom_data is None: + return None assert isinstance(geom_data, GeomData) is_atom_node = geom_data.is_atom_node @@ -571,6 +621,29 @@ def _merge_props_into_base( is_graph_node=is_graph_node, ) + def _prediction_merge_props_into_base_wrapper( + self, row: pd.Series | dict, model_hparams: Optional[dict] = None + ) -> GeomData: + """ + Wrapper to merge properties into base features for prediction. + + Args: + row: A dictionary or pd.Series containing 'features' and encoded properties. + Returns: + A GeomData object with merged features. + """ + if ( + model_hparams is None + or "in_channels" not in model_hparams["config"] + or model_hparams["config"]["in_channels"] is None + ): + raise ValueError( + f"model_hparams must be provided for data class: {self.__class__.__name__}" + f" which should contain 'in_channels' key with valid value in 'config' dictionary." + ) + max_len_node_properties = int(model_hparams["config"]["in_channels"]) + return self._merge_props_into_base(row, max_len_node_properties) + class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): READER = RandomFeatureInitializationReader