Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ To train the model, you will need the CCD cache. The CCD cache is generated by p


```sh
python3 scripts/gen_ccd_cache.py
python3 preprocess/gen_ccd_cache.py
```

```
Expand Down Expand Up @@ -141,6 +141,7 @@ release_data/


### 2. Training

We provide the trained model checkpoint via [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/resources/rnapro) and HuggingFace ([Public best](https://huggingface.co/nvidia/RNAPro-Public-Best-500M) and [Private best](https://huggingface.co/nvidia/RNAPro-Private-Best-500M)).

We provide a convenience script for training. Please modify it according to your purpose:
Expand All @@ -155,7 +156,7 @@ sh rnapro_train_example.sh

For details on the input format and output format, please refer to the [overview](model_cards/overview.md).

### 1. Prepare inputs
### Prepare inputs

- Input csv files
- Prepare a CSV file with the columns: target_id and sequence.
Expand All @@ -171,6 +172,10 @@ For details on the input format and output format, please refer to the [overview
- `python preprocess/convert_templates_to_pt_files.py --input_csv path/to/submission.csv --output_name path/to/template_features.pt --max_n 40`
- Use with `--use_template ca_precomputed --template_data path/to/template_features.pt`.

- CCD cache (same as training)
-`python preprocess/gen_ccd_cache.py`

- Model weights are available via [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/resources/rnapro) and HuggingFace ([Public best](https://huggingface.co/nvidia/RNAPro-Public-Best-500M) and [Private best](https://huggingface.co/nvidia/RNAPro-Private-Best-500M)).

### Inference via Bash Script

Expand All @@ -190,8 +195,7 @@ The script configures and forwards the following parameters to the CLI:
- `--rna_msa_dir`: Directory containing precomputed MSAs.
- `--use_template`: Template mode (use `ca_precomputed` for prepared templates).
- `--template_data`: Path to `.pt` template file converted from submission.csv.
- `--template_idx`: Top-k template selection index:
- 0 -> top1, 1 -> top2, 2 -> top3, 3 -> top4, 4 -> top5
- `--template_idx`: Top-k template selection index: 0 -> top1, 1 -> top2, 2 -> top3, 3 -> top4, 4 -> top5
- `--num_templates`: Number of templates to use (e.g., `10`).
- `--model.N_cycle`: Diffusion cycles (e.g., `10`).
- `--sample_diffusion.N_sample`: Number of samples per seed (e.g., `1`).
Expand All @@ -200,7 +204,9 @@ The script configures and forwards the following parameters to the CLI:
- `--num_workers`: Data loader workers.
- `--triangle_attention` / `--triangle_multiplicative`: Kernel backends (`torch`, `cuequivariance`, etc.).
- `--sequences_csv`: Optional CSV with headers `sequence,target_id` for batched inference.

- `--max_len`: Maximum length of the sequence. Longer sequences will be skipped during inference (default: `10000`).
- `--logger`: Logger to use by the inference runner (default: `logging`). Supports `logging` and `print`.
- `--n_templates_inf`: Number of inferences to do with different template combinations (default: `5`).

### Acceleration

Expand Down
2 changes: 1 addition & 1 deletion configs/configs_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
or (not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH))
or (not os.path.exists(PDB_CLUSTER_FILE_PATH))
):
print("Try to find the ccd cache data in the code directory for inference.")
# print("Try to find the ccd cache data in the code directory for inference.")
current_file_path = os.path.abspath(__file__)
current_directory = os.path.dirname(current_file_path)
code_directory = os.path.dirname(current_directory)
Expand Down
32 changes: 31 additions & 1 deletion rnapro/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def merge_configs(self, new_configs: dict) -> ConfigDict:

def parse_configs(
configs: dict, arg_str: str = None, fill_required_with_null: bool = False
) -> ConfigDict:
):
"""
Parses and merges configuration settings from a dictionary and command-line arguments.

Expand All @@ -237,6 +237,30 @@ def parse_configs(
"""
manager = ConfigManager(configs, fill_required_with_null=fill_required_with_null)
parser = argparse.ArgumentParser()

# This is new
parser.add_argument(
"--max_len",
type=int,
default=10000,
required=False,
help="Maximum length of the sequence. Longer sequences will be skipped during inference"
)
parser.add_argument(
"--logger",
type=str,
default="logging",
required=False,
help="Logger to use during inference. Supports 'logging' and 'print'"
)
parser.add_argument(
"--n_templates_inf",
type=int,
default=5,
required=False,
help="Number of templates to use during inference"
)

# Register arguments
for key, (
dtype,
Expand All @@ -252,6 +276,12 @@ def parse_configs(
merged_configs = manager.merge_configs(
vars(parser.parse_args(arg_str.split())) if arg_str else {}
)

args = parser.parse_args(arg_str.split())
merged_configs.max_len = args.max_len
merged_configs.logger = args.logger
merged_configs.n_templates_inf = args.n_templates_inf

return merged_configs


Expand Down
Loading