Skip to content

fraxea/x-RNA

Repository files navigation

x-RNA

This is a PyTorch implementation for RNA 3D Folding based on AlphaFold 3.

Protenix

Much of this repo used Protenix code.

Deletions

Here are the deletion codes from Protenix due to RNA-only inputs.

Modules

  • InputFeatureEmbedder: This part is necessary to differ between token and atom. But since we have only RNA, it's not needed.
  • RelativePositionEncoding: Features asym_id, entity_id, sym_id, token_index are removed, and for simplicity the chunking part at inference is removed.
  • MSAModule and TemplateEmbedder: For the first phase, there is no template and MSA.
  • PairformerBlock: Inplace operations are removed. Also some flags: use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, chunk_size.
  • TriangleMultiplicativeUpdate: Remove _inference_forward. Inplace operations are removed. Use nn.LayerNorm instead of customized LayerNorms. Remove chunking.
  • TriangleAttention: Remove some flags: use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, chunk_size. Inplace operations are removed. Use nn.LayerNorm instead of customized LayerNorms. Remove chunking.
  • Attention: Remove flashing, chunking, use_memory_efficient_kernel, use_deepspeed_evo_attention, lma. Remove local attention.
  • AttentionPairBias: Remove chunking, inplace, local attention.
  • PairformerStack: No checkpoining is performed. No use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, clear_cache_between_blocks, chunk_size.
  • DiffusionModule: Remove chunking, checkpoining, inplace. Also all parameters and operations dealing with reference features.
  • Protenix: Remove ConfidenceHead, InputFeatureEmbedder, TemplateEmbedder, MSAModule, DistogramHead, linear_no_bias_token_bond, symmetric_permutation, Mini-rollout.

Losses

  • Losses: Remove PLDDTLoss, PDELoss, ExperimentallyResolvedLoss, PAELoss, BondLoss, SmoothLDDTLoss, DistogramLoss.
  • Methods: Remove distances, chunking, train_confidence_only, has_valid_resolution, confidence_coordinate, diffusion_lddt_loss, distogram.
  • MSELoss: Remove weight_ligand, weight_dna, weight_rna. Remove per-atom weighting.
  • align_pred_to_true: Remove weight, allowing_reflection.

Data

  • Sources: Remove templates, MSA, and reference features.
  • Methods: Remove shuffling, cropping (except continuous cropping).

Config

There are model parameters, that determines the model size.

Parameter Description Default Large Medium Small Nano
c_s single embedding 384 192 96 48 24
c_z pair and atom embedding 128 64 32 16 8
c_token token embedding 768 384 192 96 48
pairformer_n_block number of pairformer blocks 48 24 12 6 3
diffusion_n_block number of diffusion blocks 24 12 6 3 2
n_head number of heads in pairformer and diffusion transformer 16 16 8 4 4
TOTAL total parameters 350.65M 46.53M 6.51M 1.00M 0.19M

Also there are train parameters, that affects the training process.

Parameter Description Train Fine tune 1 Fine tune 2 Inference
crop_size Train crop size 384 640 768 -1
batch_size Diffusion batch size 48 32 32 1
n_cycle Number of pairformer cycles 3 3 3 3
dropout Dropout rate 0.25 0.25 0.25 0
n_step Number of diffusion steps 200 200 200 200

Gradient Checkpointing

The idea is to not store the intermediate activations of the model, and recompute them during the backward pass. This is done by storing only the input and output of each block, and recomputing the activations when needed (by using torch.utils.checkpoint).

Chunking

This technique reduces the memory consumption without affecting accuracy, but maybe performance. One of the highest memory consumption blocks in AlphaFold are those that deal with pair representation. For calculating attention weights in these blocks, we need $O(N_{res}^3)$ memory which $N_{res}$ is the number of residues in the input. In this technique, we find the dimension that acts "batch-like". For example on Row-wise Trianlge Attentions, the computation at row dimension is independent. Using a small chunk size like $4$ (but not very small, to use parralelism) reduces memory consumption to $O(N_{res}^2)$.

These parts use chunking (parts that are not in this implementation are not mentioned):

  • PairformerStack
    • PairformerBlock
      • TriangleAttention
        • Attention
          • chunk_layer
  • DiffusionModule
    • AtomAttentionEncoder
      • AtomTransformer
        • DiffusionTransformer
    • DiffusionTransformer
      • DiffusionTransformerBlock
        • AttentionPairBias
          • Attention
    • AtomAttentionDecoder
      • AtomTransformer
        • DiffusionTransformer

DeepSpeed

For this model, DeepSpeed made it possible to train the standard model with 2 T4 GPUs on Kaggle. The config (ZeRO-2 + bfloat16) is in the ds_config.json file. The model is trained with 2 GPUs (15GB each), batch size of 32, and crop size 384.

Training

The following figure plots the training loss and memory peaks for 100 steps on val_set.

progress

About

My replication of Protenix for only-RNA folding

Topics

Resources

Stars

Watchers

Forks

Languages