From a3284b6f2753747dd6fbaf4f5ac2a20fdad70154 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 23:40:47 +0100 Subject: [PATCH 1/4] predict pipeline --- chebai_graph/preprocessing/datasets/chebi.py | 36 +++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 4ae441a..722df7c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -77,7 +77,7 @@ def __init__( properties = self._sort_properties(properties) else: properties = [] - self.properties = properties + self.properties: list[MolecularProperty] = properties assert isinstance(self.properties, list) and all( isinstance(p, MolecularProperty) for p in self.properties ) @@ -361,6 +361,40 @@ def load_processed_data( return base_df[base_data[0].keys()].to_dict("records") + def _process_input_for_prediction(self, smiles_list: list[str]) -> list: + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + # element of data is a dict with 'id' and 'features' (GeomData) + # GeomData has only edge_index filled but node and edges features are empty. + + assert len(data) == len(smiles_list), "Data length mismatch." + data_df = pd.DataFrame(data) + + for idx, data_row in data_df.itertuples(index=True): + property_data = data_row + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property(smiles_list[idx], property) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + property_data[property.name] = encoded_value + + property_data["features"] = property_data.apply( + lambda row: self._merge_props_into_base(row), axis=1 + ) + + return data_df.to_dict("records") + class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): From 173788c5aa44abcd90c09aae2b0000813aa2124b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:25:44 +0100 Subject: [PATCH 2/4] fix pred pipe func --- chebai_graph/preprocessing/datasets/chebi.py | 120 +++++++++++++------ 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 722df7c..055cca9 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -184,6 +184,62 @@ def _after_setup(self, **kwargs) -> None: self._setup_properties() super()._after_setup(**kwargs) + def _process_input_for_prediction( + self, + smiles_list: list[str], + model_hparams: Optional[dict] = None, + ) -> list: + data_df = self._process_smiles_and_props(smiles_list) + data_df["features"] = data_df.apply( + lambda row: self._merge_props_into_base(row), axis=1 + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + data_df["features"] = data_df["features"].apply(self.transform) + + return data_df.to_dict("records") + + def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame: + """ + Process SMILES strings and compute molecular properties. + """ + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + # element of data is a dict with 'id' and 'features' (GeomData) + # GeomData has only edge_index filled but node and edges features are empty. + + assert len(data) == len(smiles_list), "Data length mismatch." + data_df = pd.DataFrame(data) + + props: list[dict] = [] + for data_row in data_df.itertuples(index=True): + row_prop_dict: dict = {} + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property( + smiles_list[data_row.Index], property + ) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + row_prop_dict[property.name] = encoded_value + row_prop_dict["ident"] = data_row.ident + props.append(row_prop_dict) + + property_df = pd.DataFrame(props) + data_df = data_df.merge(property_df, on="ident", how="left") + return data_df + class GraphPropertiesMixIn(DataPropertiesSetter, ABC): def __init__( @@ -361,40 +417,6 @@ def load_processed_data( return base_df[base_data[0].keys()].to_dict("records") - def _process_input_for_prediction(self, smiles_list: list[str]) -> list: - data = [ - self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} - ) - for idx, smiles in enumerate(smiles_list) - ] - # element of data is a dict with 'id' and 'features' (GeomData) - # GeomData has only edge_index filled but node and edges features are empty. - - assert len(data) == len(smiles_list), "Data length mismatch." - data_df = pd.DataFrame(data) - - for idx, data_row in data_df.itertuples(index=True): - property_data = data_row - for property in self.properties: - property.encoder.eval = True - property_value = self.reader.read_property(smiles_list[idx], property) - if property_value is None or len(property_value) == 0: - encoded_value = None - else: - encoded_value = torch.stack( - [property.encoder.encode(v) for v in property_value] - ) - if len(encoded_value.shape) == 3: - encoded_value = encoded_value.squeeze(0) - property_data[property.name] = encoded_value - - property_data["features"] = property_data.apply( - lambda row: self._merge_props_into_base(row), axis=1 - ) - - return data_df.to_dict("records") - class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): @@ -605,6 +627,36 @@ def _merge_props_into_base( is_graph_node=is_graph_node, ) + def _process_input_for_prediction( + self, + smiles_list: list[str], + model_hparams: Optional[dict] = None, + ) -> list: + if ( + model_hparams is None + or "in_channels" not in model_hparams["config"] + or model_hparams["config"]["in_channels"] is None + ): + raise ValueError( + f"model_hparams must be provided for data class: {self.__class__.__name__}" + f" which should contain 'in_channels' key with valid value in 'config' dictionary." + ) + + max_len_node_properties = int(model_hparams["config"]["in_channels"]) + # Determine max_len_node_properties based on in_channels + + data_df = self._process_smiles_and_props(smiles_list) + data_df["features"] = data_df.apply( + lambda row: self._merge_props_into_base(row, max_len_node_properties), + axis=1, + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + data_df["features"] = data_df["features"].apply(self.transform) + + return data_df.to_dict("records") + class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): READER = RandomFeatureInitializationReader From a11ba5067b6aec9f94e48f74db42f19f0482522f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 23:29:09 +0100 Subject: [PATCH 3/4] fix irrevalant ident from reader error --- chebai_graph/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 055cca9..8578160 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -206,7 +206,7 @@ def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame: """ data = [ self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} + {"ident": f"smiles_{idx}", "features": smiles, "labels": None} ) for idx, smiles in enumerate(smiles_list) ] From 8e9d2255a2d08c14a542e9da0d82d4afa4850858 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Dec 2025 11:55:52 +0100 Subject: [PATCH 4/4] adapt code for new logic to handle none returns --- chebai_graph/preprocessing/datasets/chebi.py | 121 +++++++++---------- 1 file changed, 54 insertions(+), 67 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 8578160..fdf2e6d 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -184,61 +184,53 @@ def _after_setup(self, **kwargs) -> None: self._setup_properties() super()._after_setup(**kwargs) - def _process_input_for_prediction( - self, - smiles_list: list[str], - model_hparams: Optional[dict] = None, - ) -> list: - data_df = self._process_smiles_and_props(smiles_list) - data_df["features"] = data_df.apply( - lambda row: self._merge_props_into_base(row), axis=1 + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + result = self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) + if result is None or result["features"] is None: + return None + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property(smiles, property) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + result[property.name] = encoded_value + + result["features"] = self._prediction_merge_props_into_base_wrapper( + result, model_hparams ) # apply transformation, e.g. masking for pretraining task if self.transform is not None: - data_df["features"] = data_df["features"].apply(self.transform) + result["features"] = self.transform(result["features"]) - return data_df.to_dict("records") + return result - def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame: + def _prediction_merge_props_into_base_wrapper( + self, row: pd.Series | dict, model_hparams: Optional[dict] = None + ) -> GeomData: """ - Process SMILES strings and compute molecular properties. + Wrapper to merge properties into base features for prediction. + + Args: + row: A dictionary or pd.Series containing 'features' and encoded properties. + Returns: + A GeomData object with merged features. """ - data = [ - self.reader.to_data( - {"ident": f"smiles_{idx}", "features": smiles, "labels": None} - ) - for idx, smiles in enumerate(smiles_list) - ] - # element of data is a dict with 'id' and 'features' (GeomData) - # GeomData has only edge_index filled but node and edges features are empty. - - assert len(data) == len(smiles_list), "Data length mismatch." - data_df = pd.DataFrame(data) - - props: list[dict] = [] - for data_row in data_df.itertuples(index=True): - row_prop_dict: dict = {} - for property in self.properties: - property.encoder.eval = True - property_value = self.reader.read_property( - smiles_list[data_row.Index], property - ) - if property_value is None or len(property_value) == 0: - encoded_value = None - else: - encoded_value = torch.stack( - [property.encoder.encode(v) for v in property_value] - ) - if len(encoded_value.shape) == 3: - encoded_value = encoded_value.squeeze(0) - row_prop_dict[property.name] = encoded_value - row_prop_dict["ident"] = data_row.ident - props.append(row_prop_dict) - - property_df = pd.DataFrame(props) - data_df = data_df.merge(property_df, on="ident", how="left") - return data_df + return self._merge_props_into_base(row) class GraphPropertiesMixIn(DataPropertiesSetter, ABC): @@ -276,7 +268,7 @@ def __init__( f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" ) - def _merge_props_into_base(self, row: pd.Series) -> GeomData: + def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData: """ Merge encoded molecular properties into the GeomData object. @@ -544,6 +536,8 @@ def _merge_props_into_base( A GeomData object with merged features. """ geom_data = row["features"] + if geom_data is None: + return None assert isinstance(geom_data, GeomData) is_atom_node = geom_data.is_atom_node @@ -627,11 +621,17 @@ def _merge_props_into_base( is_graph_node=is_graph_node, ) - def _process_input_for_prediction( - self, - smiles_list: list[str], - model_hparams: Optional[dict] = None, - ) -> list: + def _prediction_merge_props_into_base_wrapper( + self, row: pd.Series | dict, model_hparams: Optional[dict] = None + ) -> GeomData: + """ + Wrapper to merge properties into base features for prediction. + + Args: + row: A dictionary or pd.Series containing 'features' and encoded properties. + Returns: + A GeomData object with merged features. + """ if ( model_hparams is None or "in_channels" not in model_hparams["config"] @@ -641,21 +641,8 @@ def _process_input_for_prediction( f"model_hparams must be provided for data class: {self.__class__.__name__}" f" which should contain 'in_channels' key with valid value in 'config' dictionary." ) - max_len_node_properties = int(model_hparams["config"]["in_channels"]) - # Determine max_len_node_properties based on in_channels - - data_df = self._process_smiles_and_props(smiles_list) - data_df["features"] = data_df.apply( - lambda row: self._merge_props_into_base(row, max_len_node_properties), - axis=1, - ) - - # apply transformation, e.g. masking for pretraining task - if self.transform is not None: - data_df["features"] = data_df["features"].apply(self.transform) - - return data_df.to_dict("records") + return self._merge_props_into_base(row, max_len_node_properties) class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):