diff --git a/README.md b/README.md
index 895b10b..2945a9b 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,9 @@ uv sync
# Analyze a structure file
uv run grodecoder path/to/structure.gro
+# Analyze a pair topology file + coordinates
+uv run grodecoder path/to/topology.psf /path/to/coordinates.coor
+
# Output to stdout with compact format
uv run grodecoder structure.pdb --compact --stdout
@@ -101,6 +104,40 @@ GROdecoder produces detailed JSON inventories with the following structure:
## 🔧 Advanced Features
+### Read back a Grodecoder inventory file
+
+Reading a Grodecoder inventory file is essential to be able to access the different parts of a system
+without having to identify them again:
+
+```python
+from grodecoder import read_grodecoder_output
+
+gro_results = read_grodecoder_output("1BRS_grodecoder.json")
+
+# Print the sequence of protein segment only.
+for segment in gro_results.decoded.inventory.segments:
+ if segment.is_protein():
+ print(segment.sequence)
+```
+
+In conjunction with the structure file, we can use the grodecoder output file to access the different
+parts of the system, as identified by grodecoder:
+
+```python
+import MDAnalysis
+from grodecoder import read_grodecoder_output
+
+
+universe = MDAnalysis.Universe("tests/data/1BRS.pdb")
+gro_results = read_grodecoder_output("1BRS_grodecoder.json")
+
+# Prints the center of mass of each protein segment.
+for segment in gro_results.decoded.inventory.segments:
+ if segment.is_protein():
+ seg: MDAnalysis.AtomGroup = universe.atoms[segment.atoms]
+ print(seg.center_of_mass())
+```
+
### Chain Detection
GROdecoder uses sophisticated distance-based algorithms to detect protein and nucleic acid chains:
diff --git a/pyproject.toml b/pyproject.toml
index 7024310..1395ed7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,6 +39,3 @@ line-length = 110
[tool.pytest.ini_options]
addopts = "-ra"
-
-[tool.mypy]
-ignore_missing_imports = true
diff --git a/src/grodecoder/__init__.py b/src/grodecoder/__init__.py
index dbe9651..52a7c9a 100644
--- a/src/grodecoder/__init__.py
+++ b/src/grodecoder/__init__.py
@@ -1,63 +1,13 @@
-import json
-from datetime import datetime
-
-import MDAnalysis as mda
-from loguru import logger
-
-from . import databases, toputils
-from ._typing import AtomGroup, Json, PathLike, Residue, Universe, UniverseLike
-from .identifier import identify
+from .core import decode, decode_structure
from .models import Decoded, GrodecoderRunOutput, GrodecoderRunOutputRead
+from .io import read_grodecoder_output, read_universe
__all__ = [
- "databases",
- "identify",
- "toputils",
- "read_structure",
- "AtomGroup",
+ "decode",
+ "decode_structure",
+ "read_grodecoder_output",
+ "read_universe",
"Decoded",
"GrodecoderRunOutput",
"GrodecoderRunOutputRead",
- "Json",
- "PathLike",
- "Residue",
- "Universe",
- "UniverseLike",
]
-
-__version__ = "0.0.1"
-
-
-def _now() -> str:
- """Returns the current date and time formatted string."""
- return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-
-
-def read_structure(path: PathLike, psf_path: PathLike | None = None) -> Universe:
- """Reads a structure file."""
- if psf_path:
- return mda.Universe(path, psf_path)
- return mda.Universe(path)
-
-
-def read_grodecoder_output(path: PathLike) -> GrodecoderRunOutputRead:
- with open(path) as fileobj:
- return GrodecoderRunOutputRead.model_validate(json.load(fileobj))
-
-
-def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded:
- """Decodes the universe into an inventory of segments."""
- return Decoded(
- inventory=identify(universe, bond_threshold=bond_threshold),
- resolution=toputils.guess_resolution(universe),
- )
-
-
-def decode_structure(
- path: PathLike, psf_path: PathLike | None = None, bond_threshold: float = 5.0
-) -> Decoded:
- """Reads a structure file and decodes it into an inventory of segments."""
- universe = read_structure(path, psf_path)
- assert universe.atoms is not None # required by type checker for some reason
- logger.info(f"{path}: {len(universe.atoms):,d} atoms")
- return decode(universe, bond_threshold=bond_threshold)
diff --git a/src/grodecoder/cli.py b/src/grodecoder/cli.py
deleted file mode 100644
index 1bdb1e5..0000000
--- a/src/grodecoder/cli.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import hashlib
-import json
-import sys
-import time
-from pathlib import Path
-
-import click
-from loguru import logger
-
-import grodecoder as gd
-
-
-class PathToStructureFile(click.ParamType):
- """Custom click parameter type for validating structure files."""
-
- name = "structure file"
-
- def convert(self, value, param, ctx):
- """Convert the input value to a Path object."""
- path = Path(value)
- if not path.exists():
- self.fail(f"'{path}' does not exist", param, ctx)
- if not path.is_file():
- self.fail(f"'{path}' is not a file", param, ctx)
- valid_extensions = (".gro", ".pdb", ".coor", ".crd")
- if path.suffix not in valid_extensions:
- self.fail(
- f"'{path}' has an invalid extension (valid extensions are {valid_extensions})", param, ctx
- )
- return path
-
-
-class DefaultFilenameGenerator:
- """Generates output filenames based on the structure file name."""
-
- DEFAULT_STEM_SUFFIX = "_grodecoder"
-
- def __init__(self, structure_path: Path, stem_suffix: str = DEFAULT_STEM_SUFFIX):
- self._stem = structure_path.stem
- self._stem_suffix = stem_suffix
- self._output_stem = self._stem + self._stem_suffix
-
- def _generate_output_filename(self, extension: str) -> Path:
- """Generic method to create an output filename with the given extension."""
- return Path(self._output_stem + extension)
-
- @property
- def inventory_filename(self) -> Path:
- """Generates the output JSON filename."""
- return self._generate_output_filename(".json")
-
- @property
- def log_filename(self) -> Path:
- """Generates the output log filename."""
- return self._generate_output_filename(".log")
-
-
-def setup_logging(logfile: Path, debug: bool = False):
- """Sets up logging configuration."""
- fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}"
- level = "DEBUG" if debug else "INFO"
- logger.remove()
- logger.add(sys.stderr, level=level, format=fmt, colorize=True)
- logger.add(logfile, level=level, format=fmt, colorize=False, mode="w")
-
-
-def _get_checksum(structure_path: gd.PathLike) -> str:
- """Computes a checksum for the structure file."""
- with open(structure_path, "rb") as f:
- return hashlib.md5(f.read()).hexdigest()
-
-
-def main(structure_path: Path, bond_threshold: float, compact_serialization: bool, output_to_stdout: bool):
- """Main function to process a structure file and count the molecules.
-
- Args:
- structure_path (Path): Path to the structure file.
- bond_threshold (float): Threshold for interchain bond detection.
- compact_serialization (bool): If True, use compact serialization (no atom indices).
- output_to_stdout (bool): Whether to output results to stdout.
- """
- start_time = time.perf_counter_ns()
- logger.info(f"Processing structure file: {structure_path}")
-
- # Decoding.
- decoded = gd.decode_structure(structure_path, bond_threshold=bond_threshold)
-
- output = gd.GrodecoderRunOutput(
- decoded=decoded,
- structure_file_checksum=_get_checksum(structure_path),
- database_version=gd.databases.__version__,
- grodecoder_version=gd.__version__,
- )
-
- # Serialization.
- serialization_mode = "compact" if compact_serialization else "full"
-
- # Updates run time as late as possible.
- output_json = output.model_dump(context={"serialization_mode": serialization_mode})
- output_json["runtime_in_seconds"] = (time.perf_counter_ns() - start_time) / 1e9
-
- # Output results: to stdout or writes to a file.
- if output_to_stdout:
- print(json.dumps(output_json, indent=2))
- else:
- inventory_filename = DefaultFilenameGenerator(structure_path).inventory_filename
- with open(inventory_filename, "w") as f:
- f.write(json.dumps(output_json, indent=2))
- logger.info(f"Results written to {inventory_filename}")
-
-
-@click.command()
-@click.argument("structure_path", type=PathToStructureFile())
-@click.option(
- "--bond-threshold",
- default=5.0,
- type=float,
- help="Threshold for interchain bond detection (default: 5 Ã…)",
-)
-@click.option("--no-atom-ids", is_flag=True, help="do not output the atom indice array")
-@click.option("-s", "--stdout", is_flag=True, help="Output the results to stdout in JSON format")
-def cli(structure_path, bond_threshold, no_atom_ids, stdout):
- """Command-line interface for processing structure files."""
- logfile = DefaultFilenameGenerator(structure_path).log_filename
- setup_logging(logfile)
- main(structure_path, bond_threshold, no_atom_ids, stdout)
diff --git a/src/grodecoder/cli/__init__.py b/src/grodecoder/cli/__init__.py
new file mode 100644
index 0000000..fab9282
--- /dev/null
+++ b/src/grodecoder/cli/__init__.py
@@ -0,0 +1,41 @@
+import click
+
+from ..main import main as grodecoder_main
+from .args import Arguments as CliArgs
+from .args import CoordinatesFile, StructureFile
+from ..logging import setup_logging
+
+
+@click.command()
+@click.argument("structure_file", type=StructureFile)
+@click.argument("coordinates_file", type=CoordinatesFile, required=False)
+@click.option(
+ "--bond-threshold",
+ default=5.0,
+ type=float,
+ help="Threshold for interchain bond detection (default: 5 Ã…)",
+)
+@click.option("--no-atom-ids", is_flag=True, help="do not output the atom indice array")
+@click.option(
+ "-s",
+ "--stdout",
+ metavar="print_to_stdout",
+ is_flag=True,
+ help="Output the results to stdout in JSON format",
+)
+@click.option("-v", "--verbose", is_flag=True, help="show debug messages")
+def cli(**kwargs):
+ """Command-line interface for processing structure files."""
+ args = CliArgs(
+ structure_file=kwargs["structure_file"],
+ coordinates_file=kwargs["coordinates_file"],
+ no_atom_ids=kwargs["no_atom_ids"],
+ print_to_stdout=kwargs["stdout"],
+ )
+
+ logfile = args.get_log_filename()
+ setup_logging(logfile, kwargs["verbose"])
+ grodecoder_main(args)
+
+
+__all__ = ["cli"]
diff --git a/src/grodecoder/cli/args.py b/src/grodecoder/cli/args.py
new file mode 100644
index 0000000..68f7c75
--- /dev/null
+++ b/src/grodecoder/cli/args.py
@@ -0,0 +1,85 @@
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import ClassVar
+
+from loguru import logger
+
+DEFAULT_OUTPUT_STEM_SUFFIX = "_grodecoder"
+
+
+def _fatal_error(msg: str, status: int = 1):
+ """Prints an error message and exits with status `status`."""
+ logger.critical(msg)
+ sys.exit(status)
+
+
+@dataclass
+class InputFile:
+ path: Path
+ valid_extensions: ClassVar[set[str]]
+
+ def __post_init__(self):
+ # Ensures paths are pathlib.Path instances.
+ self.path = Path(self.path)
+
+ # Ensures paths are valid files.
+ path = self.path
+ if not path.exists():
+ _fatal_error(f"'{path}' does not exist")
+ if not path.is_file():
+ _fatal_error(f"'{path}' is not a file")
+ if path.suffix not in self.valid_extensions:
+ _fatal_error(f"'{path}' has an invalid extension (valid extensions are {self.valid_extensions})")
+ return path
+
+ @property
+ def extension(self) -> str:
+ return self.path.suffix
+
+ @property
+ def stem(self) -> str:
+ return self.path.stem
+
+
+@dataclass
+class StructureFile(InputFile):
+ valid_extensions: ClassVar[set[str]] = {".gro", ".pdb", ".tpr", ".psf"}
+
+
+@dataclass
+class CoordinatesFile(InputFile):
+ valid_extensions: ClassVar[set[str]] = {".gro", ".pdb", ".tpr", ".psf", ".coor"}
+
+
+@dataclass
+class Arguments:
+ """Holds command-line arguments.
+
+ Attrs:
+ structure_file (Path): Path to the structure file.
+ coordinates_file (Path): Path to the coordinates file.
+ bond_threshold (float): Threshold for interchain bond detection.
+ no_atom_ids (bool): If True, use compact serialization (no atom indices).
+ print_to_stdout (bool): Whether to output results to stdout.
+ """
+
+ structure_file: StructureFile
+ coordinates_file: CoordinatesFile | None = None
+ bond_threshold: float = 5.0
+ no_atom_ids: bool = True
+ print_to_stdout: bool = False
+
+ def get_log_filename(self) -> Path:
+ return generate_output_log_path(self.structure_file.stem)
+
+ def get_inventory_filename(self) -> Path:
+ return generate_output_inventory_path(self.structure_file.stem)
+
+
+def generate_output_inventory_path(stem: str) -> Path:
+ return Path(stem + DEFAULT_OUTPUT_STEM_SUFFIX + ".json")
+
+
+def generate_output_log_path(stem: str) -> Path:
+ return Path(stem + DEFAULT_OUTPUT_STEM_SUFFIX + ".log")
diff --git a/src/grodecoder/core.py b/src/grodecoder/core.py
new file mode 100644
index 0000000..986e45b
--- /dev/null
+++ b/src/grodecoder/core.py
@@ -0,0 +1,32 @@
+from datetime import datetime
+
+from loguru import logger
+
+from ._typing import PathLike, UniverseLike
+from .identifier import identify
+from .io import read_universe
+from .models import Decoded
+from .toputils import guess_resolution
+
+
+def _now() -> str:
+ """Returns the current date and time formatted string."""
+ return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
+
+def decode(universe: UniverseLike, bond_threshold: float = 5.0) -> Decoded:
+ """Decodes the universe into an inventory of segments."""
+ return Decoded(
+ inventory=identify(universe, bond_threshold=bond_threshold),
+ resolution=guess_resolution(universe),
+ )
+
+
+def decode_structure(
+ structure_path: PathLike, coordinates_path: PathLike | None = None, bond_threshold: float = 5.0
+) -> Decoded:
+ """Reads a structure file and decodes it into an inventory of segments."""
+ universe = read_universe(structure_path, coordinates_path)
+ assert universe.atoms is not None # required by type checker for some reason
+ logger.debug(f"Universe has {len(universe.atoms):,d} atoms")
+ return decode(universe, bond_threshold=bond_threshold)
diff --git a/src/grodecoder/databases/__init__.py b/src/grodecoder/databases/__init__.py
index de174db..082b311 100644
--- a/src/grodecoder/databases/__init__.py
+++ b/src/grodecoder/databases/__init__.py
@@ -1,29 +1,32 @@
+from functools import lru_cache
+
from .api import (
+ DATABASES_DATA_PATH,
get_amino_acid_definitions,
- get_amino_acid_names,
get_amino_acid_name_map,
+ get_amino_acid_names,
get_ion_definitions,
get_ion_names,
get_lipid_definitions,
get_lipid_names,
get_nucleotide_definitions,
- get_nucleotide_names,
get_nucleotide_name_map,
+ get_nucleotide_names,
get_other_definitions,
get_solvent_definitions,
get_solvent_names,
- DATABASES_DATA_PATH,
)
from .models import (
- ResidueDefinition,
AminoAcid,
Ion,
Lipid,
Nucleotide,
+ ResidueDefinition,
Solvent,
)
__all__ = [
+ "get_database_version",
"get_amino_acid_definitions",
"get_amino_acid_names",
"get_amino_acid_name_map",
@@ -46,5 +49,7 @@
]
-with open(DATABASES_DATA_PATH / "version.txt") as f:
- __version__ = f.read().strip()
+@lru_cache(maxsize=1)
+def get_database_version() -> str:
+ with open(DATABASES_DATA_PATH / "version.txt") as f:
+ return f.read().strip()
diff --git a/src/grodecoder/databases/api.py b/src/grodecoder/databases/api.py
index eb4b8ef..5a4df93 100644
--- a/src/grodecoder/databases/api.py
+++ b/src/grodecoder/databases/api.py
@@ -1,9 +1,8 @@
-import itertools
import json
from collections import Counter
-from dataclasses import dataclass
+from functools import lru_cache
from pathlib import Path
-from typing import Iterable, TypeVar
+from typing import TypeVar
from loguru import logger
@@ -96,18 +95,32 @@ def read_csml_database() -> list[csml.Residue]:
return _read_database(CSML_DB_PATH, csml.Residue)
-ION_DB: list[Ion] = read_ion_database()
-SOLVENT_DB: list[Solvent] = read_solvent_database()
-AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database()
-NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database()
+# ION_DB: list[Ion] = read_ion_database()
+# SOLVENT_DB: list[Solvent] = read_solvent_database()
+# AMINO_ACIDS_DB: list[AminoAcid] = read_amino_acid_database()
+# NUCLEOTIDES_DB: list[Nucleotide] = read_nucleotide_database()
+# MAD_DB: list[mad.Residue] = read_mad_database()
+# CSML_DB: list[csml.Residue] = read_csml_database()
+# LIPID_DB: list[Lipid] = _build_lipid_db()
+# OTHER_DB: list[Residue] = _build_other_db()
-MAD_DB: list[mad.Residue] = read_mad_database()
-CSML_DB: list[csml.Residue] = read_csml_database()
+
+@lru_cache(maxsize=1)
+def _get_csml_databse():
+ return read_csml_database()
+
+
+@lru_cache(maxsize=1)
+def _get_mad_databse():
+ return read_mad_database()
def _build_lipid_db() -> list[Lipid]:
- _mad_lipid_resnames = {item.alias for item in MAD_DB if item.family == mad.ResidueFamily.LIPID}
- _csml_lipid_resnames = {residue.name for residue in CSML_DB if residue.family == csml.ResidueFamily.LIPID}
+ mad_db = _get_mad_databse()
+ csml_db = _get_csml_databse()
+
+ _mad_lipid_resnames = {item.alias for item in mad_db if item.family == mad.ResidueFamily.LIPID}
+ _csml_lipid_resnames = {residue.name for residue in csml_db if residue.family == csml.ResidueFamily.LIPID}
if False:
# IMPORTANT:
@@ -121,27 +134,27 @@ def _build_lipid_db() -> list[Lipid]:
db = {
item.alias: Lipid(description=item.name, residue_name=item.alias)
- for item in MAD_DB
+ for item in mad_db
if item.family == mad.ResidueFamily.LIPID
}
db.update(
{
residue.name: Lipid(description=residue.description, residue_name=residue.name)
- for residue in CSML_DB
+ for residue in csml_db
if residue.family == csml.ResidueFamily.LIPID
}
)
return list(db.values())
-LIPID_DB: list[Lipid] = _build_lipid_db()
-
-
def _build_other_db() -> list[Residue]:
"""Builds a database of other residues that are not ions, solvents, amino acids, or nucleotides."""
+ mad_db = _get_mad_databse()
+ csml_db = _get_csml_databse()
+
csml_other = {
residue.name: Residue(residue_name=residue.name, description=residue.description)
- for residue in CSML_DB
+ for residue in csml_db
if residue.family
not in {
csml.ResidueFamily.PROTEIN,
@@ -154,7 +167,7 @@ def _build_other_db() -> list[Residue]:
mad_other = {
residue.alias: Residue(residue_name=residue.alias, description=residue.name)
- for residue in MAD_DB
+ for residue in mad_db
if residue.family
not in {
mad.ResidueFamily.PROTEIN,
@@ -168,124 +181,72 @@ def _build_other_db() -> list[Residue]:
return list(by_name.values())
-OTHER_DB: list[Residue] = _build_other_db()
-
-
-def get_other_definitions() -> list[Residue]:
- """Returns the definitions of other residues in the database."""
- return OTHER_DB
-
-
+@lru_cache(maxsize=1)
def get_ion_definitions() -> list[Ion]:
"""Returns the definitions of the ions in the database."""
- return ION_DB
+ return read_ion_database()
+@lru_cache(maxsize=1)
def get_solvent_definitions() -> list[Solvent]:
"""Returns the definitions of the solvents in the database."""
- return SOLVENT_DB
+ return read_solvent_database()
+@lru_cache(maxsize=1)
def get_amino_acid_definitions() -> list[AminoAcid]:
"""Returns the definitions of the amino acids in the database."""
- return AMINO_ACIDS_DB
-
-
-def get_amino_acid_name_map() -> dict[str, str]:
- """Returns a mapping of amino acid 3-letter names to 1-letter names."""
- return {aa.long_name: aa.short_name for aa in AMINO_ACIDS_DB}
-
-
-def get_nucleotide_name_map() -> dict[str, str]:
- """Returns a mapping of nucleotide 3-letter names to 1-letter names."""
- return {nucleotide.residue_name: nucleotide.short_name for nucleotide in NUCLEOTIDES_DB}
+ return read_amino_acid_database()
+@lru_cache(maxsize=1)
def get_nucleotide_definitions() -> list[Nucleotide]:
"""Returns the definitions of the nucleotides in the database."""
- return NUCLEOTIDES_DB
+ return read_nucleotide_database()
+@lru_cache(maxsize=1)
def get_lipid_definitions() -> list[Lipid]:
"""Returns the definitions of the lipids in the database."""
- return LIPID_DB
+ return _build_lipid_db()
+
+
+@lru_cache(maxsize=1)
+def get_other_definitions() -> list[Residue]:
+ """Returns the definitions of other residues in the database."""
+ return _build_other_db()
+
+
+def get_amino_acid_name_map() -> dict[str, str]:
+ """Returns a mapping of amino acid 3-letter names to 1-letter names."""
+ return {aa.long_name: aa.short_name for aa in get_amino_acid_definitions()}
+
+
+def get_nucleotide_name_map() -> dict[str, str]:
+ """Returns a mapping of nucleotide 3-letter names to 1-letter names."""
+ return {nucleotide.residue_name: nucleotide.short_name for nucleotide in get_nucleotide_definitions()}
def get_ion_names() -> set[str]:
"""Returns the names of the ions in the database."""
- return set(ion.residue_name for ion in ION_DB)
+ return set(ion.residue_name for ion in get_ion_definitions())
def get_solvent_names() -> set[str]:
"""Returns the names of the solvents in the database."""
- return set(solvent.residue_name for solvent in SOLVENT_DB)
+ return set(solvent.residue_name for solvent in get_solvent_definitions())
def get_amino_acid_names() -> set[str]:
"""Returns the names of the amino acids in the database."""
- return set(aa.long_name for aa in AMINO_ACIDS_DB)
+ return set(aa.long_name for aa in get_amino_acid_definitions())
def get_nucleotide_names() -> set[str]:
"""Returns the names of the nucleotides in the database."""
- return set(nucleotide.residue_name for nucleotide in NUCLEOTIDES_DB)
+ return set(nucleotide.residue_name for nucleotide in get_nucleotide_definitions())
def get_lipid_names() -> set[str]:
"""Returns the names of the lipids in the database."""
- return set(lipid.residue_name for lipid in LIPID_DB)
-
-
-@dataclass(frozen=True)
-class ResidueDatabase:
- """Database of residues."""
-
- ions: list[Ion]
- solvents: list[Solvent]
- amino_acids: list[AminoAcid]
- nucleotides: list[Nucleotide]
-
- def __post_init__(self):
- names = {
- "ions": {ion.residue_name for ion in self.ions},
- "solvents": {solvent.residue_name for solvent in self.solvents},
- "amino_acids": {aa.long_name for aa in self.amino_acids},
- "nucleotides": {nucleotide.residue_name for nucleotide in self.nucleotides},
- }
-
- combinations = itertools.combinations(names.keys(), 2)
- for lhs, rhs in combinations:
- duplicates = names[lhs].intersection(names[rhs])
- if duplicates:
- logger.warning(
- f"Residue names {duplicates} are defined in multiple families: {lhs} and {rhs}"
- )
-
-
-class ResidueNotFound(Exception):
- """Raised when a residue with a given name and atom names is not found in the database."""
-
-
-class DuplicateResidue(Exception):
- """Raised when a residue with a given name and atom names is defined multiple times in the database."""
-
-
-def _find_using_atom_names(
- residue_name: str, atom_names: Iterable[str], database: list[Ion | Solvent]
-) -> Ion | Solvent | None:
- candidate_residues = [ion for ion in database if ion.residue_name == residue_name]
- if not candidate_residues:
- return None
-
- actual_atom_names = set(atom_names)
- matching_residues = [ion for ion in candidate_residues if set(ion.atom_names) == actual_atom_names]
-
- if len(matching_residues) == 0:
- raise ResidueNotFound(f"No residue '{residue_name}' found with atom names {actual_atom_names}")
-
- elif len(matching_residues) > 1:
- raise DuplicateResidue(
- f"Multiple residues '{residue_name}' found with atom names {actual_atom_names}"
- )
-
- return matching_residues[0]
+ return set(lipid.residue_name for lipid in get_lipid_definitions())
diff --git a/src/grodecoder/identifier.py b/src/grodecoder/identifier.py
index e9fca28..f7d4ee5 100644
--- a/src/grodecoder/identifier.py
+++ b/src/grodecoder/identifier.py
@@ -96,7 +96,7 @@ def _get_protein_segments(atoms: AtomGroup, bond_threshold: float = 5.0) -> list
"""Returns the protein segments in the universe."""
protein = _select_protein(atoms)
return [
- Segment(atoms=atoms, sequence=toputils.sequence(atoms), molecular_type=MolecularType.PROTEIN)
+ Segment(atoms=atoms, molecular_type=MolecularType.PROTEIN)
for atoms in _iter_chains(protein, bond_threshold)
]
@@ -105,7 +105,7 @@ def _get_nucleic_segments(atoms: AtomGroup, bond_threshold: float = 5.0) -> list
"""Returns the nucleic acid segments in the universe."""
nucleic = _select_nucleic(atoms)
return [
- Segment(atoms=atoms, sequence=toputils.sequence(atoms), molecular_type=MolecularType.NUCLEIC)
+ Segment(atoms=atoms, molecular_type=MolecularType.NUCLEIC)
for atoms in _iter_chains(nucleic, bond_threshold)
]
@@ -164,6 +164,10 @@ def _identify(universe: UniverseLike, bond_threshold: float = 5.0) -> Inventory:
# Remove identified atoms from the universe along the way to avoid double counting (e.g.
# 'MET' residues are counted first in the protein, then removed so not counted elsewhere).
+ # All ty: ignore[invalid-argument-type] in this block fix ty clear mistake:
+ # Expected `list[HasAtoms]`, found `list[Segment]`
+ # while `Segment` clearly satisfies the `HasAtoms` Protocol.
+
protein = _get_protein_segments(universe, bond_threshold=bond_threshold)
_log_identified_segments(protein, "protein")
universe = _remove_identified_atoms(universe, protein) # ty: ignore[invalid-argument-type]
diff --git a/src/grodecoder/io.py b/src/grodecoder/io.py
new file mode 100644
index 0000000..d8ff23b
--- /dev/null
+++ b/src/grodecoder/io.py
@@ -0,0 +1,32 @@
+"""Grodecoder read/write functions"""
+
+import json
+
+import MDAnalysis as mda
+from loguru import logger
+from MDAnalysis.exceptions import NoDataError
+
+from ._typing import PathLike
+from .models import GrodecoderRunOutputRead
+
+
+def read_universe(structure_path: PathLike, coordinates_path: PathLike | None = None) -> mda.Universe:
+ """Reads a structure file."""
+ source = (structure_path, coordinates_path) if coordinates_path else (structure_path,)
+ source_str = ", ".join(str(s) for s in source)
+ logger.debug(f"Reading universe from {source_str}")
+ universe: mda.Universe | None = None
+ try:
+ universe = mda.Universe(*source)
+ except Exception as e:
+ raise IOError("MDAnalysis error while reading universe") from e
+
+ if not hasattr(universe, "trajectory"):
+ raise NoDataError(f"no coordinates read from {source_str}")
+
+ return universe
+
+
+def read_grodecoder_output(path: PathLike) -> GrodecoderRunOutputRead:
+ with open(path) as fileobj:
+ return GrodecoderRunOutputRead.model_validate(json.load(fileobj))
diff --git a/src/grodecoder/logging.py b/src/grodecoder/logging.py
new file mode 100644
index 0000000..a32df47
--- /dev/null
+++ b/src/grodecoder/logging.py
@@ -0,0 +1,34 @@
+"""grodecoder logging utilities."""
+
+import sys
+import warnings
+from pathlib import Path
+
+from loguru import logger
+
+
+def setup_logging(logfile: Path, debug: bool = False):
+ """Sets up logging configuration."""
+ fmt = "{time:YYYY-MM-DD HH:mm:ss} {level}: {message}"
+ level = "DEBUG" if debug else "INFO"
+ logger.remove()
+ logger.add(sys.stderr, level=level, format=fmt, colorize=True)
+ logger.add(logfile, level=level, format=fmt, colorize=False, mode="w")
+
+ # Sets up loguru to capture warnings (typically MDAnalysis warnings)
+ def showwarning(message, *args, **kwargs):
+ logger.opt(depth=2).warning(message)
+
+ warnings.showwarning = showwarning # ty: ignore invalid-assignment
+
+
+def is_logging_debug() -> bool:
+ """Returns True if at least one logging handler is set to level DEBUG."""
+ return "DEBUG" in get_logging_level()
+
+
+def get_logging_level() -> list[str]:
+ """Returns the list of logging level names (one value per handler)."""
+ core_logger = logger._core # ty: ignore unresolved-attribute
+ level_dict = {level.no: level.name for level in core_logger.levels.values()}
+ return [level_dict[h.levelno] for h in core_logger.handlers.values()]
diff --git a/src/grodecoder/main.py b/src/grodecoder/main.py
new file mode 100644
index 0000000..d9e99d2
--- /dev/null
+++ b/src/grodecoder/main.py
@@ -0,0 +1,58 @@
+import hashlib
+import json
+import time
+from typing import TYPE_CHECKING
+
+from loguru import logger
+
+from ._typing import PathLike
+from .core import decode_structure
+from .databases import get_database_version
+from .models import GrodecoderRunOutput
+from .version import get_version
+
+if TYPE_CHECKING:
+ from .cli.args import Arguments as CliArgs
+
+
+def _get_checksum(structure_path: PathLike) -> str:
+ """Computes a checksum for the structure file."""
+ with open(structure_path, "rb") as f:
+ return hashlib.md5(f.read()).hexdigest()
+
+
+def main(args: "CliArgs"):
+ """Main function to process a structure file and count the molecules."""
+ start_time = time.perf_counter_ns()
+ structure_path = args.structure_file.path
+ coordinates_path = args.coordinates_file.path if args.coordinates_file else None
+
+ logger.info(f"Processing structure file: {structure_path}")
+
+ # Decoding.
+ decoded = decode_structure(
+ structure_path, coordinates_path=coordinates_path, bond_threshold=args.bond_threshold
+ )
+
+ output = GrodecoderRunOutput(
+ decoded=decoded,
+ structure_file_checksum=_get_checksum(structure_path),
+ database_version=get_database_version(),
+ grodecoder_version=get_version(),
+ )
+
+ # Serialization.
+ serialization_mode = "compact" if args.no_atom_ids else "full"
+
+ # Updates run time as late as possible.
+ output_json = output.model_dump(context={"serialization_mode": serialization_mode})
+ output_json["runtime_in_seconds"] = (time.perf_counter_ns() - start_time) / 1e9
+
+ # Output results: to stdout or writes to a file.
+ if args.print_to_stdout:
+ print(json.dumps(output_json, indent=2))
+ else:
+ inventory_filename = args.get_inventory_filename()
+ with open(inventory_filename, "w") as f:
+ f.write(json.dumps(output_json, indent=2))
+ logger.info(f"Results written to {inventory_filename}")
diff --git a/src/grodecoder/models.py b/src/grodecoder/models.py
index 252d557..1803ea1 100644
--- a/src/grodecoder/models.py
+++ b/src/grodecoder/models.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+from pydantic import model_validator
from enum import StrEnum
from typing import Protocol
@@ -15,6 +16,8 @@
model_serializer,
)
+from . import toputils
+
class MolecularResolution(StrEnum):
COARSE_GRAINED = "coarse-grained"
@@ -111,11 +114,36 @@ def serialize(self, handler: SerializerFunctionWrapHandler, info: SerializationI
return self_data
-class SmallMolecule(FrozenWithAtoms):
+class MolecularTypeMixin(BaseModel):
+ molecular_type: MolecularType
+
+ def is_ion(self) -> bool:
+ return self.molecular_type == MolecularType.ION
+
+ def is_lipid(self) -> bool:
+ return self.molecular_type == MolecularType.LIPID
+
+ def is_solvent(self) -> bool:
+ return self.molecular_type == MolecularType.SOLVENT
+
+ def is_unknown(self) -> bool:
+ return self.molecular_type == MolecularType.UNKNOWN
+
+ def is_other(self) -> bool:
+ """Alias for MolecularTypeMixin.is_unknown()"""
+ return self.is_unknown()
+
+ def is_protein(self) -> bool:
+ return self.molecular_type == MolecularType.PROTEIN
+
+ def is_nucleic(self) -> bool:
+ return self.molecular_type == MolecularType.NUCLEIC
+
+
+class SmallMolecule(MolecularTypeMixin, FrozenWithAtoms):
"""Small molecules are defined as residues with a single residue name."""
description: str
- molecular_type: MolecularType
@computed_field
@property
@@ -124,11 +152,13 @@ def name(self) -> str:
return self.atoms.residues[0].resname
-class Segment(FrozenWithAtoms):
+class Segment(MolecularTypeMixin, FrozenWithAtoms):
"""A segment is a group of atoms that are connected."""
- sequence: str
- molecular_type: MolecularType # likely to be protein or nucleic acid
+ @computed_field
+ @property
+ def sequence(self) -> str:
+ return toputils.sequence(self.atoms)
class Inventory(FrozenModel):
@@ -175,16 +205,22 @@ class BaseModelWithAtomsRead(FrozenModel):
number_of_atoms: int
number_of_residues: int
+ @model_validator(mode="after")
+ def check_number_of_atoms_is_valid(self):
+ if len(self.atoms) != self.number_of_atoms:
+ raise ValueError(
+ f"field `number_of_atoms` ({self.number_of_atoms}) does not match number of atom ids ({len(self.atoms)})"
+ )
+ return self
+
-class SmallMoleculeRead(BaseModelWithAtomsRead):
+class SmallMoleculeRead(MolecularTypeMixin, BaseModelWithAtomsRead):
name: str
description: str
- molecular_type: MolecularType
-class SegmentRead(BaseModelWithAtomsRead):
+class SegmentRead(MolecularTypeMixin, BaseModelWithAtomsRead):
sequence: str
- molecular_type: MolecularType
class InventoryRead(FrozenModel):
@@ -192,6 +228,15 @@ class InventoryRead(FrozenModel):
segments: list[SegmentRead]
total_number_of_atoms: int
+ @model_validator(mode="after")
+ def check_total_number_of_atoms(self):
+ n = sum(item.number_of_atoms for item in self.small_molecules + self.segments)
+ if self.total_number_of_atoms != n:
+ raise ValueError(
+ f"field `total_number_of_atoms` ({self.total_number_of_atoms}) does not add up with the rest of the inventory (found {n} atoms)"
+ )
+ return self
+
class DecodedRead(FrozenModel):
inventory: InventoryRead
diff --git a/src/grodecoder/toputils.py b/src/grodecoder/toputils.py
index abea32d..878c469 100644
--- a/src/grodecoder/toputils.py
+++ b/src/grodecoder/toputils.py
@@ -85,8 +85,7 @@ def detect_chains(universe: UniverseLike, cutoff: float = 5.0) -> list[tuple[int
Example
-------
- >>> from grodecoder import read_structure
- >>> universe = read_structure("3EAM.pdb")
+ >>> universe = MDAnalysis.Universe("path/to/structure_file.pdb")
>>> protein = universe.select_atoms("protein")
>>> chains = detect_chains(protein)
>>> for start, end in chains:
diff --git a/src/grodecoder/version.py b/src/grodecoder/version.py
new file mode 100644
index 0000000..6c0c00c
--- /dev/null
+++ b/src/grodecoder/version.py
@@ -0,0 +1,4 @@
+__version__ = "0.0.1"
+
+def get_version() -> str:
+ return __version__
diff --git a/tests/test_identifier_README.md b/tests/test_identifier_README.md
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000..6b5eba9
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,53 @@
+"""Tests for grodecoder.models."""
+
+import json
+from pathlib import Path
+
+import pytest
+
+from grodecoder.models import BaseModelWithAtomsRead, InventoryRead
+
+TEST_DATA_ROOT_DIR = Path(__file__).parent / "data" / "regression_data"
+TEST_DATA_EXPECTED_RESULTS_DIR = TEST_DATA_ROOT_DIR / "expected_results"
+EXPECTED_INVENTORY_FILES = list(TEST_DATA_EXPECTED_RESULTS_DIR.glob("*.json"))
+
+
+class TestBaseModelWithAtomsRead:
+ def test_number_of_atoms_validation_success(self):
+ try:
+ BaseModelWithAtomsRead(
+ atoms=[1, 2, 3],
+ number_of_atoms=3,
+ number_of_residues=1,
+ )
+ except Exception as exc:
+ assert False, f"exception was raised: {exc}"
+
+ def test_number_of_atoms_validation_fails(self):
+ with pytest.raises(ValueError):
+ BaseModelWithAtomsRead(
+ atoms=[1, 2, 3],
+ number_of_atoms=2, # wrong number of atoms
+ number_of_residues=1,
+ )
+
+
+def _read_json(path: str) -> dict:
+ with open(path, "rb") as f:
+ return json.load(f)
+
+
+class TestInventoryRead:
+ def test_total_number_of_atoms_success(self):
+ raw_json = _read_json(EXPECTED_INVENTORY_FILES[0])
+ try:
+ InventoryRead.model_validate(raw_json["inventory"])
+ except Exception as exc:
+ assert False, f"exception was raised: {exc}"
+
+ def test_total_number_of_atoms_fail(self):
+ raw_json = _read_json(EXPECTED_INVENTORY_FILES[0])
+ assert "total_number_of_atoms" in raw_json.get("inventory", {})
+ raw_json["inventory"]["total_number_of_atoms"] += 1000 # sets this field to wrong value
+ with pytest.raises(ValueError, match=r"field `total_number_of_atoms` \([0-9]+\) does not add up"):
+ InventoryRead.model_validate(raw_json["inventory"])