This is a PyTorch implementation for RNA 3D Folding based on AlphaFold 3.
Much of this repo used Protenix code.
Here are the deletion codes from Protenix due to RNA-only inputs.
- InputFeatureEmbedder: This part is necessary to differ between token and atom. But since we have only RNA, it's not needed.
- RelativePositionEncoding: Features asym_id, entity_id, sym_id, token_index are removed, and for simplicity the chunking part at inference is removed.
- MSAModule and TemplateEmbedder: For the first phase, there is no template and MSA.
- PairformerBlock: Inplace operations are removed. Also some flags: use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, chunk_size.
- TriangleMultiplicativeUpdate: Remove _inference_forward. Inplace operations are removed. Use nn.LayerNorm instead of customized LayerNorms. Remove chunking.
- TriangleAttention: Remove some flags: use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, chunk_size. Inplace operations are removed. Use nn.LayerNorm instead of customized LayerNorms. Remove chunking.
- Attention: Remove flashing, chunking, use_memory_efficient_kernel, use_deepspeed_evo_attention, lma. Remove local attention.
- AttentionPairBias: Remove chunking, inplace, local attention.
- PairformerStack: No checkpoining is performed. No use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, inplace_safe, clear_cache_between_blocks, chunk_size.
- DiffusionModule: Remove chunking, checkpoining, inplace. Also all parameters and operations dealing with reference features.
- Protenix: Remove ConfidenceHead, InputFeatureEmbedder, TemplateEmbedder, MSAModule, DistogramHead, linear_no_bias_token_bond, symmetric_permutation, Mini-rollout.
- Losses: Remove PLDDTLoss, PDELoss, ExperimentallyResolvedLoss, PAELoss, BondLoss, SmoothLDDTLoss, DistogramLoss.
- Methods: Remove distances, chunking, train_confidence_only, has_valid_resolution, confidence_coordinate, diffusion_lddt_loss, distogram.
- MSELoss: Remove weight_ligand, weight_dna, weight_rna. Remove per-atom weighting.
- align_pred_to_true: Remove weight, allowing_reflection.
- Sources: Remove templates, MSA, and reference features.
- Methods: Remove shuffling, cropping (except continuous cropping).
There are model parameters, that determines the model size.
| Parameter | Description | Default | Large | Medium | Small | Nano |
|---|---|---|---|---|---|---|
| c_s | single embedding | 384 | 192 | 96 | 48 | 24 |
| c_z | pair and atom embedding | 128 | 64 | 32 | 16 | 8 |
| c_token | token embedding | 768 | 384 | 192 | 96 | 48 |
| pairformer_n_block | number of pairformer blocks | 48 | 24 | 12 | 6 | 3 |
| diffusion_n_block | number of diffusion blocks | 24 | 12 | 6 | 3 | 2 |
| n_head | number of heads in pairformer and diffusion transformer | 16 | 16 | 8 | 4 | 4 |
| TOTAL | total parameters | 350.65M | 46.53M | 6.51M | 1.00M | 0.19M |
Also there are train parameters, that affects the training process.
| Parameter | Description | Train | Fine tune 1 | Fine tune 2 | Inference |
|---|---|---|---|---|---|
| crop_size | Train crop size | 384 | 640 | 768 | -1 |
| batch_size | Diffusion batch size | 48 | 32 | 32 | 1 |
| n_cycle | Number of pairformer cycles | 3 | 3 | 3 | 3 |
| dropout | Dropout rate | 0.25 | 0.25 | 0.25 | 0 |
| n_step | Number of diffusion steps | 200 | 200 | 200 | 200 |
The idea is to not store the intermediate activations of the model, and recompute them during the backward pass. This is done by storing only the input and output of each block, and recomputing the activations when needed (by using torch.utils.checkpoint).
This technique reduces the memory consumption without affecting accuracy, but maybe performance. One of the highest memory consumption blocks in AlphaFold are those that deal with pair representation. For calculating attention weights in these blocks, we need
These parts use chunking (parts that are not in this implementation are not mentioned):
- PairformerStack
- PairformerBlock
- TriangleAttention
- Attention
- chunk_layer
- Attention
- TriangleAttention
- PairformerBlock
- DiffusionModule
- AtomAttentionEncoder
- AtomTransformer
- DiffusionTransformer
- AtomTransformer
- DiffusionTransformer
- DiffusionTransformerBlock
- AttentionPairBias
- Attention
- AttentionPairBias
- DiffusionTransformerBlock
- AtomAttentionDecoder
- AtomTransformer
- DiffusionTransformer
- AtomTransformer
- AtomAttentionEncoder
For this model, DeepSpeed made it possible to train the standard model with 2 T4 GPUs on Kaggle. The config (ZeRO-2 + bfloat16) is in the ds_config.json file. The model is trained with 2 GPUs (15GB each), batch size of 32, and crop size 384.
The following figure plots the training loss and memory peaks for 100 steps on val_set.
