Skip to content

sarinstein-yan/HSG-12M

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

HSG-12M arXiv

HSG-12M Spatial Multigraph Dataset.

This repository contains the code in the companion paper "HSG-12M: A Large-Scale Spatial Multigraph Dataset":

  • downloading of the raw data files
  • preliminary featurization and processing to PyTorch Geometric Dataset (both in memory and on disk)
  • benchmarking on GNN baseline models
  • generation of the HSG-12M
  • deriving the six dataset variants, and generic custom subsets

The 1401 data files are publicly available at Dataverse.

The dataset is generated by poly2graph.

See the tutorial.ipynb for an interactive start.

Installation

The package requires python>=3.12 and can be installed locally.

$ conda create -n hsg python=3.12 # python>=3.12
$ conda activate hsg

$ git clone https://github.com/sarinstein-yan/HSG-12M.git
$ cd HSG-12M
$ pip install -e . 
# Check Installation
$ python -c "import hsg; print(hsg.__version__)"

Or if you want to install with a specific CUDA version of PyTorch, you can set the CUDA environment variable to the desired version (e.g., cu126 for CUDA 12.6) before running the installation command:

$ export CUDA=cu126 # < On Linux or macOS
# On Windows (PowerShell): `$CUDA = "cu126"`
$ pip install . --extra-index-url https://download.pytorch.org/whl/${CUDA}

PyG Datasets

The download of raw data files is by easyDataverse, which sometimes encounters timeout connecting to Dataverse. Retry or connect to better network if you encounter timeout when initializing the dataset.

The NetworkX MultiGraph dataset is processed to PyTorch Geometric Dataset via the scheme discussed in the companion paper.

Feel free to overwrite the process() method in HSGOnDisk or HSGInMemory to customize the featurization and processing.

Available dataset variants (subset) are:

  • one-band: 24 classes, balanced
  • two-band: 275 classes, balanced
  • three-band: 1102 classes, balanced
  • topology: 1401 classes, imbalanced
  • all / all-static: 1401 classes, balanced
from hsg import HSGInMemory, HSGOnDisk
ROOT_DIR = 'path/to/dataset_root'

ds = HSGInMemory(root=ROOT_DIR, subset="one-band")

# # use HSGOnDisk if some large dataset variant overflows RAM
# ds = HSGOnDisk(root=ROOT_DIR, subset="one-band")

print("Number of graphs:", len(ds))
print("Number of classes:", ds.num_classes)
print("Number of node features:", ds.num_features)
print("Number of edge features:", ds.num_edge_features)
print("A graph:", ds[12345])

>>>

Number of graphs: 198744
Number of classes: 24
Number of node features: 4
Number of edge features: 13
A graph: Data(edge_index=[2, 6], x=[4, 4], edge_attr=[6, 13], y=[1])

Benchmarking

Training is done with PyTorch Lightning. The training code is located in src/hsg/training.py and scripts/benchmark.py.

Reproducing Benchmarking Results

Run the following command to reproduce the benchmarking results in the companion paper.

You may need to adjust the configurations in the script file.

# $ cd path/to/HSG-12M (if not already in it)
$ python scripts/benchmark.py
# OR: $ python src/hsg/training.py

Custom Training

Run experiment for one model on one dataset variant (and a certain seed if specified).

Pass training configurations to hsg.Config()

cfg = hsg.Config(
    ### Environment
    data_root=ROOT_DIR,              # Root directory for data
    save_dir="path/to/results",       # Directory to save experiment results
    subset="one-band",               # Dataset subset to use
    seed=None,                       # Set random seed for reproducibility

    ### Training
    model_name="gcn",                # Model architecture (e.g., 'gcn', 'sage', 'gat')
    batch_size=4096,                  # Batch size for training
    max_epochs=100,                    # Maximum number of training epochs
    max_steps=-1,                    # Maximum number of training steps (-1 for unlimited)

    ### DataModule
    size_mode="edge",                # Batch size limitation mode ("edge" or "node")
    max_num_per_batch=2_000_000,     # Maximum number of edges/nodes per batch. Batches are rebalanced to avoid OOM
    transform=None,                  # Dataset transform

    ### Model Hyperparameters
    dim_h_gnn=64,                    # Hidden dimension size for GNN Conv layers
    dim_h_mlp=64,                    # Hidden dimension size for MLP layers
    num_layers_gnn=4,                # Number of GNN Conv layers
    num_layers_mlp=2,                # Number of MLP layers
    dropout=0.0,                     # Dropout rate
    num_heads=1,                     # Number of attention heads (for GAT/GATv2 only)
    kernel_size=5,                   # Kernel size (for MoNet/SplineCNN only)

    ### Optimizer
    lr_init=1e-3,                    # Initial learning rate
    lr_min=1e-5,                     # Minimum learning rate, for CosineAnnealing scheduler
    weight_decay=0.0,                # Weight decay for optimizer, for AdamW
    T_0=100,                         # Scheduler parameter, for CosineAnnealingWarmRestarts

    ### Trainer
    devices="auto",                  # Devices to use ("auto", int, or str)
    strategy="auto",                 # Training strategy ("auto" or specific strategy)
    val_check_interval=1.0,          # Validation check interval
    log_every_n_steps=50,            # Logging frequency (steps)
    deterministic=True,              # Deterministic training for reproducibility. NOTE: True will consume a lot more memory!
    profiler=None,                   # Profiler type (None or string)
    fast_dev_run=False,              # Fast development run (debug mode)
    num_sanity_val_steps=0,          # Number of sanity validation steps before training 
)

