Skip to content

Commit 3e109e2

Browse files
3.2.2.post2 (#276)
1 parent c896919 commit 3e109e2

24 files changed

+2199
-55
lines changed

cookbook/tutorials/2_embed.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,18 @@
4949
"cell_type": "markdown",
5050
"metadata": {},
5151
"source": [
52-
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
52+
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
5353
]
5454
},
5555
{
5656
"cell_type": "code",
57-
"execution_count": 1,
57+
"execution_count": null,
5858
"metadata": {},
5959
"outputs": [],
6060
"source": [
6161
"from getpass import getpass\n",
6262
"\n",
63-
"token = getpass(\"Token from Forge console: \")"
63+
"token = getpass(\"Token from Forge: \")"
6464
]
6565
},
6666
{

cookbook/tutorials/3_gfp_design.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,18 @@
8080
"\n",
8181
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
8282
"\n",
83-
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
83+
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
8484
]
8585
},
8686
{
8787
"cell_type": "code",
88-
"execution_count": 3,
88+
"execution_count": null,
8989
"metadata": {
9090
"id": "zNrU9Q2SYonX"
9191
},
9292
"outputs": [],
9393
"source": [
94-
"token = getpass(\"Token from Forge console: \")"
94+
"token = getpass(\"Token from Forge: \")"
9595
]
9696
},
9797
{

cookbook/tutorials/4_forge_generate.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"cell_type": "markdown",
5454
"metadata": {},
5555
"source": [
56-
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
56+
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
5757
]
5858
},
5959
{
@@ -64,7 +64,7 @@
6464
"source": [
6565
"from getpass import getpass\n",
6666
"\n",
67-
"token = getpass(\"Token from Forge console: \")\n",
67+
"token = getpass(\"Token from Forge: \")\n",
6868
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
6969
]
7070
},

cookbook/tutorials/5_guided_generation.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
"\n",
121121
"from esm.sdk import client\n",
122122
"\n",
123-
"token = getpass(\"Token from Forge console: \")\n",
123+
"token = getpass(\"Token from Forge: \")\n",
124124
"model = client(\n",
125125
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
126126
")"

esm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.2.2"
1+
__version__ = "3.2.2.post2"

