diff --git a/README.md b/README.md index 895b10b..2945a9b 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,9 @@ uv sync # Analyze a structure file uv run grodecoder path/to/structure.gro +# Analyze a pair topology file + coordinates +uv run grodecoder path/to/topology.psf /path/to/coordinates.coor + # Output to stdout with compact format uv run grodecoder structure.pdb --compact --stdout @@ -101,6 +104,40 @@ GROdecoder produces detailed JSON inventories with the following structure: ## 🔧 Advanced Features +### Read back a Grodecoder inventory file + +Reading a Grodecoder inventory file is essential to be able to access the different parts of a system +without having to identify them again: + +```python +from grodecoder import read_grodecoder_output + +gro_results = read_grodecoder_output("1BRS_grodecoder.json") + +# Print the sequence of protein segment only. +for segment in gro_results.decoded.inventory.segments: + if segment.is_protein(): + print(segment.sequence) +``` + +In conjunction with the structure file, we can use the grodecoder output file to access the different +parts of the system, as identified by grodecoder: + +```python +import MDAnalysis +from grodecoder import read_grodecoder_output + + +universe = MDAnalysis.Universe("tests/data/1BRS.pdb") +gro_results = read_grodecoder_output("1BRS_grodecoder.json") + +# Prints the center of mass of each protein segment. +for segment in gro_results.decoded.inventory.segments: + if segment.is_protein(): + seg: MDAnalysis.AtomGroup = universe.atoms[segment.atoms] + print(seg.center_of_mass()) +``` + ### Chain Detection GROdecoder uses sophisticated distance-based algorithms to detect protein and nucleic acid chains: diff --git a/pyproject.toml b/pyproject.toml index 7024310..1395ed7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,3 @@ line-length = 110 [tool.pytest.ini_options] addopts = "-ra" - -[tool.mypy] -ignore_missing_imports = true diff --git a/src/grodecoder/__init__.py b/src/grodecoder/__init__.py index dbe9651..52a7c9a 100644 --- a/src/grodecoder/__init__.py +++ b/src/grodecoder/__init__.py @@ -1,63 +1,13 @@ -import json -from datetime import datetime - -import MDAnalysis as mda -from loguru import logger - -from . import databases, toputils -from ._typing import AtomGroup, Json, PathLike, Residue, Universe, UniverseLike -from .identifier import identify +from .core import decode, decode_structure from .models import Decoded, GrodecoderRunOutput, GrodecoderRunOutputRead +from .io import read_grodecoder_output, read_universe __all__ = [ - "databases", - "identify", - "toputils", - "read_structure", - "AtomGroup", + "decode", + "decode_structure", + "read_grodecoder_output", + "read_universe", "Decoded", "GrodecoderRunOutput", "GrodecoderRunOutputRead", - "Json", - "PathLike", - "Residue", - "Universe", - "UniverseLike", ] - -__version__ = "0.0.1" - - -def _now() -> str: - """Returns the current date and time formatted string.""" - return datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - -def read_structure(path: PathLike, psf_path: PathLike | None = None) -> Universe: - """Reads a structure file.""" - if psf_path: - return mda.Universe(path, psf_path) - return mda.Universe(path) - - -def read_grodecoder_output(path: PathLike) -> GrodecoderRunOutputRead: - with open(path) as fileobj: - return GrodecoderRunOutputRead.model_validate(json.load(fileobj)) - - -def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: - """Decodes the universe into an inventory of segments.""" - return Decoded( - inventory=identify(universe, bond_threshold=bond_threshold), - resolution=toputils.guess_resolution(universe), - ) - - -def decode_structure( - path: PathLike, psf_path: PathLike | None = None, bond_threshold: float = 5.0 -) -> Decoded: - """Reads a structure file and decodes it into an inventory of segments.""" - universe = read_structure(path, psf_path) - assert universe.atoms is not None # required by type checker for some reason - logger.info(f"{path}: {len(universe.atoms):,d} atoms") - return decode(universe, bond_threshold=bond_threshold) diff --git a/src/grodecoder/cli.py b/src/grodecoder/cli.py deleted file mode 100644 index 1bdb1e5..0000000 --- a/src/grodecoder/cli.py +++ /dev/null @@ -1,126 +0,0 @@ -import hashlib -import json -import sys -import time -from pathlib import Path - -import click -from loguru import logger - -import grodecoder as gd - - -class PathToStructureFile(click.ParamType): - """Custom click parameter type for validating structure files.""" - - name = "structure file" - - def convert(self, value, param, ctx): - """Convert the input value to a Path object.""" - path = Path(value) - if not path.exists(): - self.fail(f"'{path}' does not exist", param, ctx) - if not path.is_file(): - self.fail(f"'{path}' is not a file", param, ctx) - valid_extensions = (".gro", ".pdb", ".coor", ".crd") - if path.suffix not in valid_extensions: - self.fail( - f"'{path}' has an invalid extension (valid extensions are {valid_extensions})", param, ctx - ) - return path - - -class DefaultFilenameGenerator: - """Generates output filenames based on the structure file name.""" - - DEFAULT_STEM_SUFFIX = "_grodecoder" - - def __init__(self, structure_path: Path, stem_suffix: str = DEFAULT_STEM_SUFFIX): - self._stem = structure_path.stem - self._stem_suffix = stem_suffix - self._output_stem = self._stem + self._stem_suffix - - def _generate_output_filename(self, extension: str) -> Path: - """Generic method to create an output filename with the given extension.""" - return Path(self._output_stem + extension) - - @property - def inventory_filename(self) -> Path: - """Generates the output JSON filename.""" - return self._generate_output_filename(".json") - - @property - def log_filename(self) -> Path: - """Generates the output log filename.""" - return self._generate_output_filename(".log") - - -def setup_logging(logfile: Path, debug: bool = False): - """Sets up logging configuration.""" - fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" - level = "DEBUG" if debug else "INFO" - logger.remove() - logger.add(sys.stderr, level=level, format=fmt, colorize=True) - logger.add(logfile, level=level, format=fmt, colorize=False, mode="w") - - -def _get_checksum(structure_path: gd.PathLike) -> str: - """Computes a checksum for the structure file.""" - with open(structure_path, "rb") as f: - return hashlib.md5(f.read()).hexdigest() - - -def main(structure_path: Path, bond_threshold: float, compact_serialization: bool, output_to_stdout: bool): - """Main function to process a structure file and count the molecules. - - Args: - structure_path (Path): Path to the structure file. - bond_threshold (float): Threshold for interchain bond detection. - compact_serialization (bool): If True, use compact serialization (no atom indices). - output_to_stdout (bool): Whether to output results to stdout. - """ - start_time = time.perf_counter_ns() - logger.info(f"Processing structure file: {structure_path}") - - # Decoding. - decoded = gd.decode_structure(structure_path, bond_threshold=bond_threshold) - - output = gd.GrodecoderRunOutput( - decoded=decoded, - structure_file_checksum=_get_checksum(structure_path), - database_version=gd.databases.__version__, - grodecoder_version=gd.__version__, - ) - - # Serialization. - serialization_mode = "compact" if compact_serialization else "full" - - # Updates run time as late as possible. - output_json = output.model_dump(context={"serialization_mode": serialization_mode}) - output_json["runtime_in_seconds"] = (time.perf_counter_ns() - start_time) / 1e9 - - # Output results: to stdout or writes to a file. - if output_to_stdout: - print(json.dumps(output_json, indent=2)) - else: - inventory_filename = DefaultFilenameGenerator(structure_path).inventory_filename - with open(inventory_filename, "w") as f: - f.write(json.dumps(output_json, indent=2)) - logger.info(f"Results written to {inventory_filename}") - - -@click.command() -@click.argument("structure_path", type=PathToStructureFile()) -@click.option( - "--bond-threshold", - default=5.0, - type=float, - help="Threshold for interchain bond detection (default: 5 Å)", -) -@click.option("--no-atom-ids", is_flag=True, help="do not output the atom indice array") -@click.option("-s", "--stdout", is_flag=True, help="Output the results to stdout in JSON format") -def cli(structure_path, bond_threshold, no_atom_ids, stdout): - """Command-line interface for processing structure files.""" - logfile = DefaultFilenameGenerator(structure_path).log_filename - setup_logging(logfile) - main(structure_path, bond_threshold, no_atom_ids, stdout) diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py new file mode 100644 index 0000000..fab9282 --- /dev/null +++ b/src/grodecoder/cli/__init__.py @@ -0,0 +1,41 @@ +import click + +from ..main import main as grodecoder_main +from .args import Arguments as CliArgs +from .args import CoordinatesFile, StructureFile +from ..logging import setup_logging + + +@click.command() +@click.argument("structure_file", type=StructureFile) +@click.argument("coordinates_file", type=CoordinatesFile, required=False) +@click.option( + "--bond-threshold", + default=5.0, + type=float, + help="Threshold for interchain bond detection (default: 5 Å)", +) +@click.option("--no-atom-ids", is_flag=True, help="do not output the atom indice array") +@click.option( + "-s", + "--stdout", + metavar="print_to_stdout", + is_flag=True, + help="Output the results to stdout in JSON format", +) +@click.option("-v", "--verbose", is_flag=True, help="show debug messages") +def cli(**kwargs): + """Command-line interface for processing structure files.""" + args = CliArgs( + structure_file=kwargs["structure_file"], + coordinates_file=kwargs["coordinates_file"], + no_atom_ids=kwargs["no_atom_ids"], + print_to_stdout=kwargs["stdout"], + ) + + logfile = args.get_log_filename() + setup_logging(logfile, kwargs["verbose"]) + grodecoder_main(args) + + +__all__ = ["cli"] diff --git a/src/grodecoder/cli/args.py b/src/grodecoder/cli/args.py new file mode 100644 index 0000000..68f7c75 --- /dev/null +++ b/src/grodecoder/cli/args.py @@ -0,0 +1,85 @@ +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar + +from loguru import logger + +DEFAULT_OUTPUT_STEM_SUFFIX = "_grodecoder" + + +def _fatal_error(msg: str, status: int = 1): + """Prints an error message and exits with status `status`.""" + logger.critical(msg) + sys.exit(status) + + +@dataclass +class InputFile: + path: Path + valid_extensions: ClassVar[set[str]] + + def __post_init__(self): + # Ensures paths are pathlib.Path instances. + self.path = Path(self.path) + + # Ensures paths are valid files. + path = self.path + if not path.exists(): + _fatal_error(f"'{path}' does not exist") + if not path.is_file(): + _fatal_error(f"'{path}' is not a file") + if path.suffix not in self.valid_extensions: + _fatal_error(f"'{path}' has an invalid extension (valid extensions are {self.valid_extensions})") + return path + + @property + def extension(self) -> str: + return self.path.suffix + + @property + def stem(self) -> str: + return self.path.stem + + +@dataclass +class StructureFile(InputFile): + valid_extensions: ClassVar[set[str]] = {".gro", ".pdb", ".tpr", ".psf"} + + +@dataclass +class CoordinatesFile(InputFile): + valid_extensions: ClassVar[set[str]] = {".gro", ".pdb", ".tpr", ".psf", ".coor"} + + +@dataclass +class Arguments: + """Holds command-line arguments. + + Attrs: + structure_file (Path): Path to the structure file. + coordinates_file (Path): Path to the coordinates file. + bond_threshold (float): Threshold for interchain bond detection. + no_atom_ids (bool): If True, use compact serialization (no atom indices). + print_to_stdout (bool): Whether to output results to stdout. + """ + + structure_file: StructureFile + coordinates_file: CoordinatesFile | None = None + bond_threshold: float = 5.0 + no_atom_ids: bool = True + print_to_stdout: bool = False + + def get_log_filename(self) -> Path: + return generate_output_log_path(self.structure_file.stem) + + def get_inventory_filename(self) -> Path: + return generate_output_inventory_path(self.structure_file.stem) + + +def generate_output_inventory_path(stem: str) -> Path: + return Path(stem + DEFAULT_OUTPUT_STEM_SUFFIX + ".json") + + +def generate_output_log_path(stem: str) -> Path: + return Path(stem + DEFAULT_OUTPUT_STEM_SUFFIX + ".log") diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py new file mode 100644 index 0000000..986e45b --- /dev/null +++ b/src/grodecoder/core.py @@ -0,0 +1,32 @@ +from datetime import datetime + +from loguru import logger + +from ._typing import PathLike, UniverseLike +from .identifier import identify +from .io import read_universe +from .models import Decoded +from .toputils import guess_resolution + + +def _now() -> str: + """Returns the current date and time formatted string.""" + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded: + """Decodes the universe into an inventory of segments.""" + return Decoded( + inventory=identify(universe, bond_threshold=bond_threshold), + resolution=guess_resolution(universe), + ) + + +def decode_structure( + structure_path: PathLike, coordinates_path: PathLike | None = None, bond_threshold: float = 5.0 +) -> Decoded: + """Reads a structure file and decodes it into an inventory of segments.""" + universe = read_universe(structure_path, coordinates_path) + assert universe.atoms is not None # required by type checker for some reason + logger.debug(f"Universe has {len(universe.atoms):,d} atoms") + return decode(universe, bond_threshold=bond_threshold) diff --git a/src/grodecoder/databases/__init__.py b/src/grodecoder/databases/__init__.py index de174db..082b311 100644 --- a/src/grodecoder/databases/__init__.py +++ b/src/grodecoder/databases/__init__.py @@ -1,29 +1,32 @@ +from functools import lru_cache + from .api import ( + DATABASES_DATA_PATH, get_amino_acid_definitions, - get_amino_acid_names, get_amino_acid_name_map, + get_amino_acid_names, get_ion_definitions, get_ion_names, get_lipid_definitions, get_lipid_names, get_nucleotide_definitions, - get_nucleotide_names, get_nucleotide_name_map, + get_nucleotide_names, get_other_definitions, get_solvent_definitions, get_solvent_names, - DATABASES_DATA_PATH, ) from .models import ( - ResidueDefinition, AminoAcid, Ion, Lipid, Nucleotide, + ResidueDefinition, Solvent, ) __all__ = [ + "get_database_version", "get_amino_acid_definitions", "get_amino_acid_names", "get_amino_acid_name_map", @@ -46,5 +49,7 @@ ] -with open(DATABASES_DATA_PATH / "version.txt") as f: - __version__ = f.read().strip() +@lru_cache(maxsize=1) +def get_database_version() -> str: + with open(DATABASES_DATA_PATH / "version.txt") as f: + return f.read().strip() diff --git a/src/grodecoder/databases/api.py b/src/grodecoder/databases/api.py index eb4b8ef..5a4df93 100644 --- a/src/grodecoder/databases/api.py +++ b/src/grodecoder/databases/api.py @@ -1,9 +1,8 @@ -import itertools import json from collections import Counter -from dataclasses import dataclass +from functools import lru_cache from pathlib import Path -from typing import Iterable, TypeVar +from typing import TypeVar from loguru import logger @@ -96,18 +95,32 @@ def read_csml_database() -> list[csml.Residue]: return _read_database(CSML_DB_PATH, csml.Residue) -ION_DB: list[Ion] = read_ion_database() -SOLVENT_DB: list[Solvent] = read_solvent_database() -AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database() -NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database() +# ION_DB: list[Ion] = read_ion_database() +# SOLVENT_DB: list[Solvent] = read_solvent_database() +# AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database() +# NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database() +# MAD_DB: list[mad.Residue] = read_mad_database() +# CSML_DB: list[csml.Residue] = read_csml_database() +# LIPID_DB: list[Lipid] = _build_lipid_db() +# OTHER_DB: list[Residue] = _build_other_db() -MAD_DB: list[mad.Residue] = read_mad_database() -CSML_DB: list[csml.Residue] = read_csml_database() + +@lru_cache(maxsize=1) +def _get_csml_databse(): + return read_csml_database() + + +@lru_cache(maxsize=1) +def _get_mad_databse(): + return read_mad_database() def _build_lipid_db() -> list[Lipid]: - _mad_lipid_resnames = {item.alias for item in MAD_DB if item.family == mad.ResidueFamily.LIPID} - _csml_lipid_resnames = {residue.name for residue in CSML_DB if residue.family == csml.ResidueFamily.LIPID} + mad_db = _get_mad_databse() + csml_db = _get_csml_databse() + + _mad_lipid_resnames = {item.alias for item in mad_db if item.family == mad.ResidueFamily.LIPID} + _csml_lipid_resnames = {residue.name for residue in csml_db if residue.family == csml.ResidueFamily.LIPID} if False: # IMPORTANT: @@ -121,27 +134,27 @@ def _build_lipid_db() -> list[Lipid]: db = { item.alias: Lipid(description=item.name, residue_name=item.alias) - for item in MAD_DB + for item in mad_db if item.family == mad.ResidueFamily.LIPID } db.update( { residue.name: Lipid(description=residue.description, residue_name=residue.name) - for residue in CSML_DB + for residue in csml_db if residue.family == csml.ResidueFamily.LIPID } ) return list(db.values()) -LIPID_DB: list[Lipid] = _build_lipid_db() - - def _build_other_db() -> list[Residue]: """Builds a database of other residues that are not ions, solvents, amino acids, or nucleotides.""" + mad_db = _get_mad_databse() + csml_db = _get_csml_databse() + csml_other = { residue.name: Residue(residue_name=residue.name, description=residue.description) - for residue in CSML_DB + for residue in csml_db if residue.family not in { csml.ResidueFamily.PROTEIN, @@ -154,7 +167,7 @@ def _build_other_db() -> list[Residue]: mad_other = { residue.alias: Residue(residue_name=residue.alias, description=residue.name) - for residue in MAD_DB + for residue in mad_db if residue.family not in { mad.ResidueFamily.PROTEIN, @@ -168,124 +181,72 @@ def _build_other_db() -> list[Residue]: return list(by_name.values()) -OTHER_DB: list[Residue] = _build_other_db() - - -def get_other_definitions() -> list[Residue]: - """Returns the definitions of other residues in the database.""" - return OTHER_DB - - +@lru_cache(maxsize=1) def get_ion_definitions() -> list[Ion]: """Returns the definitions of the ions in the database.""" - return ION_DB + return read_ion_database() +@lru_cache(maxsize=1) def get_solvent_definitions() -> list[Solvent]: """Returns the definitions of the solvents in the database.""" - return SOLVENT_DB + return read_solvent_database() +@lru_cache(maxsize=1) def get_amino_acid_definitions() -> list[AminoAcid]: """Returns the definitions of the amino acids in the database.""" - return AMINO_ACIDS_DB - - -def get_amino_acid_name_map() -> dict[str, str]: - """Returns a mapping of amino acid 3-letter names to 1-letter names.""" - return {aa.long_name: aa.short_name for aa in AMINO_ACIDS_DB} - - -def get_nucleotide_name_map() -> dict[str, str]: - """Returns a mapping of nucleotide 3-letter names to 1-letter names.""" - return {nucleotide.residue_name: nucleotide.short_name for nucleotide in NUCLEOTIDES_DB} + return read_amino_acid_database() +@lru_cache(maxsize=1) def get_nucleotide_definitions() -> list[Nucleotide]: """Returns the definitions of the nucleotides in the database.""" - return NUCLEOTIDES_DB + return read_nucleotide_database() +@lru_cache(maxsize=1) def get_lipid_definitions() -> list[Lipid]: """Returns the definitions of the lipids in the database.""" - return LIPID_DB + return _build_lipid_db() + + +@lru_cache(maxsize=1) +def get_other_definitions() -> list[Residue]: + """Returns the definitions of other residues in the database.""" + return _build_other_db() + + +def get_amino_acid_name_map() -> dict[str, str]: + """Returns a mapping of amino acid 3-letter names to 1-letter names.""" + return {aa.long_name: aa.short_name for aa in get_amino_acid_definitions()} + + +def get_nucleotide_name_map() -> dict[str, str]: + """Returns a mapping of nucleotide 3-letter names to 1-letter names.""" + return {nucleotide.residue_name: nucleotide.short_name for nucleotide in get_nucleotide_definitions()} def get_ion_names() -> set[str]: """Returns the names of the ions in the database.""" - return set(ion.residue_name for ion in ION_DB) + return set(ion.residue_name for ion in get_ion_definitions()) def get_solvent_names() -> set[str]: """Returns the names of the solvents in the database.""" - return set(solvent.residue_name for solvent in SOLVENT_DB) + return set(solvent.residue_name for solvent in get_solvent_definitions()) def get_amino_acid_names() -> set[str]: """Returns the names of the amino acids in the database.""" - return set(aa.long_name for aa in AMINO_ACIDS_DB) + return set(aa.long_name for aa in get_amino_acid_definitions()) def get_nucleotide_names() -> set[str]: """Returns the names of the nucleotides in the database.""" - return set(nucleotide.residue_name for nucleotide in NUCLEOTIDES_DB) + return set(nucleotide.residue_name for nucleotide in get_nucleotide_definitions()) def get_lipid_names() -> set[str]: """Returns the names of the lipids in the database.""" - return set(lipid.residue_name for lipid in LIPID_DB) - - -@dataclass(frozen=True) -class ResidueDatabase: - """Database of residues.""" - - ions: list[Ion] - solvents: list[Solvent] - amino_acids: list[AminoAcid] - nucleotides: list[Nucleotide] - - def __post_init__(self): - names = { - "ions": {ion.residue_name for ion in self.ions}, - "solvents": {solvent.residue_name for solvent in self.solvents}, - "amino_acids": {aa.long_name for aa in self.amino_acids}, - "nucleotides": {nucleotide.residue_name for nucleotide in self.nucleotides}, - } - - combinations = itertools.combinations(names.keys(), 2) - for lhs, rhs in combinations: - duplicates = names[lhs].intersection(names[rhs]) - if duplicates: - logger.warning( - f"Residue names {duplicates} are defined in multiple families: {lhs} and {rhs}" - ) - - -class ResidueNotFound(Exception): - """Raised when a residue with a given name and atom names is not found in the database.""" - - -class DuplicateResidue(Exception): - """Raised when a residue with a given name and atom names is defined multiple times in the database.""" - - -def _find_using_atom_names( - residue_name: str, atom_names: Iterable[str], database: list[Ion | Solvent] -) -> Ion | Solvent | None: - candidate_residues = [ion for ion in database if ion.residue_name == residue_name] - if not candidate_residues: - return None - - actual_atom_names = set(atom_names) - matching_residues = [ion for ion in candidate_residues if set(ion.atom_names) == actual_atom_names] - - if len(matching_residues) == 0: - raise ResidueNotFound(f"No residue '{residue_name}' found with atom names {actual_atom_names}") - - elif len(matching_residues) > 1: - raise DuplicateResidue( - f"Multiple residues '{residue_name}' found with atom names {actual_atom_names}" - ) - - return matching_residues[0] + return set(lipid.residue_name for lipid in get_lipid_definitions()) diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py index e9fca28..f7d4ee5 100644 --- a/src/grodecoder/identifier.py +++ b/src/grodecoder/identifier.py @@ -96,7 +96,7 @@ def _get_protein_segments(atoms: AtomGroup, bond_threshold: float = 5.0) -> list """Returns the protein segments in the universe.""" protein = _select_protein(atoms) return [ - Segment(atoms=atoms, sequence=toputils.sequence(atoms), molecular_type=MolecularType.PROTEIN) + Segment(atoms=atoms, molecular_type=MolecularType.PROTEIN) for atoms in _iter_chains(protein, bond_threshold) ] @@ -105,7 +105,7 @@ def _get_nucleic_segments(atoms: AtomGroup, bond_threshold: float = 5.0) -> list """Returns the nucleic acid segments in the universe.""" nucleic = _select_nucleic(atoms) return [ - Segment(atoms=atoms, sequence=toputils.sequence(atoms), molecular_type=MolecularType.NUCLEIC) + Segment(atoms=atoms, molecular_type=MolecularType.NUCLEIC) for atoms in _iter_chains(nucleic, bond_threshold) ] @@ -164,6 +164,10 @@ def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory: # Remove identified atoms from the universe along the way to avoid double counting (e.g. # 'MET' residues are counted first in the protein, then removed so not counted elsewhere). + # All ty: ignore[invalid-argument-type] in this block fix ty clear mistake: + # Expected `list[HasAtoms]`, found `list[Segment]` + # while `Segment` clearly satisfies the `HasAtoms` Protocol. + protein = _get_protein_segments(universe, bond_threshold=bond_threshold) _log_identified_segments(protein, "protein") universe = _remove_identified_atoms(universe, protein) # ty: ignore[invalid-argument-type] diff --git a/src/grodecoder/io.py b/src/grodecoder/io.py new file mode 100644 index 0000000..d8ff23b --- /dev/null +++ b/src/grodecoder/io.py @@ -0,0 +1,32 @@ +"""Grodecoder read/write functions""" + +import json + +import MDAnalysis as mda +from loguru import logger +from MDAnalysis.exceptions import NoDataError + +from ._typing import PathLike +from .models import GrodecoderRunOutputRead + + +def read_universe(structure_path: PathLike, coordinates_path: PathLike | None = None) -> mda.Universe: + """Reads a structure file.""" + source = (structure_path, coordinates_path) if coordinates_path else (structure_path,) + source_str = ", ".join(str(s) for s in source) + logger.debug(f"Reading universe from {source_str}") + universe: mda.Universe | None = None + try: + universe = mda.Universe(*source) + except Exception as e: + raise IOError("MDAnalysis error while reading universe") from e + + if not hasattr(universe, "trajectory"): + raise NoDataError(f"no coordinates read from {source_str}") + + return universe + + +def read_grodecoder_output(path: PathLike) -> GrodecoderRunOutputRead: + with open(path) as fileobj: + return GrodecoderRunOutputRead.model_validate(json.load(fileobj)) diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py new file mode 100644 index 0000000..a32df47 --- /dev/null +++ b/src/grodecoder/logging.py @@ -0,0 +1,34 @@ +"""grodecoder logging utilities.""" + +import sys +import warnings +from pathlib import Path + +from loguru import logger + + +def setup_logging(logfile: Path, debug: bool = False): + """Sets up logging configuration.""" + fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}" + level = "DEBUG" if debug else "INFO" + logger.remove() + logger.add(sys.stderr, level=level, format=fmt, colorize=True) + logger.add(logfile, level=level, format=fmt, colorize=False, mode="w") + + # Sets up loguru to capture warnings (typically MDAnalysis warnings) + def showwarning(message, *args, **kwargs): + logger.opt(depth=2).warning(message) + + warnings.showwarning = showwarning # ty: ignore invalid-assignment + + +def is_logging_debug() -> bool: + """Returns True if at least one logging handler is set to level DEBUG.""" + return "DEBUG" in get_logging_level() + + +def get_logging_level() -> list[str]: + """Returns the list of logging level names (one value per handler).""" + core_logger = logger._core # ty: ignore unresolved-attribute + level_dict = {level.no: level.name for level in core_logger.levels.values()} + return [level_dict[h.levelno] for h in core_logger.handlers.values()] diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py new file mode 100644 index 0000000..d9e99d2 --- /dev/null +++ b/src/grodecoder/main.py @@ -0,0 +1,58 @@ +import hashlib +import json +import time +from typing import TYPE_CHECKING + +from loguru import logger + +from ._typing import PathLike +from .core import decode_structure +from .databases import get_database_version +from .models import GrodecoderRunOutput +from .version import get_version + +if TYPE_CHECKING: + from .cli.args import Arguments as CliArgs + + +def _get_checksum(structure_path: PathLike) -> str: + """Computes a checksum for the structure file.""" + with open(structure_path, "rb") as f: + return hashlib.md5(f.read()).hexdigest() + + +def main(args: "CliArgs"): + """Main function to process a structure file and count the molecules.""" + start_time = time.perf_counter_ns() + structure_path = args.structure_file.path + coordinates_path = args.coordinates_file.path if args.coordinates_file else None + + logger.info(f"Processing structure file: {structure_path}") + + # Decoding. + decoded = decode_structure( + structure_path, coordinates_path=coordinates_path, bond_threshold=args.bond_threshold + ) + + output = GrodecoderRunOutput( + decoded=decoded, + structure_file_checksum=_get_checksum(structure_path), + database_version=get_database_version(), + grodecoder_version=get_version(), + ) + + # Serialization. + serialization_mode = "compact" if args.no_atom_ids else "full" + + # Updates run time as late as possible. + output_json = output.model_dump(context={"serialization_mode": serialization_mode}) + output_json["runtime_in_seconds"] = (time.perf_counter_ns() - start_time) / 1e9 + + # Output results: to stdout or writes to a file. + if args.print_to_stdout: + print(json.dumps(output_json, indent=2)) + else: + inventory_filename = args.get_inventory_filename() + with open(inventory_filename, "w") as f: + f.write(json.dumps(output_json, indent=2)) + logger.info(f"Results written to {inventory_filename}") diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py index 252d557..1803ea1 100644 --- a/src/grodecoder/models.py +++ b/src/grodecoder/models.py @@ -1,4 +1,5 @@ from __future__ import annotations +from pydantic import model_validator from enum import StrEnum from typing import Protocol @@ -15,6 +16,8 @@ model_serializer, ) +from . import toputils + class MolecularResolution(StrEnum): COARSE_GRAINED = "coarse-grained" @@ -111,11 +114,36 @@ def serialize(self, handler: SerializerFunctionWrapHandler, info: SerializationI return self_data -class SmallMolecule(FrozenWithAtoms): +class MolecularTypeMixin(BaseModel): + molecular_type: MolecularType + + def is_ion(self) -> bool: + return self.molecular_type == MolecularType.ION + + def is_lipid(self) -> bool: + return self.molecular_type == MolecularType.LIPID + + def is_solvent(self) -> bool: + return self.molecular_type == MolecularType.SOLVENT + + def is_unknown(self) -> bool: + return self.molecular_type == MolecularType.UNKNOWN + + def is_other(self) -> bool: + """Alias for MolecularTypeMixin.is_unknown()""" + return self.is_unknown() + + def is_protein(self) -> bool: + return self.molecular_type == MolecularType.PROTEIN + + def is_nucleic(self) -> bool: + return self.molecular_type == MolecularType.NUCLEIC + + +class SmallMolecule(MolecularTypeMixin, FrozenWithAtoms): """Small molecules are defined as residues with a single residue name.""" description: str - molecular_type: MolecularType @computed_field @property @@ -124,11 +152,13 @@ def name(self) -> str: return self.atoms.residues[0].resname -class Segment(FrozenWithAtoms): +class Segment(MolecularTypeMixin, FrozenWithAtoms): """A segment is a group of atoms that are connected.""" - sequence: str - molecular_type: MolecularType # likely to be protein or nucleic acid + @computed_field + @property + def sequence(self) -> str: + return toputils.sequence(self.atoms) class Inventory(FrozenModel): @@ -175,16 +205,22 @@ class BaseModelWithAtomsRead(FrozenModel): number_of_atoms: int number_of_residues: int + @model_validator(mode="after") + def check_number_of_atoms_is_valid(self): + if len(self.atoms) != self.number_of_atoms: + raise ValueError( + f"field `number_of_atoms` ({self.number_of_atoms}) does not match number of atom ids ({len(self.atoms)})" + ) + return self + -class SmallMoleculeRead(BaseModelWithAtomsRead): +class SmallMoleculeRead(MolecularTypeMixin, BaseModelWithAtomsRead): name: str description: str - molecular_type: MolecularType -class SegmentRead(BaseModelWithAtomsRead): +class SegmentRead(MolecularTypeMixin, BaseModelWithAtomsRead): sequence: str - molecular_type: MolecularType class InventoryRead(FrozenModel): @@ -192,6 +228,15 @@ class InventoryRead(FrozenModel): segments: list[SegmentRead] total_number_of_atoms: int + @model_validator(mode="after") + def check_total_number_of_atoms(self): + n = sum(item.number_of_atoms for item in self.small_molecules + self.segments) + if self.total_number_of_atoms != n: + raise ValueError( + f"field `total_number_of_atoms` ({self.total_number_of_atoms}) does not add up with the rest of the inventory (found {n} atoms)" + ) + return self + class DecodedRead(FrozenModel): inventory: InventoryRead diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py index abea32d..878c469 100644 --- a/src/grodecoder/toputils.py +++ b/src/grodecoder/toputils.py @@ -85,8 +85,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int Example ------- - >>> from grodecoder import read_structure - >>> universe = read_structure("3EAM.pdb") + >>> universe = MDAnalysis.Universe("path/to/structure_file.pdb") >>> protein = universe.select_atoms("protein") >>> chains = detect_chains(protein) >>> for start, end in chains: diff --git a/src/grodecoder/version.py b/src/grodecoder/version.py new file mode 100644 index 0000000..6c0c00c --- /dev/null +++ b/src/grodecoder/version.py @@ -0,0 +1,4 @@ +__version__ = "0.0.1" + +def get_version() -> str: + return __version__ diff --git a/tests/test_identifier_README.md b/tests/test_identifier_README.md deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..6b5eba9 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,53 @@ +"""Tests for grodecoder.models.""" + +import json +from pathlib import Path + +import pytest + +from grodecoder.models import BaseModelWithAtomsRead, InventoryRead + +TEST_DATA_ROOT_DIR = Path(__file__).parent / "data" / "regression_data" +TEST_DATA_EXPECTED_RESULTS_DIR = TEST_DATA_ROOT_DIR / "expected_results" +EXPECTED_INVENTORY_FILES = list(TEST_DATA_EXPECTED_RESULTS_DIR.glob("*.json")) + + +class TestBaseModelWithAtomsRead: + def test_number_of_atoms_validation_success(self): + try: + BaseModelWithAtomsRead( + atoms=[1, 2, 3], + number_of_atoms=3, + number_of_residues=1, + ) + except Exception as exc: + assert False, f"exception was raised: {exc}" + + def test_number_of_atoms_validation_fails(self): + with pytest.raises(ValueError): + BaseModelWithAtomsRead( + atoms=[1, 2, 3], + number_of_atoms=2, # wrong number of atoms + number_of_residues=1, + ) + + +def _read_json(path: str) -> dict: + with open(path, "rb") as f: + return json.load(f) + + +class TestInventoryRead: + def test_total_number_of_atoms_success(self): + raw_json = _read_json(EXPECTED_INVENTORY_FILES[0]) + try: + InventoryRead.model_validate(raw_json["inventory"]) + except Exception as exc: + assert False, f"exception was raised: {exc}" + + def test_total_number_of_atoms_fail(self): + raw_json = _read_json(EXPECTED_INVENTORY_FILES[0]) + assert "total_number_of_atoms" in raw_json.get("inventory", {}) + raw_json["inventory"]["total_number_of_atoms"] += 1000 # sets this field to wrong value + with pytest.raises(ValueError, match=r"field `total_number_of_atoms` \([0-9]+\) does not add up"): + InventoryRead.model_validate(raw_json["inventory"])