From 8053b474ccf5849f60729421f25021fcf56f30f5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 13:39:37 +0100 Subject: [PATCH 1/8] add doc to static gni class --- .../preprocessing/reader/static_gni.py | 176 +++++++++++++++--- 1 file changed, 148 insertions(+), 28 deletions(-) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 106c528..a9fac8c 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -1,19 +1,70 @@ """ -Abboud, Ralph, et al. -"The surprising power of graph neural networks with random node initialization." -arXiv preprint arXiv:2010.01179 (2020). +RandomFeatureInitializationReader +-------------------------------- -Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py +Implements random node / edge / molecule feature initialization for graph neural +networks following: + +Abboud, R., et al. (2020). "The surprising power of graph neural networks with +random node initialization." arXiv preprint arXiv:2010.01179. + +Code reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py + +This module provides a reader that replaces node/edge/molecule features with +randomly initialized tensors drawn from a selected distribution. + +Notes +----- +- This reader subclasses GraphPropertyReader and is intended to be used where a + graph object with attributes `x`, `edge_attr`, and optionally `molecule_attr` + is expected (e.g., `torch_geometric.data.Data`). +- The reader only performs random initialization and does not support reading + specific properties from the input data. """ +from typing import Any, Optional + import torch +from torch import Tensor from torch_geometric.data import Data as GeomData from .reader import GraphPropertyReader class RandomFeatureInitializationReader(GraphPropertyReader): - DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] + """ + Reader that initializes node, bond (edge), and molecule features with + random values according to a chosen distribution. + + Supported distributions: + - "normal" : standard normal (mean=0, std=1) + - "uniform" : uniform in [-1, 1] + - "xavier_normal" : Xavier normal initialization + - "xavier_uniform" : Xavier uniform initialization + - "zeros" : all zeros + + Parameters + ---------- + num_node_properties : int + Number of features to generate per node. + num_bond_properties : int + Number of features to generate per edge/bond. + num_molecule_properties : int + Number of global molecule-level features to generate. + distribution : str, optional + One of the supported distributions (default: "normal"). + *args, **kwargs : Any + Additional positional and keyword arguments passed to the parent + GraphPropertyReader. + """ + + DISTRIBUTIONS = [ + "normal", + "uniform", + "xavier_normal", + "xavier_uniform", + "zeros", + ] def __init__( self, @@ -21,54 +72,123 @@ def __init__( num_bond_properties: int, num_molecule_properties: int, distribution: str = "normal", - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) - self.num_node_properties = num_node_properties - self.num_bond_properties = num_bond_properties - self.num_molecule_properties = num_molecule_properties - assert distribution in self.DISTRIBUTIONS - self.distribution = distribution + if distribution not in self.DISTRIBUTIONS: + raise ValueError( + f"distribution must be one of {self.DISTRIBUTIONS}, got '{distribution}'" + ) + + self.num_node_properties: int = int(num_node_properties) + self.num_bond_properties: int = int(num_bond_properties) + self.num_molecule_properties: int = int(num_molecule_properties) + self.distribution: str = distribution def name(self) -> str: """ - Get the name identifier of the reader. + Return a human-readable identifier for this reader configuration. + + Returns + ------- + str + A name encoding the chosen distribution and generated feature sizes. + """ + return ( + f"gni-{self.distribution}" + f"-node{self.num_node_properties}" + f"-bond{self.num_bond_properties}" + f"-mol{self.num_molecule_properties}" + ) - Returns: - str: The name of the reader. + def _read_data(self, raw_data: Any) -> Optional[GeomData]: """ - return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}" + Read and return a `torch_geometric.data.Data` object with randomized + node/edge/molecule features. + + This method calls the parent's `_read_data` to obtain a graph object, + then replaces `x`, `edge_attr` and sets `molecule_attr` with new tensors. - def _read_data(self, raw_data): - data: GeomData = super()._read_data(raw_data) + Parameters + ---------- + raw_data : Any + Raw input that the parent reader understands. + + Returns + ------- + Optional[GeomData] + A `Data` object with randomized attributes or `None` if the parent + `_read_data` returned `None`. + """ + data: Optional[GeomData] = super()._read_data(raw_data) if data is None: return None - random_x = torch.empty(data.x.shape[0], self.num_node_properties) - random_edge_attr = torch.empty( - data.edge_attr.shape[0], self.num_bond_properties + # Ensure expected attributes exist (torch_geometric Data may vary). + num_nodes = int(data.x.shape[0]) if getattr(data, "x", None) is not None else 0 + num_edges = ( + int(data.edge_attr.shape[0]) + if getattr(data, "edge_attr", None) is not None + else 0 + ) + + # Create random tensors of the requested shapes. + random_x: Tensor = torch.empty(num_nodes, self.num_node_properties) + random_edge_attr: Tensor = torch.empty(num_edges, self.num_bond_properties) + random_molecule_properties: Tensor = torch.empty( + 1, self.num_molecule_properties ) - random_molecule_properties = torch.empty(1, self.num_molecule_properties) + # Initialize them according to the chosen distribution. self.random_gni(random_x, self.distribution) self.random_gni(random_edge_attr, self.distribution) self.random_gni(random_molecule_properties, self.distribution) + # Assign randomized attributes back to the data object. data.x = random_x data.edge_attr = random_edge_attr + # Use `molecule_attr` as the name in this codebase; if your Data object + # expects a different name (e.g., `u` or `global_attr`) adapt accordingly. data.molecule_attr = random_molecule_properties + return data - def read_property(self, *args, **kwargs) -> Exception: - """This reader does not support reading specific properties.""" - raise NotImplementedError("This reader only performs random initialization.") + def read_property(self, *args: Any, **kwargs: Any) -> None: + """ + This reader does not support reading specific properties from the input. + It only performs random initialization of features. + + Raises + ------ + NotImplementedError + Always raised to indicate unsupported operation. + """ + raise NotImplementedError( + "RandomFeatureInitializationReader only performs random initialization." + ) @staticmethod - def random_gni(tensor: torch.Tensor, distribution: str) -> None: + def random_gni(tensor: Tensor, distribution: str) -> None: + """ + Fill `tensor` in-place according to the requested initialization. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to initialize in-place. + distribution : str + One of the supported distribution identifiers. + + Raises + ------ + ValueError + If an unknown distribution string is provided. + """ if distribution == "normal": torch.nn.init.normal_(tensor) elif distribution == "uniform": + # Uniform in [-1, 1] torch.nn.init.uniform_(tensor, a=-1.0, b=1.0) elif distribution == "xavier_normal": torch.nn.init.xavier_normal_(tensor) @@ -77,4 +197,4 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None: elif distribution == "zeros": torch.nn.init.zeros_(tensor) else: - raise ValueError("Unknown distribution type") + raise ValueError(f"Unknown distribution type: '{distribution}'") From 592036259452c2a61321af57edac3740272ea347 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 13:42:45 +0100 Subject: [PATCH 2/8] add aug graph readme --- README.md | 116 ++++++++++++++++++ .../preprocessing/reader/static_gni.py | 17 +-- 2 files changed, 120 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index c28dee3..35b773a 100644 --- a/README.md +++ b/README.md @@ -75,3 +75,119 @@ The list can be found in the `configs/data/chebi50_graph_properties.yml` file. ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml ``` + +## Augmented Graphs + +```bash +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0 +``` + +### Model Hyperparameters + +#### **GAT Architecture** + +To use a GAT-based model, choose **one** of the following configs: + +- **Atom–Motif–Graph Node Pooling** + ```bash + --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml + ``` + +- **Atom-Augmented Node Pooling** + ```bash + --model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml + ``` + +- **Standard Pooling** + ```bash + --model=../python-chebai-graph/configs/model/gat.yml + ``` + +#### GAT-specific hyperparameters + +- **Number of message-passing layers** + ```bash + --model.config.num_layers=5 # Default: 4 + ``` + +- **Attention heads** + ```bash + --model.config.heads=4 # Default: 8 + ``` + *Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).* + +- **Use GATv2** + ```bash + --model.config.v2=True # Default: False + ``` + + +#### **ResGated Architecture** + +To use a ResGated GNN model, choose **one** of the following configs: + +- **Atom–Motif–Graph Node Pooling** + ```bash + --model=../python-chebai-graph/configs/model/res_aug_amgpool.yml + ``` + +- **Atom-Augmented Node Pooling** + ```bash + --model=../python-chebai-graph/configs/model/res_aug_aagpool.yml + ``` + +- **Standard Pooling** + ```bash + --model=../python-chebai-graph/configs/model/resgated.yml + ``` + + +#### **Common Hyperparameters** + +These can be used for both GAT and ResGated architectures: + +- **Dropout** + ```bash + --model.config.dropout=0.1 # Default: 0 + ``` + +- **Number of final linear layers** + ```bash + --model.n_linear_layers=2 # Default: 1 + ``` + +## Random Node Initialization + + + + + + +### Static Node Intialization + +In this type of node initialization, the node properties ( and/or edge properties) of the given molecular graph is initialized only once during dataset creation with given node initiliazation scheme. + + +In the below config, for each node we the 158 node properties we retrieve from RDKit along and add 54 features to node (specified by `--data.pad_node_features=45`) which is drawn from normal distribution (by default.) You can change the distribution from which additional features are drawn by using `--data.distribution=zeros` + +below are the available distributions: +["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] + +Similary, each edge is initializaed with 7 properties from RDKit and 4 additional features drawn from given distribution. + + +``` +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 +``` + +if you to use all the features for node (and edge) drawn from given distribution, use the data class +`--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml` . Refer the data class code. + +### Dynamic Node Initialization +In this type of node initialization, the node properties ( and/or edge properties) of the given molecular graph is initialized at each forward pass of the model with given node initiliazation scheme. + + +```bash + +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0 +``` diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index a9fac8c..da9f847 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -125,20 +125,11 @@ def _read_data(self, raw_data: Any) -> Optional[GeomData]: if data is None: return None - # Ensure expected attributes exist (torch_geometric Data may vary). - num_nodes = int(data.x.shape[0]) if getattr(data, "x", None) is not None else 0 - num_edges = ( - int(data.edge_attr.shape[0]) - if getattr(data, "edge_attr", None) is not None - else 0 - ) - - # Create random tensors of the requested shapes. - random_x: Tensor = torch.empty(num_nodes, self.num_node_properties) - random_edge_attr: Tensor = torch.empty(num_edges, self.num_bond_properties) - random_molecule_properties: Tensor = torch.empty( - 1, self.num_molecule_properties + random_x = torch.empty(data.x.shape[0], self.num_node_properties) + random_edge_attr = torch.empty( + data.edge_attr.shape[0], self.num_bond_properties ) + random_molecule_properties = torch.empty(1, self.num_molecule_properties) # Initialize them according to the chosen distribution. self.random_gni(random_x, self.distribution) From 759e45604d5feaa49440c00072204a5cc398b594 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 14:01:40 +0100 Subject: [PATCH 3/8] doc for dynamic gni --- chebai_graph/models/dynamic_gni.py | 61 +++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 20a6c6a..bb48571 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -1,3 +1,23 @@ +""" +ResGatedDynamicGNIGraphPred +------------------------------------------------ + +Module providing a ResGated GNN model that applies Random Node Initialization +(RNI) dynamically at each forward pass. This follows the approach from: + +Abboud, R., et al. (2020). "The surprising power of graph neural networks with +random node initialization." arXiv preprint arXiv:2010.01179. + +The module exposes: +- ResGatedDynamicGNI: a model that can either completely replace node/edge + features with random tensors each forward pass or pad existing features with + additional random features. +- ResGatedDynamicGNIGraphPred: a thin wrapper that instantiates the above for + graph-level prediction pipelines. +""" + +__all__ = ["ResGatedDynamicGNIGraphPred"] + from typing import Any import torch @@ -14,12 +34,37 @@ class ResGatedDynamicGNI(GraphModelBase): """ - Base model class for applying ResGatedGraphConv layers to graph-structured data - with dynamic initialization of features for nodes and edges. - - Args: - config (dict): Configuration dictionary containing model hyperparameters. - **kwargs: Additional keyword arguments for parent class. + ResGated GNN with dynamic Random Node Initialization (RNI). + + This model supports two modes controlled by the `config`: + + - complete_randomness (bool-like): If True, **replace** node and edge + features entirely with randomly initialized tensors each forward pass. + If False, the model **pads** existing features with extra randomly + initialized features on-the-fly. + + - pad_node_features (int, optional): Number of random columns to append + to each node feature vector when `complete_randomness` is False. + + - pad_edge_features (int, optional): Number of random columns to append + to each edge feature vector when `complete_randomness` is False. + + - distribution (str): Distribution for random initialization. Must be one + of RandomFeatureInitializationReader.DISTRIBUTIONS. + + Parameters + ---------- + config : Dict[str, Any] + Configuration dictionary containing model hyperparameters. Expected keys + used by this class: + - distribution (optional, default "normal") + - complete_randomness (optional, default "True") + - pad_node_features (optional, int) + - pad_edge_features (optional, int) + Keys required by GraphModelBase (e.g., in_channels, hidden_channels, + out_channels, num_layers, edge_dim) should also be present. + **kwargs : Any + Additional keyword arguments forwarded to GraphModelBase. """ def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -96,6 +141,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor: new_x = None new_edge_attr = None + + # If replacing features entirely with random values if self.complete_randomness: new_x = torch.empty( graph_data.x.shape[0], graph_data.x.shape[1], device=self.device @@ -110,6 +157,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor: RandomFeatureInitializationReader.random_gni( new_edge_attr, self.distribution ) + + # If padding existing features with additional random columns else: if self.pad_node_features is not None: pad_node = torch.empty( From e9f239e922bc266a211f2d1bdcd255885b7c8cf0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 14:13:57 +0100 Subject: [PATCH 4/8] refine readme for gni --- README.md | 54 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 35b773a..929a58c 100644 --- a/README.md +++ b/README.md @@ -156,38 +156,68 @@ These can be used for both GAT and ResGated architectures: --model.n_linear_layers=2 # Default: 1 ``` -## Random Node Initialization +# Random Node Initialization +## Static Node Initialization +In this type of node initialization, the node properties (and/or edge properties) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme. -### Static Node Intialization +``` +python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 +``` -In this type of node initialization, the node properties ( and/or edge properties) of the given molecular graph is initialized only once during dataset creation with given node initiliazation scheme. +In the above config, for each node we use the 158 node properties retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). You can change the distribution using: +``` +--data.distribution=zeros +``` -In the below config, for each node we the 158 node properties we retrieve from RDKit along and add 54 features to node (specified by `--data.pad_node_features=45`) which is drawn from normal distribution (by default.) You can change the distribution from which additional features are drawn by using `--data.distribution=zeros` +Available distributions: -below are the available distributions: +``` ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] +``` + +Similarly, each edge is initialized with 7 RDKit properties and 4 additional features drawn from the given distribution. -Similary, each edge is initializaed with 7 properties from RDKit and 4 additional features drawn from given distribution. +If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: ``` -python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 +--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml ``` -if you to use all the features for node (and edge) drawn from given distribution, use the data class -`--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml` . Refer the data class code. +Refer to the data class code for details. -### Dynamic Node Initialization -In this type of node initialization, the node properties ( and/or edge properties) of the given molecular graph is initialized at each forward pass of the model with given node initiliazation scheme. +## Dynamic Node Initialization -```bash +In this type of node initialization, the node properties (and/or edge properties) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme. + +Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: + +``` +--model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml +``` + +To keep RDKit features and *add* dynamically initialized features: + +``` +--model.config.complete_randomness=False +--model.config.pad_node_features=45 +``` +The additional features are drawn from normal distribution (default). You can change it using: + +``` +--model.config.distribution=uniform +``` + +If all features should be initialized from the given distribution, remove the complete_randomness flag (default is True). + +``` python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0 ``` From fe9b4ef6b7f3df5842199ed617cb6606089eeb91 Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Sat, 22 Nov 2025 14:24:09 +0100 Subject: [PATCH 5/8] commands on same line for compactness Updated the README to use a more concise format for model configuration options and hyperparameters. --- README.md | 68 ++++++++++--------------------------------------------- 1 file changed, 12 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 929a58c..83096ef 100644 --- a/README.md +++ b/README.md @@ -88,75 +88,31 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo To use a GAT-based model, choose **one** of the following configs: -- **Atom–Motif–Graph Node Pooling** - ```bash - --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml - ``` - -- **Atom-Augmented Node Pooling** - ```bash - --model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml - ``` - -- **Standard Pooling** - ```bash - --model=../python-chebai-graph/configs/model/gat.yml - ``` +- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml` +- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml` +- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml` #### GAT-specific hyperparameters -- **Number of message-passing layers** - ```bash - --model.config.num_layers=5 # Default: 4 - ``` - -- **Attention heads** - ```bash - --model.config.heads=4 # Default: 8 - ``` - *Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).* - -- **Use GATv2** - ```bash - --model.config.v2=True # Default: False - ``` - +- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4) +- **Attention heads**: `--model.config.heads=4` (Default: 8) + > Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified). +- **Use GATv2**: `--model.config.v2=True` (default: False) #### **ResGated Architecture** To use a ResGated GNN model, choose **one** of the following configs: -- **Atom–Motif–Graph Node Pooling** - ```bash - --model=../python-chebai-graph/configs/model/res_aug_amgpool.yml - ``` - -- **Atom-Augmented Node Pooling** - ```bash - --model=../python-chebai-graph/configs/model/res_aug_aagpool.yml - ``` - -- **Standard Pooling** - ```bash - --model=../python-chebai-graph/configs/model/resgated.yml - ``` - +- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_amgpool.yml` +- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_aagpool.yml` +- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/resgated.yml` #### **Common Hyperparameters** These can be used for both GAT and ResGated architectures: -- **Dropout** - ```bash - --model.config.dropout=0.1 # Default: 0 - ``` - -- **Number of final linear layers** - ```bash - --model.n_linear_layers=2 # Default: 1 - ``` - - +- **Dropout**: `--model.config.dropout=0.1` (default: 0) +- **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1) # Random Node Initialization From 524f665b967babe725fc9e7fc7ec2585af3aca82 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 19:49:23 +0100 Subject: [PATCH 6/8] readme - minor clarification --- README.md | 49 ++++++++++++++++++++----------------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 83096ef..450800d 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,9 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo ## Augmented Graphs +Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs. + + ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0 ``` @@ -95,7 +98,7 @@ To use a GAT-based model, choose **one** of the following configs: #### GAT-specific hyperparameters - **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4) -- **Attention heads**: `--model.config.heads=4` (Default: 8) +- **Attention heads**: `--model.config.heads=4` (default: 8) > Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified). - **Use GATv2**: `--model.config.v2=True` (default: False) @@ -118,62 +121,50 @@ These can be used for both GAT and ResGated architectures: ## Static Node Initialization -In this type of node initialization, the node properties (and/or edge properties) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme. - +In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme. -``` +```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 ``` -In the above config, for each node we use the 158 node properties retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). You can change the distribution using: +In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). -``` ---data.distribution=zeros -``` +You can change the distribution using the following config in above command: `--data.distribution=zeros` -Available distributions: +Available distributions: `"normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"` -``` -["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] -``` -Similarly, each edge is initialized with 7 RDKit properties and 4 additional features drawn from the given distribution. +Similarly, each edge is initialized with 7 RDKit features and 4 additional features drawn from the given distribution. -If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: +If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: `--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml` -``` ---data=../python-chebai-graph/configs/data/chebi50_static_gni.yml -``` Refer to the data class code for details. ## Dynamic Node Initialization -In this type of node initialization, the node properties (and/or edge properties) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme. +In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme. -Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: -``` ---model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml -``` -To keep RDKit features and *add* dynamically initialized features: +Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: `--model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml` + +To keep RDKit features and *add* dynamically initialized features use the following config in the command: ``` --model.config.complete_randomness=False --model.config.pad_node_features=45 ``` -The additional features are drawn from normal distribution (default). You can change it using: - -``` ---model.config.distribution=uniform -``` +The additional features are drawn from normal distribution (default). You can change it using:`--model.config.distribution=uniform` If all features should be initialized from the given distribution, remove the complete_randomness flag (default is True). -``` + +Please find below the command for a typical dynamic node initialization: + +```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0 ``` From 37c76c11edce7aa48292008490c9f68024986e46 Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:32:35 +0100 Subject: [PATCH 7/8] Enhance README with augmented graphs explanation Expanded the section on augmented graphs to explain the use of artificial nodes representing functional groups. Added details on connection schemes and provided commands for model and data configuration. --- README.md | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 450800d..22fa5ec 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,20 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo ## Augmented Graphs -Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs. +Graph Neural Networks (GNNs) often fail to explicitly leverage the chemically meaningful substructures present within molecules (i.e. **functional groups (FGs)**). To make this implicit information explicitly accessible to GNNs, we augment molecular graphs with **artificial nodes** that represent these substructures. The resulting graph are referred to as **augmented graphs**. +> Note: Rings are also treated as functional groups in our work. + +In these augmented graphs, each functional group node is connected to the atoms that constitute the group. Additionally, two functional group nodes are connected if any atom belonging to one group shares a bond with an atom from the other group. We further introduce a **graph node**, an extra node connected to all functional group nodes. + +Among all the connection schemes we evaluated, this configuration delivered the strongest performance. We denote it using the abbreviation **WFG_WFGE_WGN** in our work and is shown in below figure. + +mol_to_aug_mol + +
+
+ +Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs. ```bash python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0 @@ -91,17 +103,23 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo To use a GAT-based model, choose **one** of the following configs: -- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml` -- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml` - **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml` + > Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification. + +- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml` + > With this pooling stratergy, the learned representations are first separated into **two distinct sets**: those from atom nodes and those from all artificial nodes (both functional groups and the graph node). The representations within each set are aggregated separately (using summation) to yield two distinct single vectors. These two resulting vectors are then concatenated before being passed to the classification layer. + +- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml` + > This approach employs a finer granularity of separation, distinguishing learned representations into **three distinct sets**: atom nodes, Functional Group (FG) nodes, and the single graph node. Summation is performed separately on the atom node set and the FG node set, yielding two vectors. These two vectors are then concatenated along with the single vector corresponding to the graph node before the final linear layer. #### GAT-specific hyperparameters - **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4) - **Attention heads**: `--model.config.heads=4` (default: 8) - > Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified). + > **Note**: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified). - **Use GATv2**: `--model.config.v2=True` (default: False) - + > **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491). + #### **ResGated Architecture** To use a ResGated GNN model, choose **one** of the following configs: @@ -117,9 +135,9 @@ These can be used for both GAT and ResGated architectures: - **Dropout**: `--model.config.dropout=0.1` (default: 0) - **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1) -# Random Node Initialization +## Random Node Initialization -## Static Node Initialization +### Static Node Initialization In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme. @@ -143,7 +161,7 @@ If you want all node (and edge) features to be drawn from a given distribution ( Refer to the data class code for details. -## Dynamic Node Initialization +### Dynamic Node Initialization In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme. From 83ebc629e4207a699ee8f74608d5e986f697fb3c Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:39:02 +0100 Subject: [PATCH 8/8] Fix typos and improve clarity in README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 22fa5ec..a6be90a 100644 --- a/README.md +++ b/README.md @@ -145,9 +145,9 @@ In this type of node initialization, the node features (and/or edge features) of python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0 ``` -In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and add 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). +In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and additional 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default). -You can change the distribution using the following config in above command: `--data.distribution=zeros` +You can change the distribution from which additional are drawn by using the following config in above command: `--data.distribution=zeros` Available distributions: `"normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"`