And then run training with hsg.run_experiment(cfg)

results = hsg.run_experiment(cfg)

results is a dict containing the test metrics and training stats.

Training logs and checkpoints are saved together in Config.save_dir.

To view Tensorboard logs:

$ tensorboard --logdir path/to/results

Raw Data Files

The download of raw data files is by easyDataverse, which sometimes encounters timeout connecting to Dataverse. Retry or connect to better network if you encounter timeout when initializing the dataset.

from easyDataverse import Dataverse

server_url = "https://dataverse.harvard.edu"
dataset_pid = "doi:10.7910/DVN/PYDSSQ"

dv = Dataverse(server_url)
hsg12m = dv.load_dataset(
    pid=dataset_pid, 
    # local directory to save files
    filedir='.',
    # which classes to download
    filenames=[f'raw/class_{i}.npz' for i in range(1401)],
    # Set to True to download the files, requires at least 257GB of free space
    download_files=False,
)

# dataset metadata
print(hsg12m.citation)

Content of each file


Each file stores spectral graph data generated from one polynomial class. These files are created by the hsg.HSG_Generator.generate_dataset(...) method and are named as: class_{class_index}.npz. (See HSG-12M generation)

Key Type Meaning
graphs_pickle List[bytes] List of serialized NetworkX MultiGraph objects, each corresponding to a (a, b) parameter pair.
y int Class label.
a_vals np.ndarray List of parameter values for a used.
b_vals np.ndarray List of parameter values for b used.
class_meta Dict[str, Any] Class-level metadata, including polynomial class and Hamiltonian information.

Note

Except for y and class_meta which are class-level, the other three are aligned by index, i.e., the $i$-th graph in graphs_pickle was generated using a = a_vals[i] and b = b_vals[i].

Content of class_meta


Key Type Meaning
parameter_symbols Tuple[str, …] Symbols of the free coefficients inserted into the polynomial (e.g. ('a', 'b')).
latex str LaTeX code for the fully-expanded polynomial.
sympy_repr str sympy standard representation of the polynomial.
generator_symbols Tuple[str, …] The sympy.Poly's generator — normally ('z', '1/z', 'E').
number_of_bands int Number of bands $b$ (= highest exponent of E of the base term -E**b).
max_left_hopping int The highest exponent of $z$, i.e. $q$ of the base term z**q.
max_right_hopping int The lowest exponent of $z$, i.e. $p = D - q$ of the base term z**(-p).

Take the 9-th class as an example, if the raw/class_9.npz file is downloaded, one can load the nx.MultiGraph objects, class label y, the parameter values a_vals, b_vals, and the class-specific metadata as follows:

import hsg
from pathlib import Path

nx_graphs, y, a_vals, b_vals, class_meta = hsg.load_class(class_idx=9, raw_dir=Path(ROOT_DIR) / 'raw')

print("class label:", y)
print("a_vals:", a_vals)
print("b_vals:", b_vals)
print("polynomial latex:", class_meta['latex'])
print("polynomial parameter symbols:", class_meta['parameter_symbols'])
print("polynomial generators:", class_meta['generator_symbols'])
print("number of bands:", class_meta['number_of_bands'])
print("max left hopping:", class_meta['max_left_hopping'])
print("max right hopping:", class_meta['max_right_hopping'])

import sympy as sp
poly = sp.sympify(class_meta['sympy_repr'])
poly

>>>

