From 15a9b083583d47226e1b2a061a3031dee6cde431 Mon Sep 17 00:00:00 2001 From: Theo Viel Date: Fri, 16 Jan 2026 04:17:44 -0800 Subject: [PATCH 1/3] Inference refacto, minor readme changes, fix warning, some formatting --- README.md | 12 +- configs/configs_data.py | 2 +- rnapro/config/config.py | 28 +- rnapro/data/infer_data_pipeline.py | 204 ++++--- rnapro/model/generator.py | 8 +- rnapro/model/modules/diffusion.py | 2 +- rnapro/model/modules/pairformer.py | 143 +++-- rnapro/model/modules/ribonanzanet.py | 556 ++++++++++-------- .../openfold_local/utils/precision_utils.py | 2 +- rnapro/utils/inference.py | 148 +++++ rnapro_inference_example.sh | 58 +- runner/inference.py | 284 ++++----- 12 files changed, 847 insertions(+), 600 deletions(-) create mode 100644 rnapro/utils/inference.py diff --git a/README.md b/README.md index 4e81352..215faad 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ To train the model, you will need the CCD cache. The CCD cache is generated by p ```sh -python3 scripts/gen_ccd_cache.py +python3 preprocess/gen_ccd_cache.py ``` ``` @@ -141,6 +141,7 @@ release_data/ ### 2. Training + We provide the trained model checkpoint via [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/resources/rnapro) and HuggingFace ([Public best](https://huggingface.co/nvidia/RNAPro-Public-Best-500M) and [Private best](https://huggingface.co/nvidia/RNAPro-Private-Best-500M)). We provide a convenience script for training. Please modify it according to your purpose: @@ -155,7 +156,7 @@ sh rnapro_train_example.sh For details on the input format and output format, please refer to the [overview](model_cards/overview.md). -### 1. Prepare inputs +### Prepare inputs - Input csv files - Prepare a CSV file with the columns: target_id and sequence. @@ -171,6 +172,10 @@ For details on the input format and output format, please refer to the [overview - `python preprocess/convert_templates_to_pt_files.py --input_csv path/to/submission.csv --output_name path/to/template_features.pt --max_n 40` - Use with `--use_template ca_precomputed --template_data path/to/template_features.pt`. +- CCD cache (same as training) + -`python preprocess/gen_ccd_cache.py` + +- Model weights are available via [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/resources/rnapro) and HuggingFace ([Public best](https://huggingface.co/nvidia/RNAPro-Public-Best-500M) and [Private best](https://huggingface.co/nvidia/RNAPro-Private-Best-500M)). ### Inference via Bash Script @@ -200,7 +205,8 @@ The script configures and forwards the following parameters to the CLI: - `--num_workers`: Data loader workers. - `--triangle_attention` / `--triangle_multiplicative`: Kernel backends (`torch`, `cuequivariance`, etc.). - `--sequences_csv`: Optional CSV with headers `sequence,target_id` for batched inference. - +- `--max_len`: Maximum length of the sequence. Longer sequences will be skipped during inference (default: `10000`). +- `--logger`: Logger to use by the inference runner (default: `logging`). Supports `logging` and `print`. ### Acceleration diff --git a/configs/configs_data.py b/configs/configs_data.py index dc7a9d1..8c10e87 100644 --- a/configs/configs_data.py +++ b/configs/configs_data.py @@ -160,7 +160,7 @@ or (not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH)) or (not os.path.exists(PDB_CLUSTER_FILE_PATH)) ): - print("Try to find the ccd cache data in the code directory for inference.") + # print("Try to find the ccd cache data in the code directory for inference.") current_file_path = os.path.abspath(__file__) current_directory = os.path.dirname(current_file_path) code_directory = os.path.dirname(current_directory) diff --git a/rnapro/config/config.py b/rnapro/config/config.py index 3a4513e..55bfcec 100644 --- a/rnapro/config/config.py +++ b/rnapro/config/config.py @@ -222,7 +222,7 @@ def merge_configs(self, new_configs: dict) -> ConfigDict: def parse_configs( configs: dict, arg_str: str = None, fill_required_with_null: bool = False -) -> ConfigDict: +): """ Parses and merges configuration settings from a dictionary and command-line arguments. @@ -237,6 +237,25 @@ def parse_configs( """ manager = ConfigManager(configs, fill_required_with_null=fill_required_with_null) parser = argparse.ArgumentParser() + + # This is new + parser.add_argument( + "--max_len", + type=int, + default=10000, + required=False, + help="Maximum length of the sequence. Longer sequences will be skipped during inference" + ) + + # This is new + parser.add_argument( + "--logger", + type=str, + default="logging", + required=False, + help="Logger to use during inference. Supports 'logging' and 'print'" + ) + # Register arguments for key, ( dtype, @@ -252,6 +271,13 @@ def parse_configs( merged_configs = manager.merge_configs( vars(parser.parse_args(arg_str.split())) if arg_str else {} ) + + max_len = parser.parse_args(arg_str.split()).max_len + merged_configs.max_len = max_len + + logger = parser.parse_args(arg_str.split()).logger + merged_configs.logger = logger + return merged_configs diff --git a/rnapro/data/infer_data_pipeline.py b/rnapro/data/infer_data_pipeline.py index 570bae6..b070bde 100644 --- a/rnapro/data/infer_data_pipeline.py +++ b/rnapro/data/infer_data_pipeline.py @@ -29,18 +29,16 @@ import json import logging -import os import time import traceback import warnings -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Mapping, Optional import numpy as np import torch -from biotite.structure import AtomArray + from torch.utils.data import DataLoader, Dataset, DistributedSampler -from biotite.structure import AtomArray + from biotite.structure.atoms import AtomArray -from torch.utils.data import Dataset from rnapro.data.tokenizer import TokenArray from pathlib import Path @@ -114,94 +112,94 @@ def __init__( with open(self.input_json_path, "r") as f: self.inputs = json.load(f) - self.ribonanza_net_tokenizer= { - 'A': 0, - 'C': 1, - 'G': 2, - 'U': 3, - 'PAD': 4, - 'X': 5, + self.ribonanza_net_tokenizer = { + "A": 0, + "C": 1, + "G": 2, + "U": 3, + "PAD": 4, + "X": 5, } self.use_template = use_template logger.info(template_data) self.template_features = torch.load(template_data, weights_only=False) - print('#########################', 'template loaded', template_data) - + # print("- Template loaded : "" template_data) def _read_msa_file(self, target_id: str) -> Optional[list[str]]: """ Read MSA sequences from {target_id}.MSA.fasta file. - + Args: target_id: Target identifier for filename construction. - + Returns: List of RNA sequences (with gaps '-') or None if unavailable. - + Format: Standard FASTA with RNA nucleotides (A,C,G,U,N) and gaps (-). """ - + msa_file_path = self.rna_msa_dir / f"{target_id}.MSA.fasta" if not msa_file_path.exists(): logger.info(f"MSA file not found for target {target_id}: {msa_file_path}") return None - + try: sequences = [] - with open(msa_file_path, 'r') as f: + with open(msa_file_path, "r") as f: logger.info(f"Reading MSA file for target {target_id}: {msa_file_path}") current_seq = "" for line in f: line = line.strip() - if line.startswith('>'): + if line.startswith(">"): if current_seq: sequences.append(current_seq) current_seq = "" else: current_seq += line - + # Add the last sequence if current_seq: sequences.append(current_seq) - + return sequences if sequences else None - + except (IOError, OSError) as e: logger.warning(f"Error reading MSA file for target {target_id}: {e}") return None + def _create_msa_features( self, single_sample_dict: Mapping[str, Any], atom_array: AtomArray, token_array: TokenArray, crop_start: int = 0, - crop_end: Optional[int] = None + crop_end: Optional[int] = None, ) -> dict: """ Generate MSA features from FASTA files with cropping alignment. - + Nucleotide Encoding: Local indices: A=0, G=1, C=2, U=3, N=4, gap=5 Unified indices: A=21, G=22, C=23, U=24, N=25, gap=31 - + Args: single_sample_dict: Sample dict with cropped target sequence. atom_array: AtomArray (unused). token_array: TokenArray (unused). crop_start: MSA crop start position. crop_end: MSA crop end position (None if no cropping). - + Returns: - Dict with keys: msa, has_deletion, deletion_value, deletion_mean, + Dict with keys: msa, has_deletion, deletion_value, deletion_mean, profile, rna_unpair_num_alignments. Empty dict if no MSA available. - + Note: MSA sequences cropped to match target sequence length. Profile is 32D covering proteins (0-20), RNA (21-25), DNA (26-30), gap (31). """ sample_name = single_sample_dict["name"] - + # Read MSA sequences for this target msa_sequences = self._read_msa_file(sample_name) if not msa_sequences: @@ -209,46 +207,52 @@ def _create_msa_features( logger.debug(f"No MSA data available for {sample_name}") return {} if len(msa_sequences) > self.rna_msa_seq_limit: - msa_sequences = msa_sequences[:self.rna_msa_seq_limit] + msa_sequences = msa_sequences[: self.rna_msa_seq_limit] try: from rnapro.data.constants import RNA_NT_TO_ID - + # The sequence in single_sample_dict is already cropped, so use it as-is sequence = single_sample_dict["sequences"][0]["rnaSequence"]["sequence"] seq_len = len(sequence) - + num_msa_seq = len(msa_sequences) - + # Use official RNA mapping that includes gap character at index 5 rna_mapping_with_gap = RNA_NT_TO_ID - + # Create MSA feature arrays for cropped sequence length using unified residue indices # MSA embedder expects integer indices in unified residue space (converts to one-hot internally) from rnapro.data.constants import RNA_STD_RESIDUES, STD_RESIDUES_WITH_GAP - + # Mapping from local RNA_NT_TO_ID indices to unified system indices rna_local_to_unified = { 0: RNA_STD_RESIDUES["A"], # A: 0 -> 21 - 1: RNA_STD_RESIDUES["G"], # G: 1 -> 22 + 1: RNA_STD_RESIDUES["G"], # G: 1 -> 22 2: RNA_STD_RESIDUES["C"], # C: 2 -> 23 3: RNA_STD_RESIDUES["U"], # U: 3 -> 24 4: RNA_STD_RESIDUES["N"], # N: 4 -> 25 5: STD_RESIDUES_WITH_GAP["-"], # Gap: 5 -> 31 } - + # Create MSA array with integer indices in unified 32D space # MSA embedder expects int64 indices, not one-hot encoding - msa_array = np.full((num_msa_seq, seq_len), 25, dtype=np.int64) # Default to 'N' (index 25) - + msa_array = np.full( + (num_msa_seq, seq_len), 25, dtype=np.int64 + ) # Default to 'N' (index 25) + for seq_idx, msa_seq in enumerate(msa_sequences): # Apply cropping to MSA sequences to match target sequence cropping if crop_end is not None: # Crop MSA sequence to match target sequence cropping - msa_seq_cropped = msa_seq[crop_start:crop_end] if len(msa_seq) > crop_start else msa_seq + msa_seq_cropped = ( + msa_seq[crop_start:crop_end] + if len(msa_seq) > crop_start + else msa_seq + ) else: msa_seq_cropped = msa_seq - + for pos_idx, nucleotide in enumerate(msa_seq_cropped): if pos_idx < seq_len: # Ensure we don't exceed sequence length # Map nucleotide to local index first, then to unified index @@ -256,19 +260,27 @@ def _create_msa_features( unified_idx = rna_local_to_unified[local_idx] # Store unified index directly msa_array[seq_idx, pos_idx] = unified_idx - + # Create deletion features based on gap characters (index 31 in unified space) gap_idx = STD_RESIDUES_WITH_GAP["-"] # Index 31 - has_deletion = (msa_array == gap_idx).astype(np.bool_) # True where gaps are present - + has_deletion = (msa_array == gap_idx).astype( + np.bool_ + ) # True where gaps are present + # Apply arctan transformation to deletion counts for consistency with standard pipeline # Since we have binary gap presence (0 or 1), this maps: 0 -> 0, 1 -> ~0.187 - deletion_counts = has_deletion.astype(np.float32) # Treat gap presence as count=1 - deletion_value = (2 / np.pi) * np.arctan(deletion_counts / 3) # Standard arctan transform - + deletion_counts = has_deletion.astype( + np.float32 + ) # Treat gap presence as count=1 + deletion_value = (2 / np.pi) * np.arctan( + deletion_counts / 3 + ) # Standard arctan transform + # Calculate deletion_mean (mean deletion probability across MSA sequences) - deletion_mean = np.mean(deletion_value, axis=0).astype(np.float32) # Shape: (seq_len,) - + deletion_mean = np.mean(deletion_value, axis=0).astype( + np.float32 + ) # Shape: (seq_len,) + # Create unified 32-dimensional profile by averaging MSA (including gaps) # Profile covers all residue types: proteins (0-20) + RNA (21-25) + DNA (26-30) + gap (31) # This matches the standard _make_msa_profile implementation used in InferenceMSAFeaturizer @@ -276,20 +288,22 @@ def _create_msa_features( res_type_hits = msa_array[..., None] == all_res_types[None, ...] res_type_counts = res_type_hits.sum(axis=0) profile = (res_type_counts / num_msa_seq).astype(np.float32) - + # Create MSA feature dictionary msa_features = { "msa": msa_array, - "has_deletion": has_deletion, + "has_deletion": has_deletion, "deletion_value": deletion_value, "deletion_mean": deletion_mean, "profile": profile, "rna_unpair_num_alignments": np.array([num_msa_seq], dtype=np.int32), } - - logger.debug(f"Created MSA features for {sample_name} with {num_msa_seq} sequences") + + logger.debug( + f"Created MSA features for {sample_name} with {num_msa_seq} sequences" + ) return msa_features - + except Exception as e: logger.warning(f"Error generating MSA features for {sample_name}: {e}") return {} @@ -315,7 +329,7 @@ def process_one( sample2feat = SampleDictToFeatures( single_sample_dict, ) - seq = single_sample_dict["sequences"][0]['rnaSequence']['sequence'] + seq = single_sample_dict["sequences"][0]["rnaSequence"]["sequence"] features_dict, atom_array, token_array = sample2feat.get_feature_dict() features_dict["distogram_rep_atom_mask"] = torch.Tensor( atom_array.distogram_rep_atom_mask @@ -323,8 +337,9 @@ def process_one( entity_poly_type = sample2feat.entity_poly_type t1 = time.time() - msa_features = self._create_msa_features(single_sample_dict, atom_array, token_array) - + msa_features = self._create_msa_features( + single_sample_dict, atom_array, token_array + ) # Make dummy features for not implemented features dummy_feats = ["template"] @@ -340,7 +355,7 @@ def process_one( # Transform to right data type feat = data_type_transform(feat_or_label_dict=features_dict) - feat['seq'] = seq + feat["seq"] = seq t2 = time.time() @@ -421,28 +436,42 @@ def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], AtomArray, s except Exception as e: data, atom_array = {}, None error_message = f"{e}:\n{traceback.format_exc()}" - print('error_message', error_message) + print("error_message", error_message) data["sample_name"] = single_sample_dict["name"] data["sample_index"] = index - - sequence=[self.ribonanza_net_tokenizer[nt] for nt in data['input_feature_dict']['seq']] - sequence=np.array(sequence) - sequence=torch.tensor(sequence) - data['input_feature_dict']['tokenized_seq'] = sequence - print('#'*10, 'use','self.template_idx', self.template_idx) - if self.use_template == 'masked_templates': - if self.use_template and data['sample_name'] in self.template_features: - template_ca = torch.from_numpy(self.template_features[data['sample_name']]['xyz'][:, [self.template_idx]]).permute(1,0,2).float() - data['input_feature_dict']['template_coords'] = template_ca - data['input_feature_dict']['template_coords_mask'] = torch.ones(1, len(sequence), dtype=torch.bool) - data['input_feature_dict']['n_templates'] = torch.tensor([1]) + sequence = [ + self.ribonanza_net_tokenizer[nt] for nt in data["input_feature_dict"]["seq"] + ] + sequence = np.array(sequence) + sequence = torch.tensor(sequence) + data["input_feature_dict"]["tokenized_seq"] = sequence + + if self.use_template == "masked_templates": + print(f"- Using template masked templates #{self.template_idx}") + if self.use_template and data["sample_name"] in self.template_features: + template_ca = ( + torch.from_numpy( + self.template_features[data["sample_name"]]["xyz"][ + :, [self.template_idx] + ] + ) + .permute(1, 0, 2) + .float() + ) + data["input_feature_dict"]["template_coords"] = template_ca + data["input_feature_dict"]["template_coords_mask"] = torch.ones( + 1, len(sequence), dtype=torch.bool + ) + data["input_feature_dict"]["n_templates"] = torch.tensor([1]) else: template_ca = torch.ones(1, len(sequence), 3) - data['input_feature_dict']['template_coords'] = template_ca - data['input_feature_dict']['template_coords_mask'] = torch.ones(1, len(sequence), dtype=torch.bool) - data['input_feature_dict']['n_templates'] = torch.tensor([1]) - elif self.use_template == 'ca_precomputed': + data["input_feature_dict"]["template_coords"] = template_ca + data["input_feature_dict"]["template_coords_mask"] = torch.ones( + 1, len(sequence), dtype=torch.bool + ) + data["input_feature_dict"]["n_templates"] = torch.tensor([1]) + elif self.use_template == "ca_precomputed": # template_idx selects top-k templates: 0->top1, 1->top2, 2->top3, 3->top4, 4->top5 template_combinations = [ [0], @@ -451,11 +480,20 @@ def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], AtomArray, s [0, 1, 2, 3], [0, 1, 2, 3, 4], ] - print('template_combinations[self.template_idx]', template_combinations[self.template_idx]) - template_ca = torch.from_numpy(self.template_features[data['sample_name']]['xyz'][:, template_combinations[self.template_idx]]).permute(1,0,2).float() + print(f"- Using template ca precomputed #{self.template_idx} - combinations: {template_combinations[self.template_idx]}") + template_ca = ( + torch.from_numpy( + self.template_features[data["sample_name"]]["xyz"][ + :, template_combinations[self.template_idx] + ] + ) + .permute(1, 0, 2) + .float() + ) - data['input_feature_dict']['template_coords'] = template_ca - data['input_feature_dict']['template_coords_mask'] = torch.ones(len(template_ca), len(sequence)) - data['input_feature_dict']['n_templates'] = torch.tensor([len(template_ca)]) + data["input_feature_dict"]["template_coords"] = template_ca + data["input_feature_dict"]["template_coords_mask"] = torch.ones( + len(template_ca), len(sequence) + ) + data["input_feature_dict"]["n_templates"] = torch.tensor([len(template_ca)]) return data, atom_array, error_message - diff --git a/rnapro/model/generator.py b/rnapro/model/generator.py index cee1e88..d5ab3c1 100644 --- a/rnapro/model/generator.py +++ b/rnapro/model/generator.py @@ -55,7 +55,7 @@ def __init__( self.sigma_data = sigma_data self.p_mean = p_mean self.p_std = p_std - print(f"train scheduler {self.sigma_data}") + # print(f"train scheduler {self.sigma_data}") def __call__( self, size: torch.Size, device: torch.device = torch.device("cpu") @@ -98,7 +98,7 @@ def __init__( self.s_max = s_max self.s_min = s_min self.rho = rho - print(f"inference scheduler {self.sigma_data}") + # print(f"inference scheduler {self.sigma_data}") def __call__( self, @@ -330,10 +330,10 @@ def sample_diffusion_training( ) for i in range(no_chunks): x_noisy_i = (x_gt_augment + noise)[ - ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size, :, : + ..., i * diffusion_chunk_size: (i + 1) * diffusion_chunk_size, :, : ] t_hat_noise_level_i = sigma[ - ..., i * diffusion_chunk_size : (i + 1) * diffusion_chunk_size + ..., i * diffusion_chunk_size: (i + 1) * diffusion_chunk_size ] x_denoised_i = denoise_net( x_noisy=x_noisy_i, diff --git a/rnapro/model/modules/diffusion.py b/rnapro/model/modules/diffusion.py index 53e2733..c9cc6eb 100644 --- a/rnapro/model/modules/diffusion.py +++ b/rnapro/model/modules/diffusion.py @@ -98,7 +98,7 @@ def __init__( # Line10-Line12 self.transition_s1 = Transition(c_in=self.c_s, n=2) self.transition_s2 = Transition(c_in=self.c_s, n=2) - print(f"Diffusion Module has {self.sigma_data}") + # print(f"Diffusion Module has {self.sigma_data}") def forward( self, diff --git a/rnapro/model/modules/pairformer.py b/rnapro/model/modules/pairformer.py index c744cdf..9856049 100644 --- a/rnapro/model/modules/pairformer.py +++ b/rnapro/model/modules/pairformer.py @@ -33,9 +33,8 @@ import torch import torch.nn as nn -from torch.nn.functional import one_hot -from rnapro.model.modules.primitives import LinearNoBias, Transition, Linear +from rnapro.model.modules.primitives import LinearNoBias, Transition, Linear from rnapro.model.modules.transformer import AttentionPairBias from rnapro.model.utils import ( pad_at_dim, @@ -60,7 +59,6 @@ ) - class PairformerBlock(nn.Module): """Implements Algorithm 17 [Line2-Line8] in AF3 c_hidden_mul is set as openfold @@ -1018,8 +1016,6 @@ def forward( return 0 - - class TemplateEmbedderAllatom(nn.Module): """ Implements Algorithm 16 in AF3 @@ -1081,8 +1077,7 @@ def __init__( self.relu = nn.ReLU() self.projection = LinearNoBias( - in_features=sum(self.input_feature1.values()) - + 5 + 5, + in_features=sum(self.input_feature1.values()) + 5 + 5, out_features=sum(self.input_feature1.values()) + sum(self.input_feature2.values()), ) @@ -1111,13 +1106,13 @@ def forward( """ # In this version, we do not use TemplateEmbedder by setting n_blocks=0 # if "template_restype" not in input_feature_dict or self.n_blocks < 1: - # return 0 + # return 0 # Load relevant features res_type = input_feature_dict["template_restype"] frame_rot = input_feature_dict["template_frame_rot"] frame_t = input_feature_dict["template_frame_t"] frame_mask = input_feature_dict["template_mask_frame"] - cb_coords = input_feature_dict["c1_coords"] + # cb_coords = input_feature_dict["c1_coords"] ca_coords = input_feature_dict["c1_coords"] cb_mask = input_feature_dict["template_mask_cb"] template_mask = input_feature_dict["template_mask"].any(dim=1).float() @@ -1131,17 +1126,23 @@ def forward( b_cb_mask = b_cb_mask[..., None] b_frame_mask = b_frame_mask[..., None] - ca_coords = input_feature_dict['c1_coords']##.to(dtype) + ca_coords = input_feature_dict["c1_coords"] # .to(dtype) B, T, _ = ca_coords.shape - + # Compute template features with torch.autocast(device_type="cuda", enabled=False): # Compute distogram ca_dists = torch.cdist(ca_coords, ca_coords) - boundaries = torch.linspace(self.distogram['min_bin'], self.distogram['max_bin'], self.distogram['no_bins'] - 1) + boundaries = torch.linspace( + self.distogram["min_bin"], + self.distogram["max_bin"], + self.distogram["no_bins"] - 1, + ) boundaries = boundaries.to(ca_dists.device) distogram = (ca_dists[..., None] > boundaries).sum(dim=-1).long() - distogram = one_hot(distogram, num_classes=self.distogram['no_bins']).float() + distogram = one_hot( + distogram, num_classes=self.distogram["no_bins"] + ).float() # Compute unit vector in each frame frame_rot = frame_rot.unsqueeze(1).transpose(-1, -2) @@ -1151,7 +1152,6 @@ def forward( norm = torch.norm(vector, dim=-1, keepdim=True) unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector)) unit_vector = unit_vector.squeeze(-1) - # Concatenate input features a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask] @@ -1170,17 +1170,17 @@ def forward( v = v + a_tij # TODO: pairformer # v = v.view(B * T, *v.shape[2:]) - v = v + self.pairformer_stack(v, v, pair_mask)[1] # first v is dummy and not used because we set c_s=0 in PairformerStack + v = ( + v + self.pairformer_stack(v, v, pair_mask)[1] + ) # first v is dummy and not used because we set c_s=0 in PairformerStack v = self.layernorm_v(v) # v = v.view(B, T, *v.shape[1:]) - # Aggregate templates template_mask = template_mask[:, None, None, None] num_templates = num_templates.unsqueeze(0)[:, None, None] u = (v * template_mask).sum(0) / num_templates - # u = v.sum(dim=0) # TODO: Aggregate templates @@ -1191,7 +1191,7 @@ def forward( # Compute output projection u = self.linear_no_bias_u(self.relu(u)) return u - + class TemplateEmbedderCa(nn.Module): """ @@ -1261,7 +1261,6 @@ def __init__( + sum(self.input_feature2.values()), ) - def forward( self, input_feature_dict: dict[str, Any], @@ -1286,33 +1285,40 @@ def forward( """ # In this version, we do not use TemplateEmbedder by setting n_blocks=0 # if "template_restype" not in input_feature_dict or self.n_blocks < 1: - # return 0 + # return 0 - ca_coords = input_feature_dict['template_ca']##.to(dtype) + ca_coords = input_feature_dict["template_ca"] # .to(dtype) B, T, _ = ca_coords.shape from torch.nn.functional import one_hot + # Compute template features with torch.autocast(device_type="cuda", enabled=False): # Compute distogram ca_dists = torch.cdist(ca_coords, ca_coords) ca_dists[torch.isnan(ca_dists)] = 0.0 - boundaries = torch.linspace(self.distogram['min_bin'], self.distogram['max_bin'], self.distogram['no_bins'] - 1) + boundaries = torch.linspace( + self.distogram["min_bin"], + self.distogram["max_bin"], + self.distogram["no_bins"] - 1, + ) boundaries = boundaries.to(ca_dists.device) distogram = (ca_dists[..., None] > boundaries).sum(dim=-1).long() - distogram = one_hot(distogram, num_classes=self.distogram['no_bins']).float() + distogram = one_hot( + distogram, num_classes=self.distogram["no_bins"] + ).float() # TODO:Compute unit vector in each frame # TODO: Concatenate input features - a_tij = distogram #[distogram, b_cb_mask, unit_vector, b_frame_mask] + a_tij = distogram # [distogram, b_cb_mask, unit_vector, b_frame_mask] # a_tij = torch.cat(a_tij, dim=-1) # TODO: Concatenate restype_i and restype_j - #res_type_i = res_type[:, :, :, None] - #res_type_j = res_type[:, :, None, :] - #res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1) - #res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1) - #a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1) + # res_type_i = res_type[:, :, :, None] + # res_type_j = res_type[:, :, None, :] + # res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1) + # res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1) + # a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1) a_tij = self.projection(a_tij) a_tij = self.linear_no_bias_a(a_tij) @@ -1321,15 +1327,16 @@ def forward( v = v + a_tij # TODO: pairformer # v = v.view(B * T, *v.shape[2:]) - v = v + self.pairformer_stack(v, v, pair_mask)[1] # first v is dummy and not used because we set c_s=0 in PairformerStack + v = ( + v + self.pairformer_stack(v, v, pair_mask)[1] + ) # first v is dummy and not used because we set c_s=0 in PairformerStack v = self.layernorm_v(v) - template_mask = ca_coords.sum(1).sum(1) != 0 template_mask = template_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) v = v * template_mask - u = v.sum(dim=0) + u = v.sum(dim=0) u = u / (torch.clamp(template_mask.sum(), min=1)) u = self.linear_no_bias_u(self.relu(u)) @@ -1339,13 +1346,13 @@ def forward( class RNATemplateEmbedder(nn.Module): """ RNA Template Embedder for processing RNA structural templates. - + This module processes RNA template structures by: 1. Embedding distance information from template coordinates 2. Processing through pairformer attention blocks 3. Averaging over multiple templates 4. Outputting template-informed pair representations - + Args: n_blocks (int): Number of pairformer blocks for processing. Default: 2 c_s (int): Dimension of single representation. Default: 64 @@ -1372,14 +1379,14 @@ def __init__( distance_bin_step: float = 1.25, zero_init_final_linear: bool = True, ) -> None: - + super(RNATemplateEmbedder, self).__init__() self.n_blocks = n_blocks self.c_s = c_s self.c_z = c_z self.c_s_inputs = c_s_inputs - + # Distance binning parameters self.distance_bin_start = distance_bin_start self.distance_bin_end = distance_bin_end @@ -1399,7 +1406,7 @@ def __init__( self.linear_no_bias_chem = Linear( in_features=1, out_features=self.c_s, bias=False ) - + # Linear layers for combining s_inputs with pair features self.linear_no_bias_s1 = Linear( in_features=self.c_s_inputs, out_features=self.c_z, bias=False @@ -1464,30 +1471,37 @@ def forward( triangle_attention (bool): Enable triangle attention inplace_safe (bool): Enable in-place operations for memory efficiency chunk_size (Optional[int]): Chunk size for memory-efficient processing - + Returns: - torch.Tensor: Template-enhanced pair representations [..., N_token, N_token, c_z] + torch.Tensor: Template-enhanced pair representations [..., N_token, N_token, c_z] """ if self.n_blocks < 1: return z - if "n_templates" not in input_feature_dict or input_feature_dict["n_templates"] < 1: + if ( + "n_templates" not in input_feature_dict + or input_feature_dict["n_templates"] < 1 + ): return z - + # Extract C1' template coordinates and masks - n_templates = input_feature_dict['n_templates'] - template_coords = input_feature_dict['template_coords'] # [..., n_templates, N_token, 3] - template_coords_mask = input_feature_dict['template_coords_mask'] # [..., n_templates, N_token] - + n_templates = input_feature_dict["n_templates"] + template_coords = input_feature_dict[ + "template_coords" + ] # [..., n_templates, N_token, 3] + template_coords_mask = input_feature_dict[ + "template_coords_mask" + ] # [..., n_templates, N_token] + # Initialize single features s = self.input_s_ln(torch.clamp(s, min=-512, max=512)) - if 'chemical_mapping_profile' in input_feature_dict: + if "chemical_mapping_profile" in input_feature_dict: # Embed and add 1D chemical mapping profile if available - chemical_mapping_profile = input_feature_dict['chemical_mapping_profile'] # [..., N_token] - s = s + self.linear_no_bias_chem( - chemical_mapping_profile.unsqueeze(dim=-1) - ) - + chemical_mapping_profile = input_feature_dict[ + "chemical_mapping_profile" + ] # [..., N_token] + s = s + self.linear_no_bias_chem(chemical_mapping_profile.unsqueeze(dim=-1)) + # Initialize pair features with single feature projections z_init = ( self.linear_no_bias_s1(s_inputs)[..., None, :, :] @@ -1499,12 +1513,12 @@ def forward( # Process each template individually for idx in range(n_templates): - + # Start with initial pair features z_pair = z_init + z # Add distance information to pair features - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): _template_coords = template_coords[..., idx, :, :].to(torch.float32) _mask = ( template_coords_mask[..., idx, :][..., :, None] @@ -1525,22 +1539,27 @@ def forward( x=distance_pred, lower_bins=self.lower_bins, upper_bins=self.upper_bins, - ) * _mask[..., None] + ) + * _mask[..., None] + ) + z_pair += ( + self.linear_no_bias_d_wo_onehot(distance_pred.unsqueeze(dim=-1)) + * _mask[..., None] ) - z_pair += self.linear_no_bias_d_wo_onehot( - distance_pred.unsqueeze(dim=-1) - ) * _mask[..., None] else: z_pair = z_pair + self.linear_no_bias_d( one_hot( x=distance_pred, lower_bins=self.lower_bins, upper_bins=self.upper_bins, - ) * _mask[..., None] + ) + * _mask[..., None] + ) + z_pair = ( + z_pair + + self.linear_no_bias_d_wo_onehot(distance_pred.unsqueeze(dim=-1)) + * _mask[..., None] ) - z_pair = z_pair + self.linear_no_bias_d_wo_onehot( - distance_pred.unsqueeze(dim=-1) - ) * _mask[..., None] # Process through pairformer stack _, z_pair = self.pairformer_stack( @@ -1556,7 +1575,7 @@ def forward( # Accumulate results z_template += z_pair - + # Average over templates z_template = z_template / n_templates diff --git a/rnapro/model/modules/ribonanzanet.py b/rnapro/model/modules/ribonanzanet.py index 9a17eaf..6658208 100644 --- a/rnapro/model/modules/ribonanzanet.py +++ b/rnapro/model/modules/ribonanzanet.py @@ -13,28 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -import torch -import torch.nn as nn -from torch import einsum -import torch.nn.functional as F -from einops import rearrange -import torch.utils.checkpoint as checkpoint -import yaml - -class Config: - def __init__(self, **entries): - self.__dict__.update(entries) - self.entries=entries - - def print(self): - print(self.entries) - -def load_config_from_yaml(file_path): - with open(file_path, 'r') as file: - config = yaml.safe_load(file) - return Config(**config) - # Copyright 2021 AlQuraishi Laboratory # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -49,13 +27,35 @@ def load_config_from_yaml(file_path): # See the License for the specific language governing permissions and # limitations under the License. - +import yaml +import math import torch import torch.nn as nn +from torch import einsum +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from einops import rearrange +from torch.nn.parameter import Parameter from functools import partialmethod from typing import Union, List +class Config: + def __init__(self, **entries): + self.__dict__.update(entries) + self.entries = entries + + def print(self): + print(self.entries) + + +def load_config_from_yaml(file_path): + with open(file_path, "r") as file: + config = yaml.safe_load(file) + return Config(**config) + + class Dropout(nn.Module): """ Implementation of dropout with the ability to share the dropout mask @@ -75,7 +75,7 @@ def __init__(self, r: float, batch_dim: Union[int, List[int]]): super(Dropout, self).__init__() self.r = r - if type(batch_dim) == int: + if isinstance(batch_dim, int): batch_dim = [batch_dim] self.batch_dim = batch_dim self.dropout = nn.Dropout(self.r) @@ -114,20 +114,25 @@ class DropoutColumnwise(Dropout): __init__ = partialmethod(Dropout.__init__, batch_dim=-2) -def recursive_linear_init(m,scale_factor): + +def recursive_linear_init(m, scale_factor): for child_name, child in m.named_modules(): - if 'gate' not in child_name: - custom_weight_init(child,scale_factor) + if "gate" not in child_name: + custom_weight_init(child, scale_factor) + def custom_weight_init(m, scale_factor): if isinstance(m, nn.Linear): - d_model = m.in_features # Set d_model to the input dimension of the linear layer - upper = 1.0 / (d_model ** 0.5) * scale_factor - lower = -1.0 / (d_model ** 0.5) * scale_factor + d_model = ( + m.in_features + ) # Set d_model to the input dimension of the linear layer + upper = 1.0 / (d_model**0.5) * scale_factor + lower = -1.0 / (d_model**0.5) * scale_factor torch.nn.init.uniform_(m.weight, lower, upper) if m.bias is not None: torch.nn.init.zeros_(m.bias) + class TransitionLayer(nn.Module): def __init__(self, input_dim, n=4): super(TransitionLayer, self).__init__() @@ -139,18 +144,18 @@ def __init__(self, input_dim, n=4): def forward(self, x): # Step 1: Apply LayerNorm x = self.layer_norm(x) - + # Step 2: Compute a and b using LinearNoBias (implemented with Linear and bias=False) a = self.linear_a(x) b = self.linear_b(x) - + # Step 3: Element-wise multiplication of swish(a) and b swish_a = a * torch.sigmoid(a) # Swish activation directly in forward x = swish_a * b - + # Step 4: Pass through another LinearNoBias layer x = self.linear_out(x) - + return x @@ -158,37 +163,50 @@ def init_weights(m): if m is not None and isinstance(m, nn.Linear): pass + class Mish(nn.Module): def __init__(self): super().__init__() def forward(self, x): - #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) - return x *( torch.tanh(F.softplus(x))) + # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) + return x * (torch.tanh(F.softplus(x))) -from torch.nn.parameter import Parameter def gem(x, p=3, eps=1e-6): - return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p) + return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1.0 / p) + + class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): - super(GeM,self).__init__() - self.p = Parameter(torch.ones(1)*p) + super(GeM, self).__init__() + self.p = Parameter(torch.ones(1) * p) self.eps = eps + def forward(self, x): return gem(x, p=self.p, eps=self.eps) + def __repr__(self): - return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' + return ( + self.__class__.__name__ + + "(" + + "p=" + + "{:.4f}".format(self.p.data.tolist()[0]) + + ", " + + "eps=" + + str(self.eps) + + ")" + ) class ScaledDotProductAttention(nn.Module): - ''' Scaled Dot-Product Attention ''' + """Scaled Dot-Product Attention""" def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) - #self.gamma=torch.tensor(32.0) + # self.gamma=torch.tensor(32.0) def forward(self, q, k, v, mask=None, attn_mask=None): attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature @@ -197,15 +215,16 @@ def forward(self, q, k, v, mask=None, attn_mask=None): attn = attn + mask # this is actually the bias if attn_mask is not None: - attn = attn.float().masked_fill(attn_mask == -1, float('-1e9')) + attn = attn.float().masked_fill(attn_mask == -1, float("-1e9")) attn = self.dropout(F.softmax(attn, dim=-1)) output = torch.matmul(attn, v) return output, attn + class MultiHeadAttention(nn.Module): - ''' Multi-Head Attention module ''' + """Multi-Head Attention module""" def __init__(self, d_model, n_head, d_k, d_v, dropout=0.1): super().__init__() @@ -217,20 +236,19 @@ def __init__(self, d_model, n_head, d_k, d_v, dropout=0.1): self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) - #self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + # self.fc = nn.Linear(n_head * d_v, d_model, bias=False) - self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + self.attention = ScaledDotProductAttention(temperature=d_k**0.5) # self.dropout = nn.Dropout(dropout) # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) - - def forward(self, q, k, v, mask=None,src_mask=None): + def forward(self, q, k, v, mask=None, src_mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) - residual = q + # residual = q # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv @@ -245,11 +263,13 @@ def forward(self, q, k, v, mask=None,src_mask=None): mask = mask # For head axis broadcasting if src_mask is not None: - src_mask=src_mask.clone().unsqueeze(-1).long() - src_mask[src_mask==0]=-1 - src_mask=src_mask.float() - #src_mask=src_mask.unsqueeze(-1)#.float() - attn_mask=torch.matmul(src_mask,src_mask.permute(0,2,1)).unsqueeze(1).long() + src_mask = src_mask.clone().unsqueeze(-1).long() + src_mask[src_mask == 0] = -1 + src_mask = src_mask.float() + # src_mask=src_mask.unsqueeze(-1)#.float() + attn_mask = ( + torch.matmul(src_mask, src_mask.permute(0, 2, 1)).unsqueeze(1).long() + ) q, attn = self.attention(q, k, v, mask=mask, attn_mask=attn_mask) else: q, attn = self.attention(q, k, v, mask=mask) @@ -260,100 +280,130 @@ def forward(self, q, k, v, mask=None,src_mask=None): return q, attn + class ConvTransformerEncoderLayer(nn.Module): - def __init__(self, d_model, nhead, - dim_feedforward, pairwise_dimension, use_triangular_attention, dim_msa, dropout=0.1, k = 3, - ): + def __init__( + self, + d_model, + nhead, + dim_feedforward, + pairwise_dimension, + use_triangular_attention, + dim_msa, + dropout=0.1, + k=3, + ): super(ConvTransformerEncoderLayer, self).__init__() - #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout) - + # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.self_attn = MultiHeadAttention( + d_model, nhead, d_model // nhead, d_model // nhead, dropout=dropout + ) - #self.linear1 = nn.Linear(d_model, dim_feedforward) - #self.dropout = nn.Dropout(dropout) - #self.linear2 = nn.Linear(dim_feedforward, d_model) + # self.linear1 = nn.Linear(d_model, dim_feedforward) + # self.dropout = nn.Dropout(dropout) + # self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) - #self.norm3 = nn.LayerNorm(d_model) - #self.norm4 = nn.LayerNorm(d_model) + # self.norm3 = nn.LayerNorm(d_model) + # self.norm4 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) - #self.dropout3 = nn.Dropout(dropout) - #self.dropout4 = nn.Dropout(dropout) + # self.dropout3 = nn.Dropout(dropout) + # self.dropout4 = nn.Dropout(dropout) - self.pairwise2heads=nn.Linear(pairwise_dimension,nhead,bias=False) - self.pairwise_norm=nn.LayerNorm(pairwise_dimension) + self.pairwise2heads = nn.Linear(pairwise_dimension, nhead, bias=False) + self.pairwise_norm = nn.LayerNorm(pairwise_dimension) self.activation = nn.GELU() - #self.conv=nn.Conv1d(d_model,d_model,k,padding=k//2) + # self.conv=nn.Conv1d(d_model,d_model,k,padding=k//2) - self.triangle_update_out=TriangleMultiplicativeModule(dim=pairwise_dimension,mix='outgoing') - self.triangle_update_in=TriangleMultiplicativeModule(dim=pairwise_dimension,mix='ingoing') + self.triangle_update_out = TriangleMultiplicativeModule( + dim=pairwise_dimension, mix="outgoing" + ) + self.triangle_update_in = TriangleMultiplicativeModule( + dim=pairwise_dimension, mix="ingoing" + ) - self.pair_dropout_out=DropoutRowwise(dropout) - self.pair_dropout_in=DropoutRowwise(dropout) + self.pair_dropout_out = DropoutRowwise(dropout) + self.pair_dropout_in = DropoutRowwise(dropout) - self.use_triangular_attention=use_triangular_attention + self.use_triangular_attention = use_triangular_attention if self.use_triangular_attention: - self.triangle_attention_out=TriangleAttention(in_dim=pairwise_dimension, - dim=pairwise_dimension//4, - wise='row') - self.triangle_attention_in=TriangleAttention(in_dim=pairwise_dimension, - dim=pairwise_dimension//4, - wise='col') - - self.pair_attention_dropout_out=DropoutRowwise(dropout) - self.pair_attention_dropout_in=DropoutColumnwise(dropout) - - self.outer_product_mean=Outer_Product_Mean(in_dim=d_model,dim_msa=dim_msa,pairwise_dim=pairwise_dimension) - + self.triangle_attention_out = TriangleAttention( + in_dim=pairwise_dimension, dim=pairwise_dimension // 4, wise="row" + ) + self.triangle_attention_in = TriangleAttention( + in_dim=pairwise_dimension, dim=pairwise_dimension // 4, wise="col" + ) + + self.pair_attention_dropout_out = DropoutRowwise(dropout) + self.pair_attention_dropout_in = DropoutColumnwise(dropout) + + self.outer_product_mean = Outer_Product_Mean( + in_dim=d_model, dim_msa=dim_msa, pairwise_dim=pairwise_dimension + ) # self.sequence_transititon=TransitionLayer(d_model) # self.pair_transition=TransitionLayer(pairwise_dimension) - self.sequence_transititon=nn.Sequential(nn.Linear(d_model,d_model*4), - nn.ReLU(), - nn.Linear(d_model*4,d_model)) + self.sequence_transititon = nn.Sequential( + nn.Linear(d_model, d_model * 4), nn.ReLU(), nn.Linear(d_model * 4, d_model) + ) - self.pair_transition=nn.Sequential( nn.LayerNorm(pairwise_dimension), - nn.Linear(pairwise_dimension,pairwise_dimension*4), - nn.ReLU(), - nn.Linear(pairwise_dimension*4,pairwise_dimension)) + self.pair_transition = nn.Sequential( + nn.LayerNorm(pairwise_dimension), + nn.Linear(pairwise_dimension, pairwise_dimension * 4), + nn.ReLU(), + nn.Linear(pairwise_dimension * 4, pairwise_dimension), + ) - def forward(self,input): + def forward(self, input): - src , pairwise_features, src_mask, return_aw= input + src, pairwise_features, src_mask, return_aw = input # src_mask=None # return_aw=False - use_gradient_checkpoint=False + # use_gradient_checkpoint = False - pairwise_bias=self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2) + pairwise_bias = self.pairwise2heads( + self.pairwise_norm(pairwise_features) + ).permute(0, 3, 1, 2) - #self attention - res=src - src,attention_weights = self.self_attn(src, src, src, mask=pairwise_bias, src_mask=src_mask) - src=res+self.dropout1(src) + # self attention + res = src + src, attention_weights = self.self_attn( + src, src, src, mask=pairwise_bias, src_mask=src_mask + ) + src = res + self.dropout1(src) src = self.norm1(src) - - #sequence transition - res=src - src=self.sequence_transititon(src) + + # sequence transition + res = src + src = self.sequence_transititon(src) src = res + self.dropout2(src) src = self.norm2(src) - #pair track ops - pairwise_features=pairwise_features+self.outer_product_mean(src) - pairwise_features=pairwise_features+self.pair_dropout_out(self.triangle_update_out(pairwise_features,src_mask)) - pairwise_features=pairwise_features+self.pair_dropout_in(self.triangle_update_in(pairwise_features,src_mask)) + # pair track ops + pairwise_features = pairwise_features + self.outer_product_mean(src) + pairwise_features = pairwise_features + self.pair_dropout_out( + self.triangle_update_out(pairwise_features, src_mask) + ) + pairwise_features = pairwise_features + self.pair_dropout_in( + self.triangle_update_in(pairwise_features, src_mask) + ) if self.use_triangular_attention: - pairwise_features=pairwise_features+self.pair_attention_dropout_out(self.triangle_attention_out(pairwise_features,src_mask)) - pairwise_features=pairwise_features+self.pair_attention_dropout_in(self.triangle_attention_in(pairwise_features,src_mask)) - pairwise_features=pairwise_features+self.pair_transition(pairwise_features) + pairwise_features = pairwise_features + self.pair_attention_dropout_out( + self.triangle_attention_out(pairwise_features, src_mask) + ) + pairwise_features = pairwise_features + self.pair_attention_dropout_in( + self.triangle_attention_in(pairwise_features, src_mask) + ) + pairwise_features = pairwise_features + self.pair_transition(pairwise_features) if return_aw: - return src,pairwise_features,attention_weights + return src, pairwise_features, attention_weights else: - return src,pairwise_features + return src, pairwise_features + class PositionalEncoding(nn.Module): @@ -363,14 +413,16 @@ def __init__(self, d_model, dropout=0.1, max_len=200): pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:x.size(0), :] + x = x + self.pe[: x.size(0), :] return self.dropout(x) @@ -378,18 +430,19 @@ class Outer_Product_Mean(nn.Module): def __init__(self, in_dim=256, dim_msa=32, pairwise_dim=64): super(Outer_Product_Mean, self).__init__() self.proj_down1 = nn.Linear(in_dim, dim_msa) - self.proj_down2 = nn.Linear(dim_msa ** 2, pairwise_dim) + self.proj_down2 = nn.Linear(dim_msa**2, pairwise_dim) - def forward(self,seq_rep, pair_rep=None): - seq_rep=self.proj_down1(seq_rep) - outer_product = torch.einsum('bid,bjc -> bijcd', seq_rep, seq_rep) - outer_product = rearrange(outer_product, 'b i j c d -> b i j (c d)') + def forward(self, seq_rep, pair_rep=None): + seq_rep = self.proj_down1(seq_rep) + outer_product = torch.einsum("bid,bjc -> bijcd", seq_rep, seq_rep) + outer_product = rearrange(outer_product, "b i j c d -> b i j (c d)") outer_product = self.proj_down2(outer_product) if pair_rep is not None: - outer_product=outer_product+pair_rep + outer_product = outer_product + pair_rep + + return outer_product - return outer_product class relpos(nn.Module): @@ -398,7 +451,7 @@ def __init__(self, dim=64): self.linear = nn.Linear(33, dim) def forward(self, src): - L=src.shape[1] + L = src.shape[1] res_id = torch.arange(L).to(src.device).unsqueeze(0) device = res_id.device bin_values = torch.arange(-16, 17, device=device) @@ -410,22 +463,19 @@ def forward(self, src): p = self.linear(d_onehot) return p + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + class TriangleMultiplicativeModule(nn.Module): - def __init__( - self, - *, - dim, - hidden_dim = None, - mix = 'ingoing' - ): + def __init__(self, *, dim, hidden_dim=None, mix="ingoing"): super().__init__() - assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing' + assert mix in {"ingoing", "outgoing"}, "mix must be either ingoing or outgoing" hidden_dim = default(hidden_dim, dim) self.norm = nn.LayerNorm(dim) @@ -440,24 +490,24 @@ def __init__( # initialize all gating to be identity for gate in (self.left_gate, self.right_gate, self.out_gate): - nn.init.constant_(gate.weight, 0.) - nn.init.constant_(gate.bias, 1.) + nn.init.constant_(gate.weight, 0.0) + nn.init.constant_(gate.bias, 1.0) - if mix == 'outgoing': - self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d' - elif mix == 'ingoing': - self.mix_einsum_eq = '... k i d, ... k j d -> ... i j d' + if mix == "outgoing": + self.mix_einsum_eq = "... i k d, ... j k d -> ... i j d" + elif mix == "ingoing": + self.mix_einsum_eq = "... k i d, ... k j d -> ... i j d" self.to_out_norm = nn.LayerNorm(hidden_dim) self.to_out = nn.Linear(hidden_dim, dim) - def forward(self, x, src_mask = None): - src_mask=src_mask.unsqueeze(-1).float() - mask = torch.matmul(src_mask,src_mask.permute(0,2,1)) + def forward(self, x, src_mask=None): + src_mask = src_mask.unsqueeze(-1).float() + mask = torch.matmul(src_mask, src_mask.permute(0, 2, 1)) - assert x.shape[1] == x.shape[2], 'feature map must be symmetrical' + assert x.shape[1] == x.shape[2], "feature map must be symmetrical" if exists(mask): - mask = rearrange(mask, 'b i j -> b i j ()') + mask = rearrange(mask, "b i j -> b i j ()") x = self.norm(x) @@ -484,113 +534,123 @@ def forward(self, x, src_mask = None): class RibonanzaNet(nn.Module): - #def __init__(self, ntoken=5, nclass=1, ninp=512, nhead=8, nlayers=9, kmers=9, dropout=0): + # def __init__(self, ntoken=5, nclass=1, ninp=512, nhead=8, nlayers=9, kmers=9, dropout=0): def __init__(self, config): super(RibonanzaNet, self).__init__() - self.config=config - nhid=config.ninp*4 - self._tied_weights_keys = [] #avoids AttributeError: 'RibonanzaNet' object has no attribute '_tied_weights_keys' + self.config = config + nhid = config.ninp * 4 + self._tied_weights_keys = ( + [] + ) # avoids AttributeError: 'RibonanzaNet' object has no attribute '_tied_weights_keys' self.transformer_encoder = [] - print(f"constructing {config.nlayers} ConvTransformerEncoderLayers") + # print(f"constructing {config.nlayers} ConvTransformerEncoderLayers") for i in range(config.nlayers): - if i!= config.nlayers-1: - k=config.k + if i != config.nlayers - 1: + k = config.k else: - k=1 - self.transformer_encoder.append(ConvTransformerEncoderLayer(d_model = config.ninp, nhead = config.nhead, - dim_feedforward = nhid, - pairwise_dimension= config.pairwise_dimension, - use_triangular_attention=config.use_triangular_attention, - dim_msa=config.dim_msa, - dropout = config.dropout, k=k)) - - self.transformer_encoder= nn.ModuleList(self.transformer_encoder) - - for i,layer in enumerate(self.transformer_encoder): - scale_factor=1/(i+1)**0.5 - #scale_factor=i+1 - #scale_factor=0 - recursive_linear_init(layer,scale_factor) - + k = 1 + self.transformer_encoder.append( + ConvTransformerEncoderLayer( + d_model=config.ninp, + nhead=config.nhead, + dim_feedforward=nhid, + pairwise_dimension=config.pairwise_dimension, + use_triangular_attention=config.use_triangular_attention, + dim_msa=config.dim_msa, + dropout=config.dropout, + k=k, + ) + ) + + self.transformer_encoder = nn.ModuleList(self.transformer_encoder) + + for i, layer in enumerate(self.transformer_encoder): + scale_factor = 1 / (i + 1) ** 0.5 + # scale_factor=i+1 + # scale_factor=0 + recursive_linear_init(layer, scale_factor) + self.encoder = nn.Embedding(config.ntoken, config.ninp, padding_idx=4) - self.decoder = nn.Linear(config.ninp,config.nclass) + self.decoder = nn.Linear(config.ninp, config.nclass) - recursive_linear_init(self.decoder,scale_factor) + recursive_linear_init(self.decoder, scale_factor) - self.outer_product_mean=Outer_Product_Mean(in_dim=config.ninp,dim_msa=config.dim_msa,pairwise_dim=config.pairwise_dimension) - self.pos_encoder=relpos(config.pairwise_dimension) - self.use_gradient_checkpoint=False + self.outer_product_mean = Outer_Product_Mean( + in_dim=config.ninp, + dim_msa=config.dim_msa, + pairwise_dim=config.pairwise_dimension, + ) + self.pos_encoder = relpos(config.pairwise_dimension) + self.use_gradient_checkpoint = False def custom(self, module): def custom_forward(*inputs): inputs = module(inputs[0]) return inputs + return custom_forward - def forward(self, src,src_mask=None,return_aw=False): - B,L=src.shape + def forward(self, src, src_mask=None, return_aw=False): + B, L = src.shape src = src - src = self.encoder(src).reshape(B,L,-1) - - pairwise_features=self.outer_product_mean(src) - pairwise_features=pairwise_features+self.pos_encoder(src) + src = self.encoder(src).reshape(B, L, -1) - attention_weights=[] - for i,layer in enumerate(self.transformer_encoder): - src,pairwise_features=layer([src, pairwise_features, src_mask, return_aw]) + pairwise_features = self.outer_product_mean(src) + pairwise_features = pairwise_features + self.pos_encoder(src) - output = self.decoder(src).squeeze(-1)+pairwise_features.mean()*0 + attention_weights = [] + for i, layer in enumerate(self.transformer_encoder): + src, pairwise_features = layer( + [src, pairwise_features, src_mask, return_aw] + ) + output = self.decoder(src).squeeze(-1) + pairwise_features.mean() * 0 if return_aw: return output, attention_weights else: return output - def get_embeddings(self, src,src_mask=None,return_aw=False): - B,L=src.shape + def get_embeddings(self, src, src_mask=None, return_aw=False): + B, L = src.shape src = src - src = self.encoder(src).reshape(B,L,-1) - + src = self.encoder(src).reshape(B, L, -1) if self.use_gradient_checkpoint: - pairwise_features=checkpoint.checkpoint(self.custom(self.outer_product_mean), src) - pairwise_features=pairwise_features+self.pos_encoder(src) + pairwise_features = checkpoint.checkpoint( + self.custom(self.outer_product_mean), src + ) + pairwise_features = pairwise_features + self.pos_encoder(src) else: - pairwise_features=self.outer_product_mean(src) - pairwise_features=pairwise_features+self.pos_encoder(src) - - - all_sequence_features=[] - all_pairwise_features=[] - for i,layer in enumerate(self.transformer_encoder): - src,pairwise_features=checkpoint.checkpoint(self.custom(layer), - [src, pairwise_features, src_mask, return_aw], - use_reentrant=False) + pairwise_features = self.outer_product_mean(src) + pairwise_features = pairwise_features + self.pos_encoder(src) + + all_sequence_features = [] + all_pairwise_features = [] + for i, layer in enumerate(self.transformer_encoder): + src, pairwise_features = checkpoint.checkpoint( + self.custom(layer), + [src, pairwise_features, src_mask, return_aw], + use_reentrant=False, + ) all_sequence_features.append(src) all_pairwise_features.append(pairwise_features) - all_sequence_features = torch.stack(all_sequence_features,0) - all_pairwise_features = torch.stack(all_pairwise_features,0) + all_sequence_features = torch.stack(all_sequence_features, 0) + all_pairwise_features = torch.stack(all_pairwise_features, 0) return all_sequence_features, all_pairwise_features - - - class TriangleAttention(nn.Module): - def __init__(self, in_dim=128, dim=32, n_heads=4, wise='row'): + def __init__(self, in_dim=128, dim=32, n_heads=4, wise="row"): super(TriangleAttention, self).__init__() self.n_heads = n_heads self.wise = wise self.norm = nn.LayerNorm(in_dim) self.to_qkv = nn.Linear(in_dim, dim * 3 * n_heads, bias=False) self.linear_for_pair = nn.Linear(in_dim, n_heads, bias=False) - self.to_gate = nn.Sequential( - nn.Linear(in_dim, in_dim), - nn.Sigmoid() - ) + self.to_gate = nn.Sequential(nn.Linear(in_dim, in_dim), nn.Sigmoid()) self.to_out = nn.Linear(n_heads * dim, in_dim) # self.to_out.weight.data.fill_(0.) # self.to_out.bias.data.fill_(0.) @@ -609,50 +669,51 @@ def forward(self, z, src_mask): take src_mask and spawn pairwise mask, and unsqueeze accordingly """ - #spwan pair mask - src_mask[src_mask==0]=-1 - src_mask=src_mask.unsqueeze(-1).float() - attn_mask=torch.matmul(src_mask,src_mask.permute(0,2,1)) - + # spwan pair mask + src_mask[src_mask == 0] = -1 + src_mask = src_mask.unsqueeze(-1).float() + attn_mask = torch.matmul(src_mask, src_mask.permute(0, 2, 1)) wise = self.wise z = self.norm(z) q, k, v = torch.chunk(self.to_qkv(z), 3, -1) - q, k, v = map(lambda x: rearrange(x, 'b i j (h d)->b i j h d', h=self.n_heads), (q, k, v)) + q, k, v = map( + lambda x: rearrange(x, "b i j (h d)->b i j h d", h=self.n_heads), (q, k, v) + ) b = self.linear_for_pair(z) gate = self.to_gate(z) - scale = q.size(-1) ** .5 - if wise == 'row': - eq_attn = 'brihd,brjhd->brijh' - eq_multi = 'brijh,brjhd->brihd' - b = rearrange(b, 'b i j (r h)->b r i j h', r=1) + scale = q.size(-1) ** 0.5 + if wise == "row": + eq_attn = "brihd,brjhd->brijh" + eq_multi = "brijh,brjhd->brihd" + b = rearrange(b, "b i j (r h)->b r i j h", r=1) softmax_dim = 3 - attn_mask=rearrange(attn_mask, 'b i j->b 1 i j 1') - elif wise == 'col': - eq_attn = 'bilhd,bjlhd->bijlh' - eq_multi = 'bijlh,bjlhd->bilhd' - b = rearrange(b, 'b i j (l h)->b i j l h', l=1) + attn_mask = rearrange(attn_mask, "b i j->b 1 i j 1") + elif wise == "col": + eq_attn = "bilhd,bjlhd->bijlh" + eq_multi = "bijlh,bjlhd->bilhd" + b = rearrange(b, "b i j (l h)->b i j l h", l=1) softmax_dim = 2 - attn_mask=rearrange(attn_mask, 'b i j->b i j 1 1') + attn_mask = rearrange(attn_mask, "b i j->b i j 1 1") else: - raise ValueError('wise should be col or row!') - logits = (torch.einsum(eq_attn, q, k) / scale + b) - logits = logits.masked_fill(attn_mask == -1, float('-1e-9')) + raise ValueError("wise should be col or row!") + logits = torch.einsum(eq_attn, q, k) / scale + b + logits = logits.masked_fill(attn_mask == -1, float("-1e-9")) attn = logits.softmax(softmax_dim) out = torch.einsum(eq_multi, attn, v) - out = gate * rearrange(out, 'b i j h d-> b i j (h d)') + out = gate * rearrange(out, "b i j h d-> b i j (h d)") z_ = self.to_out(out) return z_ class GatedSequenceFeatureInjector(nn.Module): - def __init__(self, c_s_new: int, c_s: int, gate_type='channel'): + def __init__(self, c_s_new: int, c_s: int, gate_type="channel"): super().__init__() self.proj = nn.Linear(c_s_new, c_s) # project LM feature to match s dim - if gate_type == 'channel': + if gate_type == "channel": self.gate_param = nn.Parameter(torch.zeros(c_s)) # one gate per channel - elif gate_type == 'scalar': + elif gate_type == "scalar": self.gate_param = nn.Parameter(torch.tensor(0.0)) # one global gate else: raise ValueError("gate_type must be 'channel' or 'scalar'") @@ -670,7 +731,7 @@ def forward(self, s: torch.Tensor, new_seq_feature: torch.Tensor) -> torch.Tenso """ new_proj = self.proj(new_seq_feature) # [N_res, C_s] - if self.gate_type == 'channel': + if self.gate_type == "channel": gate = torch.sigmoid(self.gate_param).view(1, -1) # [1,1,C_s] else: gate = torch.sigmoid(self.gate_param) # scalar @@ -679,20 +740,20 @@ def forward(self, s: torch.Tensor, new_seq_feature: torch.Tensor) -> torch.Tenso gated_feature = gate * new_proj # [1, N_res, C_s] return s + gated_feature - + class GatedPairwiseFeatureInjector(nn.Module): - def __init__(self, c_pair: int, c_z: int, gate_type='channel'): + def __init__(self, c_pair: int, c_z: int, gate_type="channel"): super().__init__() self.proj = nn.Linear(c_pair, c_z) # project pair_feature to match z dim - if gate_type == 'channel': + if gate_type == "channel": self.gate_param = nn.Parameter(torch.zeros(c_z)) # one gate per channel - elif gate_type == 'scalar': + elif gate_type == "scalar": self.gate_param = nn.Parameter(torch.tensor(0.0)) # one global gate else: raise ValueError("gate_type must be 'channel' or 'scalar'") - + self.gate_type = gate_type def forward(self, z: torch.Tensor, pair_feature: torch.Tensor) -> torch.Tensor: @@ -706,7 +767,7 @@ def forward(self, z: torch.Tensor, pair_feature: torch.Tensor) -> torch.Tensor: """ pair_proj = self.proj(pair_feature) # [N_res, N_res, C_z] - if self.gate_type == 'channel': + if self.gate_type == "channel": gate = torch.sigmoid(self.gate_param).view(1, 1, -1) # [1,1,C_z] else: # scalar gate gate = torch.sigmoid(self.gate_param) # scalar @@ -716,10 +777,11 @@ def forward(self, z: torch.Tensor, pair_feature: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": - from Functions import * + from Functions import * # noqa + config = load_config_from_yaml("configs/pairwise.yaml") - model=RibonanzaNet(config).cuda() - x=torch.ones(4,128).long().cuda() - mask=torch.ones(4,128).long().cuda() - mask[:,120:]=0 - print(model(x,src_mask=mask).shape) + model = RibonanzaNet(config).cuda() + x = torch.ones(4, 128).long().cuda() + mask = torch.ones(4, 128).long().cuda() + mask[:, 120:] = 0 + print(model(x, src_mask=mask).shape) diff --git a/rnapro/openfold_local/utils/precision_utils.py b/rnapro/openfold_local/utils/precision_utils.py index 43cfb74..ed0970c 100644 --- a/rnapro/openfold_local/utils/precision_utils.py +++ b/rnapro/openfold_local/utils/precision_utils.py @@ -17,7 +17,7 @@ def is_fp16_enabled(): # Autocast world - fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = torch.get_autocast_dtype('cuda') == torch.float16 fp16_enabled = fp16_enabled and torch.is_autocast_enabled() return fp16_enabled diff --git a/rnapro/utils/inference.py b/rnapro/utils/inference.py new file mode 100644 index 0000000..b7e1a51 --- /dev/null +++ b/rnapro/utils/inference.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import numpy as np +import pandas as pd +from biotite.structure.io import pdbx + + +class dotdict(dict): + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + return None + + +def create_input_json(sequence, target_id, verbose=False): + if verbose: + print("input_no_msa") + input_json = [ + { + "sequences": [ + { + "rnaSequence": { + "sequence": sequence, + "count": 1, + } + } + ], + "name": target_id, + } + ] + return input_json + + +def extract_c1_coordinates(cif_file_path): + try: + # Read the CIF file using the correct biotite method + with open(cif_file_path, "r") as f: + cif_data = pdbx.CIFFile.read(f) + + # Get structure from CIF data + atom_array = pdbx.get_structure(cif_data, model=1) + + # Clean atom names and find C1' atoms + atom_names_clean = np.char.strip(atom_array.atom_name.astype(str)) + mask_c1 = atom_names_clean == "C1'" + c1_atoms = atom_array[mask_c1] + + if len(c1_atoms) == 0: + print(f"Warning: No C1' atoms found in {cif_file_path}") + return None + + # Sort by residue ID and return coordinates + sort_indices = np.argsort(c1_atoms.res_id) + c1_atoms_sorted = c1_atoms[sort_indices] + c1_coords = c1_atoms_sorted.coord + + return c1_coords + except Exception as e: + print(f"Error extracting C1' coordinates from {cif_file_path}: {e}") + return None + + +def process_sequence(sequence, target_id, temp_dir, verbose=False): + if verbose: + print(f"Processing {target_id}: {sequence}") + + # Create input JSON + input_json = create_input_json(sequence, target_id, verbose=verbose) + + # Save JSON to temporary file + os.makedirs(temp_dir, exist_ok=True) + input_json_path = os.path.join(temp_dir, f"{target_id}_input.json") + with open(input_json_path, "w") as f: + json.dump(input_json, f, indent=4) + + +def solution_to_submit_df(solution): + submit_df = [] + for k, s in solution.items(): + df = coord_to_df(s.sequence, s.coord, s.target_id) + submit_df.append(df) + + submit_df = pd.concat(submit_df) + return submit_df + + +def coord_to_df(sequence, coord, target_id): + L = len(sequence) + df = pd.DataFrame() + df["ID"] = [f"{target_id}_{i + 1}" for i in range(L)] + df["resname"] = [s for s in sequence] + df["resid"] = [i + 1 for i in range(L)] + + num_coord = len(coord) + for j in range(num_coord): + df[f"x_{j+1}"] = coord[j][:, 0] + df[f"y_{j+1}"] = coord[j][:, 1] + df[f"z_{j+1}"] = coord[j][:, 2] + return df + + +def update_inference_configs(configs, N_token): + # Setting the default inference configs for different N_token and N_atom + # when N_token is larger than 3000, the default config might OOM even on a + # A100 80G GPUS, + if N_token > 3840: + configs.skip_amp.confidence_head = False + configs.skip_amp.sample_diffusion = False + elif N_token > 2560: + configs.skip_amp.confidence_head = False + configs.skip_amp.sample_diffusion = True + else: + configs.skip_amp.confidence_head = True + configs.skip_amp.sample_diffusion = True + return configs + + +# data helper +def make_dummy_solution(valid_df): + solution = dotdict() + for i, row in valid_df.iterrows(): + target_id = row.target_id + sequence = row.sequence + solution[target_id] = dotdict( + target_id=target_id, + sequence=sequence, + coord=[], + ) + return solution diff --git a/rnapro_inference_example.sh b/rnapro_inference_example.sh index a3a8dca..21b72fe 100644 --- a/rnapro_inference_example.sh +++ b/rnapro_inference_example.sh @@ -1,8 +1,5 @@ export LAYERNORM_TYPE=torch # fast_layernorm, torch -# Kernel options: -# - triangle_attention: supports 'triattention', 'cuequivariance', 'deepspeed', 'torch' -# - triangle_multiplicative: supports 'cuequivariance'(default), 'torch' # Inference parameters (RNAPro) SEED=42 @@ -12,6 +9,7 @@ N_CYCLE=10 # Paths DUMP_DIR="./output" + # Set a valid checkpoint file path below CHECKPOINT_PATH="./rnapro_base.pt" @@ -20,8 +18,13 @@ TEMPLATE_DATA="./examples/test_templates.pt" # Note: template_idx supports 5 choices and maps to top-k: # 0->top1, 1->top2, 2->top3, 3->top4, 4->top5 TEMPLATE_IDX=0 + +# MSA directory RNA_MSA_DIR="./msa" + +# Sequences to process SEQUENCES_CSV="./examples/test_sequences.csv" + # RibonanzaNet2 path (keep as-is per request) RIBONANZA_PATH="./release_data/ribonanzanet2_checkpoint" @@ -30,23 +33,32 @@ MODEL_NAME="rnapro_base" mkdir -p "${DUMP_DIR}" python3 runner/inference.py \ ---model_name "${MODEL_NAME}" \ ---seeds ${SEED} \ ---dump_dir "${DUMP_DIR}" \ ---load_checkpoint_path "${CHECKPOINT_PATH}" \ ---use_msa true \ ---use_template "ca_precomputed" \ ---model.use_template "ca_precomputed" \ ---model.use_RibonanzaNet2 true \ ---model.template_embedder.n_blocks 2 \ ---model.ribonanza_net_path "${RIBONANZA_PATH}" \ ---template_data "${TEMPLATE_DATA}" \ ---template_idx ${TEMPLATE_IDX} \ ---rna_msa_dir "${RNA_MSA_DIR}" \ ---model.N_cycle ${N_CYCLE} \ ---sample_diffusion.N_sample ${N_SAMPLE} \ ---sample_diffusion.N_step ${N_STEP} \ ---load_strict true \ ---num_workers 0 \ ---triangle_attention "cuequivariance" \ ---triangle_multiplicative "cuequivariance" --sequences_csv "${SEQUENCES_CSV}" \ No newline at end of file + --model_name "${MODEL_NAME}" \ + --seeds ${SEED} \ + --dump_dir "${DUMP_DIR}" \ + --load_checkpoint_path "${CHECKPOINT_PATH}" \ + --use_msa true \ + --use_template "ca_precomputed" \ + --model.use_template "ca_precomputed" \ + --model.use_RibonanzaNet2 true \ + --model.template_embedder.n_blocks 2 \ + --model.ribonanza_net_path "${RIBONANZA_PATH}" \ + --template_data "${TEMPLATE_DATA}" \ + --template_idx ${TEMPLATE_IDX} \ + --rna_msa_dir "${RNA_MSA_DIR}" \ + --model.N_cycle ${N_CYCLE} \ + --sample_diffusion.N_sample ${N_SAMPLE} \ + --sample_diffusion.N_step ${N_STEP} \ + --load_strict true \ + --num_workers 0 \ + --triangle_attention "cuequivariance" \ + --triangle_multiplicative "cuequivariance" \ + --sequences_csv "${SEQUENCES_CSV}" \ + --max_len 5000 \ + --logger "logging" + +# Notes: +# --triangle_attention supports 'triattention', 'cuequivariance', 'deepspeed', 'torch' +# --triangle_multiplicative supports 'cuequivariance', 'torch' +# --max_len 1000: Sequences longer than max_len will be skipped to avoid oom +# --logger handles logging of the inference runner, supports "logging", "print" \ No newline at end of file diff --git a/runner/inference.py b/runner/inference.py index c1ac965..c150b6d 100644 --- a/runner/inference.py +++ b/runner/inference.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import os import shutil import logging import traceback @@ -21,12 +21,9 @@ from os.path import join as opjoin from typing import Any, Mapping -import json import torch import pandas as pd import numpy as np -from tqdm import tqdm -from biotite.structure.io import pdbx from configs.configs_base import configs as configs_base from configs.configs_data import data_configs @@ -40,24 +37,22 @@ from rnapro.utils.seed import seed_everything from rnapro.utils.torch_utils import to_device - -class dotdict(dict): - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def __getattr__(self, name): - try: - return self[name] - except KeyError: - raise AttributeError(name) - - -logger = logging.getLogger(__name__) +from rnapro.utils.inference import ( + update_inference_configs, + make_dummy_solution, + solution_to_submit_df, + process_sequence, + extract_c1_coordinates, +) class InferenceRunner(object): def __init__(self, configs: Any) -> None: + if configs.logger == "logging": + self.logger = logging.getLogger(__name__) + self.configs = configs + self.init_env() self.init_basics() self.init_model() @@ -78,7 +73,7 @@ def init_env(self) -> None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) - logging.info( + self.print( f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]" ) torch.cuda.set_device(self.device) @@ -91,16 +86,16 @@ def init_env(self) -> None: env is not None ), "if use ds4sci, set `CUTLASS_PATH` env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/" if env is not None: - logging.info( + self.print( "The kernels will be compiled when DS4Sci_EvoformerAttention is called for the first time." ) use_fastlayernorm = os.getenv("LAYERNORM_TYPE", None) if use_fastlayernorm == "fast_layernorm": - logging.info( + self.print( "The kernels will be compiled when fast_layernorm is called for the first time." ) - logging.info("Finished init ENV.") + self.print("Finished initializing environment !") def init_basics(self) -> None: self.dump_dir = self.configs.dump_dir @@ -109,17 +104,13 @@ def init_basics(self) -> None: os.makedirs(self.error_dir, exist_ok=True) def init_model(self) -> None: - self.model = RNAPro(self.configs).to(self.device) - print(self.model) num_params = sum(p.numel() for p in self.model.parameters()) - print(f"Total number of parameters: {num_params:,}") - + self.print(f"Built model ! Total number of parameters: {num_params:,}") def load_checkpoint(self) -> None: checkpoint_path = self.configs.load_checkpoint_path - print(checkpoint_path) - + if not os.path.exists(checkpoint_path): raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") self.print( @@ -128,7 +119,7 @@ def load_checkpoint(self) -> None: checkpoint = torch.load(checkpoint_path, self.device) sample_key = [k for k in checkpoint["model"].keys()][0] - self.print(f"Sampled key: {sample_key}") + # self.print(f"Sampled key: {sample_key}") if sample_key.startswith("module."): # DDP checkpoint has module. prefix checkpoint["model"] = { k[len("module."):]: v for k, v in checkpoint["model"].items() @@ -138,7 +129,7 @@ def load_checkpoint(self) -> None: strict=True, ) self.model.eval() - self.print(f"Finish loading checkpoint.") + # self.print("Finish loading checkpoint.") def init_dumper( self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True @@ -153,6 +144,9 @@ def print_dict(self, d): for k, v in d.items(): if isinstance(v, torch.Tensor): print(f"{k}: ", v.shape) + else: + pass + # print(f"{k}: {v}") # Adapted from runner.train.Trainer.evaluate @torch.no_grad() @@ -162,12 +156,14 @@ def predict(self, data: Mapping[str, Mapping[str, Any]]) -> dict[str, torch.Tens "bf16": torch.bfloat16, "fp16": torch.float16, }[self.configs.dtype] - print('eval_precision: ', eval_precision) + # print("eval_precision: ", eval_precision) enable_amp = ( torch.autocast(device_type="cuda", dtype=eval_precision) if torch.cuda.is_available() else nullcontext() ) + # print('input_feature_dict: ', self.print_dict(data["input_feature_dict"])) + # exit(0) data = to_device(data, self.device) with enable_amp: @@ -182,41 +178,34 @@ def predict(self, data: Mapping[str, Mapping[str, Any]]) -> dict[str, torch.Tens def print(self, msg: str): if DIST_WRAPPER.rank == 0: - logger.info(msg) + if self.configs.logger == "logging": + self.logger.info(msg) + elif self.configs.logger == "print": + print(msg) def update_model_configs(self, new_configs: Any) -> None: self.model.configs = new_configs -def update_inference_configs(configs: Any, N_token: int): - # Setting the default inference configs for different N_token and N_atom - # when N_token is larger than 3000, the default config might OOM even on a - # A100 80G GPUS, - if N_token > 3840: - configs.skip_amp.confidence_head = False - configs.skip_amp.sample_diffusion = False - elif N_token > 2560: - configs.skip_amp.confidence_head = False - configs.skip_amp.sample_diffusion = True - else: - configs.skip_amp.confidence_head = True - configs.skip_amp.sample_diffusion = True - return configs - - def infer_predict(runner: InferenceRunner, configs: Any) -> None: - # Data - logger.info(f"Loading data from\n{configs.input_json_path}") + """ + Infer the sequence for a given runner and configs. + + Args: + runner (InferenceRunner): The runner to be used. + configs (ConfigDict): The configurations for the inference. + """ + # Load the dataloader try: dataloader = get_inference_dataloader(configs=configs) except Exception as e: error_message = f"{e}:\n{traceback.format_exc()}" - logger.info(error_message) + runner.print(error_message) with open(opjoin(runner.error_dir, "error.txt"), "a") as f: f.write(error_message) return - num_data = len(dataloader.dataset) + # num_data = len(dataloader.dataset) for seed in configs.seeds: seed_everything(seed=seed, deterministic=configs.deterministic) for batch in dataloader: @@ -225,21 +214,25 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None: sample_name = data["sample_name"] if len(data_error_message) > 0: - logger.info(data_error_message) + runner.print(data_error_message) with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(data_error_message) continue - logger.info( - ( - f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " - f"N_asym {data['N_asym'].item()}, N_token {data['N_token'].item()}, " - f"N_atom {data['N_atom'].item()}, N_msa {data['N_msa'].item()}" - ) - ) + # runner.print( + # ( + # f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " + # f"N_asym {data['N_asym'].item()}, N_token {data['N_token'].item()}, " + # f"N_atom {data['N_atom'].item()}, N_msa {data['N_msa'].item()}" + # ) + # ) new_configs = update_inference_configs(configs, data["N_token"].item()) runner.update_model_configs(new_configs) + + # Predict prediction = runner.predict(data) + + # Save runner.dumper.dump( dataset_name="", pdb_id=sample_name, @@ -249,137 +242,51 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None: entity_poly_type=data["entity_poly_type"], ) - logger.info( - f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded.\n" - f"Results saved to {configs.dump_dir}" - ) + runner.print(f"Results saved to {configs.dump_dir}/{sample_name}/seed_{seed}/predictions/") torch.cuda.empty_cache() except Exception as e: error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}" - logger.info(error_message) + runner.print(error_message) # Save error info with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(error_message) if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache() -# data helper -def make_dummy_solution(valid_df): - solution=dotdict() - for i, row in valid_df.iterrows(): - target_id = row.target_id - sequence = row.sequence - solution[target_id]=dotdict( - target_id=target_id, - sequence=sequence, - coord=[], - ) - return solution - -def solution_to_submit_df(solution): - submit_df = [] - for k,s in solution.items(): - df = coord_to_df(s.sequence, s.coord, s.target_id) - submit_df.append(df) - - submit_df = pd.concat(submit_df) - return submit_df - - -def coord_to_df(sequence, coord, target_id): - L = len(sequence) - df = pd.DataFrame() - df['ID'] = [f'{target_id}_{i + 1}' for i in range(L)] - df['resname'] = [s for s in sequence] - df['resid'] = [i + 1 for i in range(L)] - - num_coord = len(coord) - for j in range(num_coord): - df[f'x_{j+1}'] = coord[j][:, 0] - df[f'y_{j+1}'] = coord[j][:, 1] - df[f'z_{j+1}'] = coord[j][:, 2] - return df - - -def main(configs: Any) -> None: - # Runner - runner = InferenceRunner(configs) - infer_predict(runner, configs) - -def create_input_json(sequence, target_id): - print('input_no_msa') - input_json = [{ - "sequences": [ - { - "rnaSequence": { - "sequence": sequence, - "count": 1, - } - } - ], - "name": target_id, - }] - return input_json +def run_ptx(target_id, sequence, configs, solution, template_idx, runner): + """ + Run the inference for a given target_id, sequence, configs, solution, and template_idx. -def extract_c1_coordinates(cif_file_path): - try: - # Read the CIF file using the correct biotite method - with open(cif_file_path, 'r') as f: - cif_data = pdbx.CIFFile.read(f) - - # Get structure from CIF data - atom_array = pdbx.get_structure(cif_data, model=1) - - # Clean atom names and find C1' atoms - atom_names_clean = np.char.strip(atom_array.atom_name.astype(str)) - mask_c1 = atom_names_clean == "C1'" - c1_atoms = atom_array[mask_c1] - - if len(c1_atoms) == 0: - print(f"Warning: No C1' atoms found in {cif_file_path}") - return None - - # Sort by residue ID and return coordinates - sort_indices = np.argsort(c1_atoms.res_id) - c1_atoms_sorted = c1_atoms[sort_indices] - c1_coords = c1_atoms_sorted.coord - - return c1_coords - except Exception as e: - print(f"Error extracting C1' coordinates from {cif_file_path}: {e}") - return None - -def process_sequence(sequence, target_id, temp_dir): - print(f"Processing {target_id}: {sequence}") - - # Create input JSON - input_json = create_input_json(sequence, target_id) - - # Save JSON to temporary file - os.makedirs(temp_dir, exist_ok=True) - input_json_path = os.path.join(temp_dir, f"{target_id}_input.json") - with open(input_json_path, "w") as f: - json.dump(input_json, f, indent=4) - - -def run_ptx(target_id, sequence, configs, solution, template_idx): + Args: + target_id (str): The target_id of the sequence. + sequence (str): The sequence to be inferred. + configs (ConfigDict): The configurations for the inference. + solution (DotDict): The solution to be updated. + template_idx (int): The template index to be used. + runner (InferenceRunner): The runner to be used. + """ # Create directories temp_dir = f"./{configs.dump_dir}/input" # Same as in kaggle_inference.py output_dir = f"./{configs.dump_dir}/output" # Same as in kaggle_inference.py os.makedirs(temp_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) - + process_sequence(sequence=sequence, target_id=target_id, temp_dir=temp_dir) configs.input_json_path = os.path.join(temp_dir, f"{target_id}_input.json") configs.template_idx = int(template_idx) - runner = InferenceRunner(configs) + # Run the inference infer_predict(runner, configs) - cif_file_path = f'{configs.dump_dir}/{target_id}/seed_42/predictions/{target_id}_sample_0.cif' - cif_new_path = f'{configs.dump_dir}/{target_id}/seed_42/predictions/{target_id}_sample_{template_idx}_new.cif' + # Copy the CIF file to the new path + cif_file_path = ( + f"{configs.dump_dir}/{target_id}/seed_42/predictions/{target_id}_sample_0.cif" + ) + cif_new_path = f"{configs.dump_dir}/{target_id}/seed_42/predictions/{target_id}_sample_{template_idx}_new.cif" shutil.copy(cif_file_path, cif_new_path) + + # Extract the C1 coordinates coord = extract_c1_coordinates(cif_file_path) if coord is None: coord = np.zeros((len(sequence), 3), dtype=np.float32) @@ -387,6 +294,8 @@ def run_ptx(target_id, sequence, configs, solution, template_idx): pad_len = len(sequence) - coord.shape[0] pad = np.zeros((pad_len, 3), dtype=np.float32) coord = np.concatenate([coord, pad], axis=0) + + # Update the solution solution[target_id].coord.append(coord) @@ -394,10 +303,13 @@ def run() -> None: LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" logging.basicConfig( format=LOG_FORMAT, - level=logging.INFO, + level=logging.WARNING, datefmt="%Y-%m-%d %H:%M:%S", filemode="w", ) + # Silence dataloader logging + logging.getLogger("rnapro.data").setLevel(logging.WARNING) + configs_base["use_deepspeed_evo_attention"] = ( os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true" ) @@ -409,23 +321,47 @@ def run() -> None: ) valid_df = pd.read_csv(configs.sequences_csv) - print(f"Loaded {len(valid_df)} valid sequences") + print(f"\n -> Loaded {len(valid_df)} sequence(s)") + + # Build model and load checkpoint once before looping over sequences + + print('\n -> Building model and loading checkpoint\n') + runner = InferenceRunner(configs) + print('\n -> Done, starting inference...') solution = make_dummy_solution(valid_df) - for idx, row in tqdm(valid_df.iterrows()): + for idx, row in valid_df.iterrows(): + print(f"\n -> Sequence {row.target_id}: {row.sequence}") + + if len(row.sequence) > configs.max_len: + print(f'Sequence is too long ({len(row.sequence)} > {configs.max_len}), skipping') + for template_idx in range(5): + coord = np.zeros((len(row.sequence), 3), dtype=np.float32) + solution[row.target_id].coord.append(coord) + continue + try: target_id = row.target_id sequence = row.sequence - for template_idx in tqdm(range(5)): - run_ptx(target_id=target_id, sequence=sequence, configs=configs, solution=solution, - template_idx=template_idx) + for template_idx in range(5): + print() + run_ptx( + target_id=target_id, + sequence=sequence, + configs=configs, + solution=solution, + template_idx=template_idx, + runner=runner, + ) except Exception as e: print(f"Error processing {row.target_id}: {e}") continue + + print('\n\n -> Inference done ! Saving to submission.csv') submit_df = solution_to_submit_df(solution) submit_df = submit_df.fillna(0.0) submit_df.to_csv("./submission.csv", index=False) - print(submit_df) + if __name__ == "__main__": run() From 4cf03cda29fe085f4eae968ebd7c7e950b91bdf1 Mon Sep 17 00:00:00 2001 From: Theo Viel Date: Fri, 16 Jan 2026 06:08:37 -0800 Subject: [PATCH 2/3] Templates arg, verbose, formatting train --- rnapro/config/config.py | 18 +++++---- runner/inference.py | 6 +-- runner/train.py | 86 +++++++++++++++++++++-------------------- 3 files changed, 58 insertions(+), 52 deletions(-) diff --git a/rnapro/config/config.py b/rnapro/config/config.py index 55bfcec..841f7e2 100644 --- a/rnapro/config/config.py +++ b/rnapro/config/config.py @@ -246,8 +246,6 @@ def parse_configs( required=False, help="Maximum length of the sequence. Longer sequences will be skipped during inference" ) - - # This is new parser.add_argument( "--logger", type=str, @@ -255,6 +253,13 @@ def parse_configs( required=False, help="Logger to use during inference. Supports 'logging' and 'print'" ) + parser.add_argument( + "--num_templates", + type=int, + default=5, + required=False, + help="Number of templates to use during inference" + ) # Register arguments for key, ( @@ -272,11 +277,10 @@ def parse_configs( vars(parser.parse_args(arg_str.split())) if arg_str else {} ) - max_len = parser.parse_args(arg_str.split()).max_len - merged_configs.max_len = max_len - - logger = parser.parse_args(arg_str.split()).logger - merged_configs.logger = logger + args = parser.parse_args(arg_str.split()) + merged_configs.max_len = args.max_len + merged_configs.logger = args.logger + merged_configs.num_templates = args.num_templates return merged_configs diff --git a/runner/inference.py b/runner/inference.py index c150b6d..e00950c 100644 --- a/runner/inference.py +++ b/runner/inference.py @@ -114,7 +114,7 @@ def load_checkpoint(self) -> None: if not os.path.exists(checkpoint_path): raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") self.print( - f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}" + f"Loading weights from {checkpoint_path} (strict={self.configs.load_strict})" ) checkpoint = torch.load(checkpoint_path, self.device) @@ -331,7 +331,7 @@ def run() -> None: solution = make_dummy_solution(valid_df) for idx, row in valid_df.iterrows(): - print(f"\n -> Sequence {row.target_id}: {row.sequence}") + print(f"\n -> Sequence {idx + 1}/{len(valid_df)} : {row.target_id} - {row.sequence}") if len(row.sequence) > configs.max_len: print(f'Sequence is too long ({len(row.sequence)} > {configs.max_len}), skipping') @@ -343,7 +343,7 @@ def run() -> None: try: target_id = row.target_id sequence = row.sequence - for template_idx in range(5): + for template_idx in range(configs.num_templates): print() run_ptx( target_id=target_id, diff --git a/runner/train.py b/runner/train.py index 94a4285..c0632a6 100644 --- a/runner/train.py +++ b/runner/train.py @@ -33,7 +33,6 @@ import os import glob import re -import time from argparse import Namespace from contextlib import nullcontext @@ -89,7 +88,7 @@ def init_basics(self): # Add for grad accumulation, it can increase real batch size self.iters_to_accumulate = self.configs.iters_to_accumulate - self.run_name = self.configs.run_name #+ "_" + time.strftime("%Y%m%d_%H%M%S") + self.run_name = self.configs.run_name # + "_" + time.strftime("%Y%m%d_%H%M%S") run_names = DIST_WRAPPER.all_gather_object( self.run_name if DIST_WRAPPER.rank == 0 else None ) @@ -183,7 +182,7 @@ def init_model(self): self.raw_model = RNAPro(self.configs).to(self.device) self.use_ddp = False if DIST_WRAPPER.world_size > 1: - self.print(f"Using DDP") + self.print("Using DDP") self.use_ddp = True # Fix DDP/checkpoint https://discuss.pytorch.org/t/ddp-and-gradient-checkpointing/132244 self.model = DDP( @@ -255,8 +254,9 @@ def save_checkpoint(self, ema_suffix=""): torch.save(checkpoint, path) self.print(f"Saved checkpoint to {path}") - - def find_checkpoint_pairs_in_directory(self, checkpoint_dir: str) -> list[tuple[int, str, str]]: + def find_checkpoint_pairs_in_directory( + self, checkpoint_dir: str + ) -> list[tuple[int, str, str]]: """ Find (step.pt, step_ema_0.995.pt) pairs in the given directory. Returns a list of tuples: (step, checkpoint_path, ema_checkpoint_path) @@ -265,51 +265,51 @@ def find_checkpoint_pairs_in_directory(self, checkpoint_dir: str) -> list[tuple[ if not os.path.exists(checkpoint_dir): self.print(f"Checkpoint directory not found: {checkpoint_dir}") return [] - + # Find all .pt files pt_files = glob.glob(os.path.join(checkpoint_dir, "*.pt")) if not pt_files: self.print(f"No .pt files found in {checkpoint_dir}") return [] - + # Parse step files and ema files step_files = {} # step -> path - ema_files = {} # step -> path - + ema_files = {} # step -> path + for file_path in pt_files: filename = os.path.basename(file_path) - + # Check for EMA files: {step}_ema_0.995.pt - ema_match = re.match(r'(\d+)_ema_0\.995\.pt$', filename) + ema_match = re.match(r"(\d+)_ema_0\.995\.pt$", filename) if ema_match: step = int(ema_match.group(1)) ema_files[step] = file_path continue - + # Check for regular checkpoint files: {step}.pt - step_match = re.match(r'(\d+)\.pt$', filename) + step_match = re.match(r"(\d+)\.pt$", filename) if step_match: step = int(step_match.group(1)) step_files[step] = file_path - + # Find steps that have both checkpoint and EMA files valid_pairs = [] for step in step_files: if step in ema_files: valid_pairs.append((step, step_files[step], ema_files[step])) - + # Sort by step number in descending order (newest first) valid_pairs.sort(key=lambda x: x[0], reverse=True) - - self.print(f"Found {len(valid_pairs)} valid checkpoint pairs in {checkpoint_dir}") - for step, checkpoint_path, ema_path in valid_pairs: - self.print(f" Step {step}: {os.path.basename(checkpoint_path)} + {os.path.basename(ema_path)}") - - return valid_pairs - - + self.print( + f"Found {len(valid_pairs)} valid checkpoint pairs in {checkpoint_dir}" + ) + for step, checkpoint_path, ema_path in valid_pairs: + self.print( + f" Step {step}: {os.path.basename(checkpoint_path)} + {os.path.basename(ema_path)}" + ) + return valid_pairs def try_load_checkpoint(self): @@ -324,7 +324,7 @@ def _load_checkpoint( if not os.path.exists(checkpoint_path): raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") self.print( - f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}" + f"Loading weights from {checkpoint_path} (strict={self.configs.load_strict})" ) checkpoint = torch.load(checkpoint_path, self.device) sample_key = [k for k in checkpoint["model"].keys()][0] @@ -332,25 +332,25 @@ def _load_checkpoint( if sample_key.startswith("module.") and not self.use_ddp: # DDP checkpoint has module. prefix checkpoint["model"] = { - k[len("module.") :]: v for k, v in checkpoint["model"].items() + k[len("module."):]: v for k, v in checkpoint["model"].items() } self.model.load_state_dict( state_dict=checkpoint["model"], strict=self.configs.load_strict, ) - print('#'*20, checkpoint_path, 'loaded') + print("#" * 20, checkpoint_path, "loaded") if not load_params_only: if not skip_load_optimizer: - self.print(f"Loading optimizer state") + self.print("Loading optimizer state") self.optimizer.load_state_dict(checkpoint["optimizer"]) if not skip_load_step: - self.print(f"Loading checkpoint step") + self.print("Loading checkpoint step") self.step = checkpoint["step"] + 1 self.start_step = self.step self.global_step = self.step * self.iters_to_accumulate if not skip_load_scheduler: - self.print(f"Loading scheduler state") + self.print("Loading scheduler state") self.lr_scheduler.load_state_dict(checkpoint["scheduler"]) elif load_step_for_scheduler: assert ( @@ -361,8 +361,10 @@ def _load_checkpoint( self.print(f"Finish loading checkpoint, current step: {self.step}") - if os.path.isfile(self.configs.load_checkpoint_path) and os.path.isfile(self.configs.load_ema_checkpoint_path): - print('#'*20, 'checkpoints', self.configs.load_checkpoint_path) + if os.path.isfile(self.configs.load_checkpoint_path) and os.path.isfile( + self.configs.load_ema_checkpoint_path + ): + print("#" * 20, "checkpoints", self.configs.load_checkpoint_path) # File path is directly given checkpoint_path = self.configs.load_checkpoint_path ema_checkpoint_path = self.configs.load_ema_checkpoint_path @@ -386,13 +388,14 @@ def _load_checkpoint( print(f"Loaded checkpoint: {checkpoint_path}") print(f"Loaded ema checkpoint: {ema_checkpoint_path}") - elif os.path.isdir(self.configs.load_checkpoint_path): # Directory is given - find highest step from all subdirectories - valid_pairs = self.find_checkpoint_pairs_in_directory(self.configs.load_checkpoint_path) + valid_pairs = self.find_checkpoint_pairs_in_directory( + self.configs.load_checkpoint_path + ) for step, checkpoint_path, ema_checkpoint_path in valid_pairs: - print('#'*20, step, checkpoint_path, ema_checkpoint_path) + print("#" * 20, step, checkpoint_path, ema_checkpoint_path) try: # Load EMA model parameters if ema_checkpoint_path: @@ -445,16 +448,16 @@ def print(self, msg: str): def model_forward(self, batch: dict, mode: str = "train") -> tuple[dict, dict]: assert mode in ["train", "eval"] batch["label_full_dict"] = { - 'entity_mol_id': batch["input_feature_dict"]["entity_mol_id"], - 'mol_id': batch["input_feature_dict"]["mol_id"], - 'mol_atom_index': batch["input_feature_dict"]["mol_atom_index"], + "entity_mol_id": batch["input_feature_dict"]["entity_mol_id"], + "mol_id": batch["input_feature_dict"]["mol_id"], + "mol_atom_index": batch["input_feature_dict"]["mol_atom_index"], } batch["label_dict"] = { "coordinate": batch["coordinate"], "coordinate_mask": batch["coordinate_mask"], } - if 'coordinate_multi' in batch.keys(): - batch["label_dict"]['coordinate_multi'] = batch["coordinate_multi"] + if "coordinate_multi" in batch.keys(): + batch["label_dict"]["coordinate_multi"] = batch["coordinate_multi"] batch["label_full_dict"].update(batch["label_dict"]) @@ -469,7 +472,6 @@ def model_forward(self, batch: dict, mode: str = "train") -> tuple[dict, dict]: return batch, log_dict - def get_loss( self, batch: dict, mode: str = "train" ) -> tuple[torch.Tensor, dict, dict]: @@ -534,7 +536,7 @@ def _evaluate(self, ema_suffix: str = "", mode: str = "eval"): total_batch_num = len(test_dl) for index, batch in enumerate(tqdm(test_dl)): if isinstance(batch, list): - print('len batch: ', len(batch)) + print("len batch: ", len(batch)) batch = batch[0] batch = to_device(batch, self.device) @@ -694,7 +696,7 @@ def run(self): step_need_save &= is_update_step if isinstance(batch, list): - print('len batch: ', len(batch)) + print("len batch: ", len(batch)) batch = batch[0] batch = to_device(batch, self.device) From 996b5cee7c12ca295019866dcad82dfa4ecdc11e Mon Sep 17 00:00:00 2001 From: Theo Viel Date: Fri, 16 Jan 2026 06:47:54 -0800 Subject: [PATCH 3/3] n_templates_inf argument --- README.md | 4 ++-- rnapro/config/config.py | 4 ++-- rnapro_inference_example.sh | 6 ++++-- runner/inference.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 215faad..0a41ef3 100644 --- a/README.md +++ b/README.md @@ -195,8 +195,7 @@ The script configures and forwards the following parameters to the CLI: - `--rna_msa_dir`: Directory containing precomputed MSAs. - `--use_template`: Template mode (use `ca_precomputed` for prepared templates). - `--template_data`: Path to `.pt` template file converted from submission.csv. -- `--template_idx`: Top-k template selection index: - - 0 -> top1, 1 -> top2, 2 -> top3, 3 -> top4, 4 -> top5 +- `--template_idx`: Top-k template selection index: 0 -> top1, 1 -> top2, 2 -> top3, 3 -> top4, 4 -> top5 - `--num_templates`: Number of templates to use (e.g., `10`). - `--model.N_cycle`: Diffusion cycles (e.g., `10`). - `--sample_diffusion.N_sample`: Number of samples per seed (e.g., `1`). @@ -207,6 +206,7 @@ The script configures and forwards the following parameters to the CLI: - `--sequences_csv`: Optional CSV with headers `sequence,target_id` for batched inference. - `--max_len`: Maximum length of the sequence. Longer sequences will be skipped during inference (default: `10000`). - `--logger`: Logger to use by the inference runner (default: `logging`). Supports `logging` and `print`. +- `--n_templates_inf`: Number of inferences to do with different template combinations (default: `5`). ### Acceleration diff --git a/rnapro/config/config.py b/rnapro/config/config.py index 841f7e2..e9a7b86 100644 --- a/rnapro/config/config.py +++ b/rnapro/config/config.py @@ -254,7 +254,7 @@ def parse_configs( help="Logger to use during inference. Supports 'logging' and 'print'" ) parser.add_argument( - "--num_templates", + "--n_templates_inf", type=int, default=5, required=False, @@ -280,7 +280,7 @@ def parse_configs( args = parser.parse_args(arg_str.split()) merged_configs.max_len = args.max_len merged_configs.logger = args.logger - merged_configs.num_templates = args.num_templates + merged_configs.n_templates_inf = args.n_templates_inf return merged_configs diff --git a/rnapro_inference_example.sh b/rnapro_inference_example.sh index 21b72fe..ad0a5cb 100644 --- a/rnapro_inference_example.sh +++ b/rnapro_inference_example.sh @@ -55,10 +55,12 @@ python3 runner/inference.py \ --triangle_multiplicative "cuequivariance" \ --sequences_csv "${SEQUENCES_CSV}" \ --max_len 5000 \ - --logger "logging" + --logger "logging" \ + --n_templates_inf 5 # Notes: # --triangle_attention supports 'triattention', 'cuequivariance', 'deepspeed', 'torch' # --triangle_multiplicative supports 'cuequivariance', 'torch' # --max_len 1000: Sequences longer than max_len will be skipped to avoid oom -# --logger handles logging of the inference runner, supports "logging", "print" \ No newline at end of file +# --logger handles logging of the inference runner, supports "logging", "print" +# --n_templates_inf sets the number of inferences to do with different template combinations \ No newline at end of file diff --git a/runner/inference.py b/runner/inference.py index e00950c..1ca079e 100644 --- a/runner/inference.py +++ b/runner/inference.py @@ -343,7 +343,7 @@ def run() -> None: try: target_id = row.target_id sequence = row.sequence - for template_idx in range(configs.num_templates): + for template_idx in range(configs.n_templates_inf): print() run_ptx( target_id=target_id,