esm/sdk/api.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,43 @@ def to_protein_complex(
148148
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
149149
else:
150150
gt_chains = None
151+
152+
# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
153+
# This handles the case where the server doesn't include chain breaks in pLDDT
154+
# We should fix this in the server side.
155+
if self.plddt is not None and len(self.plddt) != len(self.sequence):
156+
# Only expand if there's a mismatch (likely due to chain breaks)
157+
if "|" in self.sequence:
158+
# Create expanded pLDDT with NaN at chain break positions
159+
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
160+
plddt_idx = 0
161+
for i, aa in enumerate(self.sequence):
162+
if aa != "|":
163+
if plddt_idx < len(self.plddt):
164+
expanded_plddt[i] = self.plddt[plddt_idx]
165+
plddt_idx += 1
166+
plddt = expanded_plddt
167+
else:
168+
# Mismatch but no chain breaks - shouldn't happen but preserve original
169+
plddt = self.plddt
170+
else:
171+
plddt = self.plddt
172+
151173
pred_chains = []
152174
for i, (start, end) in enumerate(chain_boundaries):
153175
if i >= len(SINGLE_LETTER_CHAIN_IDS):
154176
raise ValueError(
155177
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
156178
)
179+
157180
pred_chain = ProteinChain.from_atom37(
158181
atom37_positions=coords[start:end],
159182
sequence=self.sequence[start:end],
160183
chain_id=gt_chains[i].chain_id
161184
if gt_chains is not None
162185
else SINGLE_LETTER_CHAIN_IDS[i],
163186
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
164-
confidence=self.plddt[start:end] if self.plddt is not None else None,
187+
confidence=plddt[start:end] if plddt is not None else None,
165188
)
166189
pred_chains.append(pred_chain)
167190
return ProteinComplex.from_chains(pred_chains)
@@ -298,13 +321,6 @@ def use_generative_unmasking_strategy(self):
298321
self.temperature_annealing = True
299322

300323

301-
@define
302-
class MSA:
303-
# Paired MSA sequences.
304-
# One would typically compute these using, for example, ColabFold.
305-
sequences: list[str]
306-
307-
308324
@define
309325
class InverseFoldingConfig:
310326
invalid_ids: Sequence[int] = []

esm/sdk/forge.py

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import base64
35
import pickle
@@ -7,7 +9,6 @@
79
import torch
810

911
from esm.sdk.api import (
10-
MSA,
1112
ESM3InferenceClient,
1213
ESMCInferenceClient,
1314
ESMProtein,
@@ -27,6 +28,15 @@
2728
from esm.sdk.retry import retry_decorator
2829
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
2930
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
31+
from esm.utils.msa import MSA
32+
from esm.utils.structure.input_builder import (
33+
StructurePredictionInput,
34+
serialize_structure_prediction_input,
35+
)
36+
from esm.utils.structure.molecular_complex import (
37+
MolecularComplex,
38+
MolecularComplexResult,
39+
)
3040
from esm.utils.types import FunctionAnnotation
3141

3242

@@ -36,10 +46,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
3646
return [FunctionAnnotation(*t) for t in l]
3747

3848

39-
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
40-
ret = data.get("logits", {}).get(track, None)
41-
# TODO(s22chan): just return this when removing return_bytes
42-
return ret if ret is None or not return_bytes else maybe_tensor(ret)
49+
def _maybe_logits(data: dict[str, Any], track: str):
50+
return maybe_tensor(data.get("logits", {}).get(track, None))
4351

4452

4553
def _maybe_b64_decode(obj, return_bytes: bool):
@@ -137,7 +145,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA:
137145
data = await self._async_post(
138146
"msa", request={}, params={"sequence": sequence, "use_env": False}
139147
)
140-
return MSA(sequences=data["msa"])
148+
return MSA.from_sequences(sequences=data["msa"])
141149

142150
def _fetch_msa(self, sequence: str) -> MSA:
143151
print("Fetching MSA ... this may take a few minutes")
@@ -146,7 +154,7 @@ def _fetch_msa(self, sequence: str) -> MSA:
146154
data = self._post(
147155
"msa", request={}, params={"sequence": sequence, "use_env": False}
148156
)
149-
return MSA(sequences=data["msa"])
157+
return MSA.from_sequences(sequences=data["msa"])
150158

151159
@retry_decorator
152160
async def async_fold(
@@ -209,6 +217,70 @@ def fold(
209217

210218
return self._process_fold_response(data, sequence)
211219

220+
@retry_decorator
221+
async def async_fold_all_atom(
222+
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
223+
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
224+
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
225+
226+
Args:
227+
all_atom_input: StructurePredictionInput containing sequences for different molecule types
228+
model_name: Override the client level model name if needed
229+
"""
230+
request = self._process_fold_all_atom_request(
231+
all_atom_input, model_name if model_name is not None else self.model
232+
)
233+
234+
try:
235+
data = await self._async_post("fold_all_atom", request)
236+
except ESMProteinError as e:
237+
return e
238+
239+
return self._process_fold_all_atom_response(data)
240+
241+
@retry_decorator
242+
def fold_all_atom(
243+
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
244+
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
245+
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
246+
247+
Args:
248+
all_atom_input: StructurePredictionInput containing sequences for different molecule types
249+
model_name: Override the client level model name if needed
250+
"""
251+
request = self._process_fold_all_atom_request(
252+
all_atom_input, model_name if model_name is not None else self.model
253+
)
254+
255+
try:
256+
data = self._post("fold_all_atom", request)
257+
except ESMProteinError as e:
258+
return e
259+
260+
return self._process_fold_all_atom_response(data)
261+
262+
@staticmethod
263+
def _process_fold_all_atom_request(
264+
all_atom_input: StructurePredictionInput, model_name: str | None = None
265+
) -> dict[str, Any]:
266+
request: dict[str, Any] = {
267+
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
268+
"model": model_name,
269+
}
270+
271+
return request
272+
273+
@staticmethod
274+
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
275+
complex_data = data.get("complex")
276+
molecular_complex = MolecularComplex.from_state_dict(complex_data)
277+
return MolecularComplexResult(
278+
complex=molecular_complex,
279+
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
280+
ptm=data.get("ptm", None),
281+
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
282+
)
283+
212284
@retry_decorator
213285
async def async_inverse_fold(
214286
self,
@@ -602,19 +674,15 @@ def _process_logits_response(
602674

603675
return LogitsOutput(
604676
logits=ForwardTrackData(
605-
sequence=_maybe_logits(data, "sequence", return_bytes),
606-
structure=_maybe_logits(data, "structure", return_bytes),
607-
secondary_structure=_maybe_logits(
608-
data, "secondary_structure", return_bytes
609-
),
610-
sasa=_maybe_logits(data, "sasa", return_bytes),
611-
function=_maybe_logits(data, "function", return_bytes),
677+
sequence=_maybe_logits(data, "sequence"),
678+
structure=_maybe_logits(data, "structure"),
679+
secondary_structure=_maybe_logits(data, "secondary_structure"),
680+
sasa=_maybe_logits(data, "sasa"),
681+
function=_maybe_logits(data, "function"),
612682
),
613683
embeddings=maybe_tensor(data["embeddings"]),
614684
mean_embedding=data["mean_embedding"],
615-
residue_annotation_logits=_maybe_logits(
616-
data, "residue_annotation", return_bytes
617-
),
685+
residue_annotation_logits=_maybe_logits(data, "residue_annotation"),
618686
hidden_states=maybe_tensor(data["hidden_states"]),
619687
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
620688
)
@@ -965,6 +1033,7 @@ def _process_logits_request(
9651033
"sequence": config.sequence,
9661034
"return_embeddings": config.return_embeddings,
9671035
"return_mean_embedding": config.return_mean_embedding,
1036+
"return_mean_hidden_states": config.return_mean_hidden_states,
9681037
"return_hidden_states": config.return_hidden_states,
9691038
"ith_hidden_layer": config.ith_hidden_layer,
9701039
}
@@ -981,12 +1050,11 @@ def _process_logits_response(
9811050
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
9821051

9831052
output = LogitsOutput(
984-
logits=ForwardTrackData(
985-
sequence=_maybe_logits(data, "sequence", return_bytes)
986-
),
1053+
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
9871054
embeddings=maybe_tensor(data["embeddings"]),
9881055
mean_embedding=data["mean_embedding"],
9891056
hidden_states=maybe_tensor(data["hidden_states"]),
1057+
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
9901058
)
9911059
return output
9921060

esm/sdk/retry.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from contextvars import ContextVar
33
from functools import wraps
44

5-
import httpx
65
from tenacity import (
76
retry,
8-
retry_if_exception_type,
7+
retry_if_exception,
98
retry_if_result,
109
stop_after_attempt,
1110
wait_incrementing,
@@ -30,8 +29,12 @@ def retry_if_specific_error(exception):
3029

3130

3231
def log_retry_attempt(retry_state):
32+
try:
33+
outcome = retry_state.outcome.result()
34+
except Exception:
35+
outcome = retry_state.outcome.exception()
3336
print(
34-
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
37+
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}"
3538
)
3639

3740

@@ -41,13 +44,18 @@ def retry_decorator(func):
4144
instance's retry settings.
4245
"""
4346

47+
def return_last_value(retry_state):
48+
"""Return the result of the last call attempt."""
49+
return retry_state.outcome.result()
50+
4451
@wraps(func)
4552
async def async_wrapper(instance, *args, **kwargs):
4653
if skip_retries_var.get():
4754
return await func(instance, *args, **kwargs)
4855
retry_decorator = retry(
56+
retry_error_callback=return_last_value,
4957
retry=retry_if_result(retry_if_specific_error)
50-
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
58+
| retry_if_exception(retry_if_specific_error),
5159
wait=wait_incrementing(
5260
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
5361
),
@@ -62,8 +70,9 @@ def wrapper(instance, *args, **kwargs):
6270
if skip_retries_var.get():
6371
return func(instance, *args, **kwargs)
6472
retry_decorator = retry(
73+
retry_error_callback=return_last_value,
6574
retry=retry_if_result(retry_if_specific_error)
66-
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
75+
| retry_if_exception(retry_if_specific_error),
6776
wait=wait_incrementing(
6877
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
6978
),

esm/utils/generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
4343

4444
sliced = {}
4545
for k, v in attr.asdict(o, recurse=False).items():
46-
if v is None:
46+
if k in ["mean_hidden_state", "mean_embedding"]:
47+
sliced[k] = v
48+
elif v is None:
4749
sliced[k] = None
4850
elif isinstance(v, torch.Tensor):
4951
# Trim padding.

0 commit comments

Comments
 (0)