diff --git a/Dockerfile b/Dockerfile index 835300319..fcaeb56b5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,20 @@ -FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 # metainformation -LABEL org.opencontainers.image.version = "1.0.0" -LABEL org.opencontainers.image.authors = "Gustaf Ahdritz" +LABEL org.opencontainers.image.version = "2.0.0" +LABEL org.opencontainers.image.authors = "OpenFold Team" LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold" LABEL org.opencontainers.image.licenses = "Apache License 2.0" -LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04" +LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" + +RUN apt-get update && apt-get install -y wget RUN apt-key del 7fa2af80 -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb +RUN dpkg -i cuda-keyring_1.0-1_all.deb + +RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git -RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN wget -P /tmp \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ diff --git a/docs/source/Aux_seq_files.md b/docs/source/Aux_seq_files.md index 820872fc3..41a94da1f 100644 --- a/docs/source/Aux_seq_files.md +++ b/docs/source/Aux_seq_files.md @@ -68,9 +68,9 @@ All together, the file directory would look like: └── 6kwc.cif └── alignment_db ├── alignment_db_0.db - ├── alignment_db_1.db - ... - ├── alignment_db_9.db + ├── alignment_db_1.db + ... + ├── alignment_db_9.db └── alignment_db.index ``` diff --git a/docs/source/Inference.md b/docs/source/Inference.md index b8cef9074..1e40f59ff 100644 --- a/docs/source/Inference.md +++ b/docs/source/Inference.md @@ -62,7 +62,7 @@ python3 run_pretrained_openfold.py \ $TEMPLATE_MMCIF_DIR --output_dir $OUTPUT_DIR \ --config_preset model_1_ptm \ - --uniref90_database_path $BASE_DATA_DIR/uniref90 \ + --uniref90_database_path $BASE_DATA_DIR/uniref90/uniref90.fasta \ --mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \ --pdb70_database_path $BASE_DATA_DIR/pdb70 \ --uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ @@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view - `--data_random_seed`: Specifies a random seed to use. - `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads. - `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`. +- `--use_custom_template`: Uses all .cif files in `template_mmcif_dir` as template input. Make sure the chains of interest have the identifier _A_ and have the same length as the input sequence. The same templates will be read for all sequences that are passed for inference. ### Advanced Options for Increasing Efficiency @@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) #### Long sequence inference To minimize memory usage during inference on long sequences, consider the following changes: -- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either. -- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint. -- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time. -- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model. +- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either. +- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint. +- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time. +- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model. - Disable FlashAttention, which seems unstable on long sequences. -Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option +Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option -Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). \ No newline at end of file +Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). \ No newline at end of file diff --git a/docs/source/Installation.md b/docs/source/Installation.md index 9e34d5f23..6b9599e83 100644 --- a/docs/source/Installation.md +++ b/docs/source/Installation.md @@ -4,7 +4,7 @@ In this guide, we will OpenFold and its dependencies. **Pre-requisites** -This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml) +This package is currently supported for CUDA 12 and Pytorch 2. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). At this time, only Linux systems are supported. @@ -19,9 +19,16 @@ At this time, only Linux systems are supported. Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process. - Activate the environment, e.g `conda activate openfold_env` 1. Run the setup script to configure kernels and folding resources. - > scripts/install_third_party_dependencies.sh` -1. Prepend the conda environment to the `$LD_LIBRARY_PATH`., e.g. - `export $LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH`. You may optionally set this as a conda environment variable according to the [conda docs](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#saving-environment-variables) to activate each time the environment is used. + > scripts/install_third_party_dependencies.sh +1. Prepend the conda environment to the `$LD_LIBRARY_PATH` and `$LIBRARY_PATH`., e.g. + + ``` + export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH + export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH + ``` + + You may optionally set this as a conda environment variable according to the [conda docs](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#saving-environment-variables) to activate each time the environment is used. + 1. Download parameters. We recommend using a destination as `openfold/resources` as our unittests will look for the weights there. - For AlphaFold2 weights, use > ./scripts/download_alphafold_params.sh @@ -46,12 +53,6 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation. ## Environment specific modifications -### CUDA 12 -To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment. - In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the URL in step 1 with https://github.com/aqlaboratory/openfold/tree/pl_upgrades - Follow the rest of the steps of [Installation Guide](#Installation) - - ### MPI To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`. @@ -64,4 +65,4 @@ If you don't have access to `aws` on your system, you can use a different downlo ### Docker setup -A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container). +A [`Dockerfile`](https://github.com/aqlaboratory/openfold/blob/main/Dockerfile) is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container). diff --git a/environment.yml b/environment.yml index c5cf4104c..448959007 100644 --- a/environment.yml +++ b/environment.yml @@ -3,36 +3,38 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - - python=3.9 - - libgcc=7.2 + - cuda + - gcc=12.4 + - python=3.10 - setuptools=59.5.0 - pip - - openmm=7.7 + - openmm - pdbfixer - pytorch-lightning - biopython - numpy - pandas - - PyYAML==5.4.1 + - PyYAML - requests - - scipy==1.7 - - tqdm==4.62.2 - - typing-extensions==4.0 + - scipy + - tqdm + - typing-extensions - wandb - modelcif==0.7 - awscli - ml-collections - aria2 - - mkl=2024.0 + - mkl - git - - bioconda::hmmer==3.3.2 - - bioconda::hhsuite==3.3.0 - - bioconda::kalign2==2.04 - - bioconda::mmseqs2 - - pytorch::pytorch=1.12.* + - bioconda::hmmer + - bioconda::hhsuite + - bioconda::kalign2 + - pytorch::pytorch=2.5 + - pytorch::pytorch-cuda=12.4 - pip: - - deepspeed==0.12.4 + - deepspeed==0.14.5 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/notebooks/OpenFold.ipynb b/notebooks/OpenFold.ipynb index de5d4539c..c1e80b012 100644 --- a/notebooks/OpenFold.ipynb +++ b/notebooks/OpenFold.ipynb @@ -107,11 +107,11 @@ "\n", "python_version = f\"{version_info.major}.{version_info.minor}\"\n", "\n", - "\n", - "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", - "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", + "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh\")\n", + "os.system(\"bash Miniforge3-Linux-x86_64.sh -bfp /usr/local\")\n", + "os.environ[\"PATH\"] = \"/usr/local/bin:\" + os.environ[\"PATH\"]\n", "os.system(\"mamba config --set auto_update_conda false\")\n", - "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n", + "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=8.2.0 python={python_version} pdbfixer biopython=1.83\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "\n", "try:\n", @@ -127,7 +127,7 @@ "\n", " %shell mkdir -p /content/openfold/openfold/resources\n", "\n", - " commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n", + " commit = \"1ffd197489aa5f35a5fbce1f00d7dd49bce1bd2f\"\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", "\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n", @@ -893,8 +893,7 @@ "metadata": { "colab": { "provenance": [], - "gpuType": "T4", - "toc_visible": true + "gpuType": "T4" }, "kernelspec": { "display_name": "Python 3", @@ -907,4 +906,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/openfold/config.py b/openfold/config.py index 7bf30e391..a738b9f07 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -660,7 +660,7 @@ def model_config( }, "relax": { "max_iterations": 0, # no max - "tolerance": 2.39, + "tolerance": 10.0, "stiffness": 10.0, "max_outer_iterations": 20, "exclude_residues": [], diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index adde0b73b..393c1cef3 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -23,8 +23,19 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union import numpy as np import torch -from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer -from openfold.data.templates import get_custom_template_features, empty_template_feats +from openfold.data import ( + templates, + parsers, + mmcif_parsing, + msa_identifiers, + msa_pairing, + feature_processing_multimer, +) +from openfold.data.templates import ( + get_custom_template_features, + empty_template_feats, + CustomHitFeaturizer, +) from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.np import residue_constants, protein @@ -38,7 +49,9 @@ def make_template_features( template_featurizer: Any, ) -> FeatureDict: hits_cat = sum(hits.values(), []) - if(len(hits_cat) == 0 or template_featurizer is None): + if template_featurizer is None or ( + len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer) + ): template_features = empty_template_feats(len(input_sequence)) else: templates_result = template_featurizer.get_templates( diff --git a/openfold/data/mmcif_parsing.py b/openfold/data/mmcif_parsing.py index 6a6c1fd5b..f83fec8f8 100644 --- a/openfold/data/mmcif_parsing.py +++ b/openfold/data/mmcif_parsing.py @@ -283,7 +283,7 @@ def parse( author_chain = mmcif_to_author_chain_id[chain_id] seq = [] for monomer in seq_info: - code = PDBData.protein_letters_3to1.get(monomer.id, "X") + code = PDBData.protein_letters_3to1_extended.get(monomer.id, "X") seq.append(code if len(code) == 1 else "X") seq = "".join(seq) author_chain_to_sequence[author_chain] = seq diff --git a/openfold/data/templates.py b/openfold/data/templates.py index 8c55b5f38..5c82b7025 100644 --- a/openfold/data/templates.py +++ b/openfold/data/templates.py @@ -22,6 +22,7 @@ import json import logging import os +from pathlib import Path import re from typing import Any, Dict, Mapping, Optional, Sequence, Tuple @@ -947,55 +948,71 @@ def _process_single_hit( def get_custom_template_features( - mmcif_path: str, - query_sequence: str, - pdb_id: str, - chain_id: str, - kalign_binary_path: str): - - with open(mmcif_path, "r") as mmcif_path: - cif_string = mmcif_path.read() - - mmcif_parse_result = mmcif_parsing.parse( - file_id=pdb_id, mmcif_string=cif_string - ) - template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id] - - - mapping = {x:x for x, _ in enumerate(query_sequence)} - - - features, warnings = _extract_template_features( - mmcif_object=mmcif_parse_result.mmcif_object, - pdb_id=pdb_id, - mapping=mapping, - template_sequence=template_sequence, - query_sequence=query_sequence, - template_chain_id=chain_id, - kalign_binary_path=kalign_binary_path, - _zero_center_positions=True - ) - features["template_sum_probs"] = [1.0] - - # TODO: clean up this logic - template_features = {} - for template_feature_name in TEMPLATE_FEATURES: - template_features[template_feature_name] = [] - - for k in template_features: - template_features[k].append(features[k]) - - for name in template_features: - template_features[name] = np.stack( - template_features[name], axis=0 - ).astype(TEMPLATE_FEATURES[name]) + mmcif_path: str, + query_sequence: str, + pdb_id: str, + chain_id: Optional[str] = "A", + kalign_binary_path: Optional[str] = None, +): + if os.path.isfile(mmcif_path): + template_paths = [Path(mmcif_path)] + elif os.path.isdir(mmcif_path): + template_paths = list(Path(mmcif_path).glob("*.cif")) + else: + logging.error("Custom template path %s does not exist", mmcif_path) + raise ValueError(f"Custom template path {mmcif_path} does not exist") + + warnings = [] + template_features = dict() + for template_path in template_paths: + logging.info("Featurizing template: %s", template_path) + # pdb_id only for error reporting, take file name + pdb_id = Path(template_path).stem + with open(template_path, "r") as mmcif_path: + cif_string = mmcif_path.read() + mmcif_parse_result = mmcif_parsing.parse( + file_id=pdb_id, mmcif_string=cif_string + ) + # mapping skipping "-" + mapping = { + x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum() + } + realigned_sequence, realigned_mapping = _realign_pdb_template_to_query( + old_template_sequence=query_sequence, + template_chain_id=chain_id, + mmcif_object=mmcif_parse_result.mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path, + ) + curr_features, curr_warnings = _extract_template_features( + mmcif_object=mmcif_parse_result.mmcif_object, + pdb_id=pdb_id, + mapping=realigned_mapping, + template_sequence=realigned_sequence, + query_sequence=query_sequence, + template_chain_id=chain_id, + kalign_binary_path=kalign_binary_path, + _zero_center_positions=True, + ) + curr_features["template_sum_probs"] = [ + 1.0 + ] # template given by user, 100% confident + template_features = { + curr_name: template_features.get(curr_name, []) + [curr_item] + for curr_name, curr_item in curr_features.items() + } + warnings.append(curr_warnings) + template_features = { + template_feature_name: np.stack( + template_features[template_feature_name], axis=0 + ).astype(template_feature_type) + for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items() + } return TemplateSearchResult( features=template_features, errors=None, warnings=warnings ) - - @dataclasses.dataclass(frozen=True) class TemplateSearchResult: features: Mapping[str, Any] @@ -1188,6 +1205,23 @@ def get_templates( ) +class CustomHitFeaturizer(TemplateHitFeaturizer): + """Featurizer for templates given in folder. + Chain of interest has to be chain A and of same sequence length as input sequence.""" + def get_templates( + self, + query_sequence: str, + hits: Sequence[parsers.TemplateHit], + ) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir) + return get_custom_template_features( + self._mmcif_dir, + query_sequence=query_sequence, + pdb_id="test", + chain_id="A", + kalign_binary_path=self._kalign_binary_path, + ) class HmmsearchHitFeaturizer(TemplateHitFeaturizer): def get_templates( self, diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index ea38cb34a..c35472539 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -808,10 +808,10 @@ def _flash_attn(q, k, v, kv_mask): # [B_flat, N, 2 * H * C] kv = kv.reshape(*kv.shape[:-3], -1) - kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) + kv_unpad, _, kv_cu_seqlens, kv_max_s, _ = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py index 02816bb81..43d9337e4 100644 --- a/openfold/np/relax/amber_minimize.py +++ b/openfold/np/relax/amber_minimize.py @@ -34,6 +34,7 @@ from openmm.app.internal.pdbstructure import PdbStructure ENERGY = unit.kilocalories_per_mole +FORCE = unit.kilojoules_per_mole / unit.nanometer LENGTH = unit.angstroms @@ -439,7 +440,7 @@ def _run_one_iteration( exclude_residues = exclude_residues or [] # Assign physical dimensions. - tolerance = tolerance * ENERGY + tolerance = tolerance * FORCE stiffness = stiffness * ENERGY / (LENGTH ** 2) start = time.perf_counter() diff --git a/openfold/utils/superimposition.py b/openfold/utils/superimposition.py index 9fe794fff..d1dca2718 100644 --- a/openfold/utils/superimposition.py +++ b/openfold/utils/superimposition.py @@ -35,10 +35,10 @@ def _superimpose_np(reference, coords): def _superimpose_single(reference, coords): - reference_np = reference.detach().to(torch.float).cpu().numpy() - coords_np = coords.detach().to(torch.float).cpu().numpy() - superimposed, rmsd = _superimpose_np(reference_np, coords_np) - return coords.new_tensor(superimposed), coords.new_tensor(rmsd) + reference_np = reference.detach().to(torch.float).cpu().numpy() + coords_np = coords.detach().to(torch.float).cpu().numpy() + superimposed, rmsd = _superimpose_np(reference_np, coords_np) + return coords.new_tensor(superimposed), coords.new_tensor(rmsd) def superimpose(reference, coords, mask): diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index eb4ad7c01..510610493 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -185,12 +185,7 @@ def main(args): use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention, ) - if args.experiment_config_json: - with open(args.experiment_config_json, 'r') as f: - custom_config_dict = json.load(f) - config.update_from_flattened_dict(custom_config_dict) - - if args.experiment_config_json: + if args.experiment_config_json: with open(args.experiment_config_json, 'r') as f: custom_config_dict = json.load(f) config.update_from_flattened_dict(custom_config_dict) @@ -202,8 +197,15 @@ def main(args): ) is_multimer = "multimer" in args.config_preset - - if is_multimer: + is_custom_template = "use_custom_template" in args and args.use_custom_template + if is_custom_template: + template_featurizer = templates.CustomHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date="9999-12-31", # just dummy, not used + max_hits=-1, # just dummy, not used + kalign_binary_path=args.kalign_binary_path + ) + elif is_multimer: template_featurizer = templates.HmmsearchHitFeaturizer( mmcif_dir=args.template_mmcif_dir, max_template_date=args.max_template_date, @@ -221,11 +223,9 @@ def main(args): release_dates_path=args.release_dates_path, obsolete_pdbs_path=args.obsolete_pdbs_path ) - data_processor = data_pipeline.DataPipeline( template_featurizer=template_featurizer, ) - if is_multimer: data_processor = data_pipeline.DataPipelineMultimer( monomer_data_pipeline=data_processor, @@ -238,7 +238,6 @@ def main(args): np.random.seed(random_seed) torch.manual_seed(random_seed + 1) - feature_processor = feature_pipeline.FeaturePipeline(config.data) if not os.path.exists(output_dir_base): os.makedirs(output_dir_base) @@ -313,7 +312,6 @@ def main(args): ) feature_dicts[tag] = feature_dict - processed_feature_dict = feature_processor.process_features( feature_dict, mode='predict', is_multimer=is_multimer ) @@ -400,6 +398,10 @@ def main(args): help="""Path to alignment directory. If provided, alignment computation is skipped and database path arguments are ignored.""" ) + parser.add_argument( + "--use_custom_template", action="store_true", default=False, + help="""Use mmcif given with "template_mmcif_dir" argument as template input.""" + ) parser.add_argument( "--use_single_seq_mode", action="store_true", default=False, help="""Use single sequence embeddings instead of MSAs.""" @@ -494,5 +496,4 @@ def main(args): """The model is being run on CPU. Consider specifying --model_device for better performance""" ) - main(args) diff --git a/scripts/flatten_roda.sh b/scripts/flatten_roda.sh index 074736a69..788f9a71c 100755 --- a/scripts/flatten_roda.sh +++ b/scripts/flatten_roda.sh @@ -9,8 +9,8 @@ # output_dir: # The directory in which to construct the reformatted data -if [[ $# != 2 ]]; then - echo "usage: ./flatten_roda.sh " +if [ "$#" -ne 2 ]; then + echo "Usage: ./flatten_roda.sh " exit 1 fi @@ -23,25 +23,36 @@ ALIGNMENT_DIR="${OUTPUT_DIR}/alignments" mkdir -p "${DATA_DIR}" mkdir -p "${ALIGNMENT_DIR}" -for chain_dir in $(ls "${RODA_DIR}"); do - CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}" - for subdir in $(ls "${CHAIN_DIR_PATH}"); do - if [[ ! -d "$subdir" ]]; then - echo "$subdir is not directory" +for chain_dir in "${RODA_DIR}"/*; do + if [ ! -d "$chain_dir" ]; then + continue + fi + + chain_name=$(basename "$chain_dir") + + for subdir in "$chain_dir"/*; do + if [ ! -d "$subdir" ]; then + echo "$subdir is not a directory" continue - elif [[ -z $(ls "${subdir}")]]; then + fi + + if [ -z "$(ls -A "$subdir")" ]; then continue - elif [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then - mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}" + fi + + subdir_name=$(basename "$subdir") + + if [ "$subdir_name" = "pdb" ] || [ "$subdir_name" = "cif" ]; then + mv "$subdir"/* "${DATA_DIR}/" else - CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}" + CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_name}" mkdir -p "${CHAIN_ALIGNMENT_DIR}" - mv "${CHAIN_DIR_PATH}/${subdir}"/* "${CHAIN_ALIGNMENT_DIR}" + mv "$subdir"/* "${CHAIN_ALIGNMENT_DIR}/" fi done done NO_DATA_FILES=$(find "${DATA_DIR}" -type f | wc -l) -if [[ $NO_DATA_FILES = 0 ]]; then - rm -rf ${DATA_DIR} -fi +if [ "$NO_DATA_FILES" -eq 0 ]; then + rm -rf "${DATA_DIR}" +fi \ No newline at end of file diff --git a/scripts/install_third_party_dependencies.sh b/scripts/install_third_party_dependencies.sh index fe2a6a0ba..e9d91002a 100755 --- a/scripts/install_third_party_dependencies.sh +++ b/scripts/install_third_party_dependencies.sh @@ -14,7 +14,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. python setup.py install echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel" -git clone https://github.com/NVIDIA/cutlass --depth 1 +git clone https://github.com/NVIDIA/cutlass --branch v3.6.0 --depth 1 conda env config vars set CUTLASS_PATH=$PWD/cutlass # This setting is used to fix a worker assignment issue during data loading diff --git a/setup.py b/setup.py index bec986254..3750d9fe9 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', @@ -52,9 +52,9 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor compute_capabilities = set([ - (3, 7), # K80, e.g. (5, 2), # Titan X (6, 1), # GeForce 1000-series + (9, 0), # Hopper ]) compute_capabilities.add((7, 0)) @@ -113,7 +113,7 @@ def get_cuda_bare_metal_version(cuda_dir): setup( name='openfold', - version='2.0.0', + version='2.2.0', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', author='OpenFold Team', author_email='jennifer.wei@omsf.io', @@ -130,7 +130,7 @@ def get_cuda_bare_metal_version(cuda_dir): classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.9,' + 'Programming Language :: Python :: 3.10,' 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], ) diff --git a/tests/test_deepspeed_evo_attention.py b/tests/test_deepspeed_evo_attention.py index 5474f98f8..a65a76317 100644 --- a/tests/test_deepspeed_evo_attention.py +++ b/tests/test_deepspeed_evo_attention.py @@ -306,7 +306,6 @@ def test_compare_model(self): batch["residx_atom37_to_atom14"] = batch[ "residx_atom37_to_atom14" ].long() - # print(batch["target_feat"].shape) batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32) batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch.update( @@ -316,8 +315,9 @@ def test_compare_model(self): # Move the recycling dimension to the end move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) batch = tensor_tree_map(move_dim, batch) - with torch.no_grad(): - with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16 + # https://github.com/aqlaboratory/openfold/issues/532 + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32): model = compare_utils.get_global_pretrained_openfold() model.globals.use_deepspeed_evo_attention = False out_repro = model(batch) diff --git a/tests/test_model.py b/tests/test_model.py index 3d19f14ed..ecf5af13f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -202,4 +202,4 @@ def run_alphafold(batch): out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro.squeeze(0) - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3) diff --git a/train_openfold.py b/train_openfold.py index 168a4b43f..c55de9db3 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -21,7 +21,6 @@ from openfold.model.model import AlphaFold from openfold.model.torchscript import script_preset_ from openfold.np import residue_constants -from openfold.utils.argparse_utils import remove_arguments from openfold.utils.callbacks import ( EarlyStoppingVerbose, ) @@ -55,7 +54,7 @@ def __init__(self, config): self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) - + self.cached_weights = None self.last_lr_step = -1 self.save_hyperparameters() @@ -73,7 +72,7 @@ def _log(self, loss_breakdown, batch, outputs, train=True): on_step=train, on_epoch=(not train), logger=True, sync_dist=False, ) - if(train): + if (train): self.log( f"{phase}/{loss_name}_epoch", indiv_loss, @@ -82,12 +81,12 @@ def _log(self, loss_breakdown, batch, outputs, train=True): with torch.no_grad(): other_metrics = self._compute_validation_metrics( - batch, + batch, outputs, superimposition_metrics=(not train) ) - for k,v in other_metrics.items(): + for k, v in other_metrics.items(): self.log( f"{phase}/{k}", torch.mean(v), @@ -96,7 +95,7 @@ def _log(self, loss_breakdown, batch, outputs, train=True): ) def training_step(self, batch, batch_idx): - if(self.ema.device != batch["aatype"].device): + if (self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) @@ -127,12 +126,13 @@ def on_before_zero_grad(self, *args, **kwargs): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): + if (self.cached_weights is None): # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling + # than copies. Therefore, we need to clone them before calling # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + def clone_param(t): return t.detach().clone() + self.cached_weights = tensor_tree_map( + clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) ground_truth = batch.pop('gt_features', None) @@ -160,17 +160,17 @@ def on_validation_epoch_end(self): self.model.load_state_dict(self.cached_weights) self.cached_weights = None - def _compute_validation_metrics(self, - batch, - outputs, - superimposition_metrics=False - ): + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): metrics = {} - + gt_coords = batch["all_atom_positions"] pred_coords = outputs["final_atom_positions"] all_atom_mask = batch["all_atom_mask"] - + # This is super janky for superimposition. Fix later gt_coords_masked = gt_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None] @@ -178,7 +178,7 @@ def _compute_validation_metrics(self, gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] all_atom_mask_ca = all_atom_mask[..., ca_pos] - + lddt_ca_score = lddt_ca( pred_coords, gt_coords, @@ -186,18 +186,18 @@ def _compute_validation_metrics(self, eps=self.config.globals.eps, per_residue=False, ) - + metrics["lddt_ca"] = lddt_ca_score - + drmsd_ca_score = drmsd( pred_coords_masked_ca, gt_coords_masked_ca, - mask=all_atom_mask_ca, # still required here to compute n + mask=all_atom_mask_ca, # still required here to compute n ) - + metrics["drmsd_ca"] = drmsd_ca_score - - if(superimposition_metrics): + + if (superimposition_metrics): superimposed_pred, alignment_rmsd = superimpose( gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, ) @@ -211,7 +211,7 @@ def _compute_validation_metrics(self, metrics["alignment_rmsd"] = alignment_rmsd metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ha"] = gdt_ha_score - + return metrics def configure_optimizers(self, @@ -220,8 +220,8 @@ def configure_optimizers(self, ) -> torch.optim.Adam: # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( - self.model.parameters(), - lr=learning_rate, + self.model.parameters(), + lr=learning_rate, eps=eps ) @@ -246,8 +246,9 @@ def configure_optimizers(self, def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] - if(not self.model.template_config.enabled): - ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} + if (not self.model.template_config.enabled): + ema["params"] = {k: v for k, + v in ema["params"].items() if not "template" in k} self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): @@ -258,13 +259,13 @@ def resume_last_lr_step(self, lr_step): def load_from_jax(self, jax_path): model_basename = os.path.splitext( - os.path.basename( - os.path.normpath(jax_path) - ) + os.path.basename( + os.path.normpath(jax_path) + ) )[0] model_version = "_".join(model_basename.split("_")[1:]) import_jax_weights_( - self.model, jax_path, version=model_version + self.model, jax_path, version=model_version ) def get_model_state_dict_from_ds_checkpoint(checkpoint_dir): @@ -331,30 +332,31 @@ def main(args): if args.resume_from_jax_params: model_module.load_from_jax(args.resume_from_jax_params) - logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") - + logging.info( + f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") + # TorchScript components of the model - if(args.script_modules): + if (args.script_modules): script_preset_(model_module) if "multimer" in args.config_preset: data_module = OpenFoldMultimerDataModule( - config=config.data, - batch_seed=args.seed, - **vars(args) - ) + config=config.data, + batch_seed=args.seed, + **vars(args) + ) else: data_module = OpenFoldDataModule( - config=config.data, + config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() - + callbacks = [] - if(args.checkpoint_every_epoch): + if (args.checkpoint_every_epoch): mc = ModelCheckpoint( every_n_epochs=1, auto_insert_metric_name=False, @@ -362,7 +364,7 @@ def main(args): ) callbacks.append(mc) - if(args.early_stopping): + if (args.early_stopping): es = EarlyStoppingVerbose( monitor="val/lddt_ca", min_delta=args.min_delta, @@ -374,7 +376,7 @@ def main(args): ) callbacks.append(es) - if(args.log_performance): + if (args.log_performance): global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_file=os.path.join(args.output_dir, "performance_log.json"), @@ -382,7 +384,7 @@ def main(args): ) callbacks.append(perf) - if(args.log_lr): + if (args.log_lr): lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) @@ -448,7 +450,7 @@ def main(args): ckpt_path = args.resume_from_ckpt trainer.fit( - model_module, + model_module, datamodule=data_module, ckpt_path=ckpt_path, ) @@ -680,22 +682,22 @@ def bool_type(bool_str: str): trainer_group.add_argument( "--reload_dataloaders_every_n_epochs", type=int, default=1, ) - - trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1, - help="Accumulate gradients over k batches before next optimizer step.") + trainer_group.add_argument( + "--accumulate_grad_batches", type=int, default=1, + help="Accumulate gradients over k batches before next optimizer step.") args = parser.parse_args() - if(args.seed is None and - ((args.gpus is not None and args.gpus > 1) or + if (args.seed is None and + ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") - if(str(args.precision) == "16" and args.deepspeed_config_path is not None): + if (str(args.precision) == "16" and args.deepspeed_config_path is not None): raise ValueError("DeepSpeed and FP16 training are not compatible") - if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): - raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") - + if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): + raise ValueError( + "Choose between loading pretrained Jax-weights and a checkpoint-path") main(args)