Skip to content

Commit 7d9ef1a

Browse files
Add pydantic =1 support (#114)
1 parent fe5d679 commit 7d9ef1a

File tree

9 files changed

+25
-138
lines changed

9 files changed

+25
-138
lines changed

.git_archival.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node: $Format:%H$
2+
node-date: $Format:%cI$
3+
describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$

.gitattributes

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
smee/_version.py export-subst
1+
.git_archival.txt export-subst

.github/workflows/ci.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ jobs:
2020
- name: Setup Conda Environment
2121
run: |
2222
apt update && apt install -y git make
23-
23+
2424
make env
2525
make lint
2626
make test
2727
make test-examples
2828
make docs-build
2929
30+
# TODO: Remove this line once pydantic 1.0 support is dropped
31+
mamba install --name smee --yes "pydantic <2"
32+
make test
33+
3034
- name: CodeCov
3135
uses: codecov/codecov-action@v4.1.1
3236
with:

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ format:
2323
$(CONDA_ENV_RUN) ruff check --fix --select I examples
2424

2525
test:
26-
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/
26+
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-append --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/
2727

2828
test-examples:
2929
$(CONDA_ENV_RUN) jupyter nbconvert --to notebook --execute $(EXAMPLES)

devtools/envs/base.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ dependencies:
1414
- openff-interchange-base >=0.3.17
1515

1616
- pytorch
17-
- pydantic
1817
- nnpops
1918

19+
- pydantic
20+
- pydantic-units
21+
2022
- networkx
2123

2224
# Optional packages

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ select = ["B","C","E","F","W","B9"]
2828
convention = "google"
2929

3030
[tool.coverage.run]
31-
omit = ["**/tests/*", "**/_version.py"]
31+
omit = ["**/tests/*"]
3232

3333
[tool.coverage.report]
3434
exclude_lines = [

smee/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
"""
2-
smee
3-
4-
Differentiably evaluate energies of molecules using SMIRNOFF force fields
5-
"""
1+
"""Differentiably evaluate energies of molecules using SMIRNOFF force fields"""
62

73
import importlib.metadata
84

smee/mm/_config.py

Lines changed: 10 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,25 @@
11
"""Configuration from MM simulations."""
22

3-
import functools
4-
import typing
5-
6-
import openff.units
73
import openmm.unit
84
import pydantic
9-
import pydantic_core
5+
from pydantic_units import OpenMMQuantity, quantity_serializer
106

117
_KCAL_PER_MOL = openmm.unit.kilocalories_per_mole
128
_ANGSTROM = openmm.unit.angstrom
139
_GRAMS_PER_ML = openmm.unit.grams / openmm.unit.milliliters
1410

1511

16-
def _quantity_validator(
17-
value: str | openmm.unit.Quantity | openff.units.unit.Quantity,
18-
expected_units: openmm.unit.Unit,
19-
) -> openmm.unit.Quantity:
20-
if isinstance(value, str):
21-
value = openff.units.Quantity(value)
22-
if isinstance(value, openff.units.Quantity):
23-
value = openff.units.openmm.to_openmm(value)
24-
25-
assert isinstance(value, openmm.unit.Quantity), f"invalid type - {type(value)}"
26-
27-
try:
28-
return value.in_units_of(expected_units)
29-
except TypeError as e:
30-
raise ValueError(
31-
f"invalid units {value.unit} - expected {expected_units}"
32-
) from e
33-
34-
35-
def _quantity_serializer(value: openmm.unit.Quantity) -> str:
36-
unit_str = openff.units.openmm.openmm_unit_to_string(value.unit)
37-
return f"{value.value_in_unit(value.unit):.8f} {unit_str}"
38-
39-
40-
class _OpenMMQuantityAnnotation:
41-
@classmethod
42-
def __get_pydantic_core_schema__(
43-
cls,
44-
_source_type: typing.Any,
45-
_handler: pydantic.GetCoreSchemaHandler,
46-
) -> pydantic_core.core_schema.CoreSchema:
47-
from_value_schema = pydantic_core.core_schema.no_info_plain_validator_function(
48-
lambda x: x
49-
)
50-
51-
return pydantic_core.core_schema.json_or_python_schema(
52-
json_schema=from_value_schema,
53-
python_schema=from_value_schema,
54-
serialization=pydantic_core.core_schema.plain_serializer_function_ser_schema(
55-
_quantity_serializer
56-
),
57-
)
58-
59-
@classmethod
60-
def __get_pydantic_json_schema__(
61-
cls,
62-
_core_schema: pydantic_core.core_schema.CoreSchema,
63-
handler: pydantic.GetJsonSchemaHandler,
64-
) -> "pydantic.json_schema.JsonSchemaValue":
65-
return handler(pydantic_core.core_schema.str_schema())
66-
67-
68-
class _OpenMMQuantityMeta(type):
69-
def __getitem__(cls, item: openmm.unit.Unit):
70-
validator = functools.partial(_quantity_validator, expected_units=item)
71-
return typing.Annotated[
72-
openmm.unit.Quantity,
73-
_OpenMMQuantityAnnotation,
74-
pydantic.BeforeValidator(validator),
75-
]
76-
77-
78-
class OpenMMQuantity(openmm.unit.Quantity, metaclass=_OpenMMQuantityMeta):
79-
"""A pydantic safe OpenMM quantity type validates unit compatibility."""
12+
if pydantic.__version__.startswith("1."):
8013

14+
class BaseModel(pydantic.BaseModel):
15+
class Config:
16+
json_encoders = {openmm.unit.Quantity: quantity_serializer}
8117

82-
if typing.TYPE_CHECKING:
83-
OpenMMQuantity = openmm.unit.Quantity # noqa: F811
18+
else:
19+
BaseModel = pydantic.BaseModel
8420

8521

86-
class GenerateCoordsConfig(pydantic.BaseModel):
22+
class GenerateCoordsConfig(BaseModel):
8723
"""Configure how coordinates should be generated for a system using PACKMOL."""
8824

8925
target_density: OpenMMQuantity[_GRAMS_PER_ML] = pydantic.Field(
@@ -113,7 +49,7 @@ class GenerateCoordsConfig(pydantic.BaseModel):
11349
)
11450

11551

116-
class MinimizationConfig(pydantic.BaseModel):
52+
class MinimizationConfig(BaseModel):
11753
"""Configure how a system should be energy minimized."""
11854

11955
tolerance: OpenMMQuantity[_KCAL_PER_MOL / _ANGSTROM] = pydantic.Field(
@@ -128,7 +64,7 @@ class MinimizationConfig(pydantic.BaseModel):
12864
)
12965

13066

131-
class SimulationConfig(pydantic.BaseModel):
67+
class SimulationConfig(BaseModel):
13268
temperature: OpenMMQuantity[openmm.unit.kelvin] = pydantic.Field(
13369
...,
13470
description="The temperature to simulate at.",

smee/tests/mm/test_config.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)