class label: 9
a_vals: [-10.-5.j -10.-2.j -10.-1.j ...  10.+1.j  10.+2.j  10.+5.j]
b_vals: [-10.-5.j -10.-5.j -10.-5.j ...  10.+5.j  10.+5.j  10.+5.j]
polynomial latex: - E + \frac{a}{z} + b z + z^{2} + \frac{1}{z^{2}}
polynomial parameter symbols: ['a' 'b']
polynomial generators: ['z' '1/z' 'E']
number of bands: 1
max left hopping: 2
max right hopping: 2

$$\text{Poly}{\left( z^{2} + b z + \frac{1}{z^{2}} + a \frac{1}{z} - E, z, \frac{1}{z}, E, domain=\mathbb{Z}\left[a, b\right] \right)}$$

Node and Edge Attributes of the Hamiltonian Spectral Graph Object


Each Hamiltonian spectral graph is represented as a networkx.MultiGraph object.

Node Attributes

Attribute Type Description
pos (2,) np.ndarray 2D position of the node in the complex energy plane $(\text{Re}(E), \text{Im}(E))$.
dos float Density of States at the node, the number of eigenvalues per unit area.
potential float Spectral Potential at the node, i.e. Ronkin function, an algebro-geometric property of the characteristic polynomial.

Edge Attributes

Attribute Type Description
weight float Length of the edge in the complex energy plane.
pts (w, 2) np.ndarray Discretized points along the edge, forming a path in the complex plane. w is the number of samples along the edge.
avg_dos float Average density of states sampled along the edge.
avg_potential float Average spectral potential sampled along the edge.

Take the $1234$-th graph in the class_9.npz file as an example:

import networkx as nx

gid = 1234
graph = nx_graphs[gid]

print(f"a = {a_vals[gid]}, b = {b_vals[gid]}\n")
print(f"Nodes: {graph.nodes(data=True)}\n")
print(f"Edges: {graph.edges(data=True)}\n")
print(f"Class metadata: {class_meta}")


import matplotlib.pyplot as plt

def plot_spatial_multigraph(graph, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    for es, ee, pts in graph.edges(data='pts'):
        ax.plot(*pts.T, c='tab:blue', lw=2, alpha=0.8)
    for n, (x, y) in graph.nodes(data='pos'):
        ax.scatter(x, y, c='tab:red', s=10, zorder=10)
    return ax

fig, ax = plt.subplots(figsize=(3, 3))
plot_spatial_multigraph(graph, ax=ax)
plt.tight_layout(); plt.show()

HSG-12M Generation

The dataset generator used in the companion paper is as follows:

import hsg
gen = hsg.HSG_Generator(
    root=ROOT_DIR,
    hopping_range=[4,5,6], 
    num_bands=[1,2,3],
    real_coeff_walk=[-10, -5, -2, -1, -0.5, -0.1, 0, 0.1, 0.5, 1, 2, 5, 10],
    imag_coeff_walk=[-5, -2, -1, 0, 1, 2, 5],
)
num_classes = len(gen.all_metas)

Run the following to generate the 1401 raw data files of HSG-12M (saved to HSG_Generator.root_dir + '/raw'):

for i in range(num_classes):
    gen.generate_dataset(
            class_idx=i,
            num_partition=20,
            # ^ generate the class in 20 partitions, can set to 1 if RAM is large enough
            short_edge_threshold=30,
            # ^ merge near-by nodes within this distance threshold, see `poly2graph` documentation
        )

Load T-HSG-5.1M

Again, take the 9-th class as an example, to obtain the temporal graphs derived from class 9 (ensure the class_9.npz file is in the gen.root_dir/raw directory):

tg_9 = gen.get_temporal_graphs_by_class(class_idx=9)

To get the whole T-HSG-5.1M as a List[List[networkx.MultiGraph]] (ensure all 1401 files are downloaded or generated at gen.root_dir/raw):

thsg = []
y = []
from tqdm import tqdm
for i in tqdm(range(num_classes)):
    tg_i = gen.get_temporal_graphs_by_class(i)
    thsg.extend(tg_i)
    y.extend([i] * len(tg_i))

print("Number of temporal graphs:", len(thsg))

Citation

If you find this work useful, please cite our paper:

@misc{yan2025hsg12mlargescalespatialmultigraph,
      title={HSG-12M: A Large-Scale Spatial Multigraph Dataset}, 
      author={Xianquan Yan and Hakan Akgün and Kenji Kawaguchi and N. Duane Loh and Ching Hua Lee},
      year={2025},
      eprint={2506.08618},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.08618}, 
}

About

HSG-12M: A Large-Scale Spatial Multigraph Dataset

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •