Ishaan/infer oxygen (#283)

This commit is contained in:
Ishaan Mathur
2025-10-22 12:19:25 -04:00
committed by GitHub
parent 23b084b94f
commit 232a1041d8
10 changed files with 134 additions and 148 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
__version__ = "3.2.3"
__version__ = "3.2.4.a0"

View File

@@ -77,6 +77,7 @@ class ESMProtein(ProteinType):
sasa=protein_chain.sasa().tolist(),
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)
else:
return ESMProtein(
@@ -85,6 +86,7 @@ class ESMProtein(ProteinType):
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)
@classmethod
@@ -104,6 +106,7 @@ class ESMProtein(ProteinType):
coordinates=torch.tensor(
protein_complex.atom37_positions, dtype=torch.float32
),
plddt=torch.tensor(protein_complex.confidence),
)
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
@@ -325,7 +328,9 @@ class GenerationConfig:
@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
temperature: float = 1.0
temperature: float = 0.1
seed: int | None = None
decode_in_residue_index_order: bool = False
## Low Level Endpoint Types

View File

@@ -119,6 +119,8 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
inverse_folding_config = {
"invalid_ids": config.invalid_ids,
"temperature": config.temperature,
"seed": config.seed,
"decode_in_residue_index_order": config.decode_in_residue_index_order,
}
request = {
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),

View File

@@ -707,8 +707,9 @@ class MolecularComplex:
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
atom_array.res_name = np.array(atom_res_names, dtype="U4")
atom_array.hetero = atom_hetero
atom_array.b_factor = atom_bfactors
atom_array.atom_name = np.array(atom_names, dtype="U4")
atom_array.add_annotation("b_factor", dtype=float)
atom_array.b_factor = atom_bfactors
# Use existing elements or infer them from atom names
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:

View File

@@ -1121,7 +1121,9 @@ class ProteinChain:
def infer_oxygen(self) -> ProteinChain:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

View File

@@ -562,7 +562,9 @@ class ProteinComplex:
def infer_oxygen(self) -> ProteinComplex:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

103
pixi.lock
View File

@@ -211,7 +211,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/36/63/0722e153fd27d64d5b0af45b5c8cb0e80b35a68cf0130303bc9a8bb095c7/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl
@@ -397,7 +397,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/08/b8/2bc2590a34c733ea0570f366e6ad7d889d05c7825bd3ccab01f36ece71c6/zstd-1.5.7.2-cp312-cp312-macosx_11_0_arm64.whl
@@ -698,7 +698,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/36/63/0722e153fd27d64d5b0af45b5c8cb0e80b35a68cf0130303bc9a8bb095c7/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl
@@ -911,7 +911,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/08/b8/2bc2590a34c733ea0570f366e6ad7d889d05c7825bd3ccab01f36ece71c6/zstd-1.5.7.2-cp312-cp312-macosx_11_0_arm64.whl
@@ -1726,13 +1726,13 @@ packages:
requires_python: '>=3.8'
- pypi: ./
name: esm
version: 3.2.3
sha256: 7f3df1026fb23f4812615d3c4968f643f04d9cbf7735000615b011620ac83007
version: 3.2.4a0
sha256: ef9ea6c382db370d0914aa4e9893c60d57a4a30c6a41307d7bb38f791ff8ecbd
requires_dist:
- torch>=2.2.0
- torchvision
- torchtext
- transformers<4.48.2
- transformers==4.52.4
- ipython
- einops
- biotite>=1.0.0
@@ -6306,35 +6306,28 @@ packages:
- pytest-mypy-testing ; extra == 'test'
- pytest>=7.0,<8.2 ; extra == 'test'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/7b/9f/92d3091c44cb19add044064af1bf1345cd35fbb84d32a3690f912800a295/transformers-4.48.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl
name: transformers
version: 4.48.1
sha256: 24be0564b0a36d9e433d9a65de248f1545b6f6edce1737669605eb6a8141bbbb
version: 4.52.4
sha256: 203f5c19416d5877e36e88633943761719538a25d9775977a24fe77a1e5adfc7
requires_dist:
- filelock
- huggingface-hub>=0.24.0,<1.0
- huggingface-hub>=0.30.0,<1.0
- numpy>=1.17
- packaging>=20.0
- pyyaml>=5.1
- regex!=2019.12.17
- requests
- tokenizers>=0.21,<0.22
- safetensors>=0.4.1
- safetensors>=0.4.3
- tqdm>=4.27
- accelerate>=0.26.0 ; extra == 'accelerate'
- diffusers ; extra == 'agents'
- accelerate>=0.26.0 ; extra == 'agents'
- datasets!=2.5.0 ; extra == 'agents'
- torch>=2.0 ; extra == 'agents'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'agents'
- opencv-python ; extra == 'agents'
- pillow>=10.0.1,<=15.0 ; extra == 'agents'
- tensorflow>2.9,<2.16 ; extra == 'all'
- onnxconverter-common ; extra == 'all'
- tf2onnx ; extra == 'all'
- tensorflow-text<2.16 ; extra == 'all'
- keras-nlp>=0.3.1,<0.14.0 ; extra == 'all'
- torch>=2.0 ; extra == 'all'
- torch>=2.1,<2.7 ; extra == 'all'
- accelerate>=0.26.0 ; extra == 'all'
- jax>=0.4.1,<=0.4.13 ; extra == 'all'
- jaxlib>=0.4.1,<=0.4.13 ; extra == 'all'
@@ -6350,13 +6343,15 @@ packages:
- phonemizer ; extra == 'all'
- kenlm ; extra == 'all'
- pillow>=10.0.1,<=15.0 ; extra == 'all'
- kernels>=0.4.4,<0.5 ; extra == 'all'
- optuna ; extra == 'all'
- ray[tune]>=2.7.0 ; extra == 'all'
- sigopt ; extra == 'all'
- timm<=1.0.11 ; extra == 'all'
- torchvision ; extra == 'all'
- codecarbon>=2.8.1 ; extra == 'all'
- av==9.2.0 ; extra == 'all'
- av ; extra == 'all'
- num2words ; extra == 'all'
- librosa ; extra == 'audio'
- pyctcdecode>=0.4.0 ; extra == 'audio'
- phonemizer ; extra == 'audio'
@@ -6367,10 +6362,12 @@ packages:
- accelerate>=0.26.0 ; extra == 'deepspeed'
- deepspeed>=0.9.3 ; extra == 'deepspeed-testing'
- accelerate>=0.26.0 ; extra == 'deepspeed-testing'
- pytest>=7.2.0,<8.0.0 ; extra == 'deepspeed-testing'
- pytest>=7.2.0 ; extra == 'deepspeed-testing'
- pytest-asyncio ; extra == 'deepspeed-testing'
- pytest-rich ; extra == 'deepspeed-testing'
- pytest-xdist ; extra == 'deepspeed-testing'
- pytest-order ; extra == 'deepspeed-testing'
- pytest-rerunfailures ; extra == 'deepspeed-testing'
- timeout-decorator ; extra == 'deepspeed-testing'
- parameterized ; extra == 'deepspeed-testing'
- psutil ; extra == 'deepspeed-testing'
@@ -6378,8 +6375,7 @@ packages:
- dill<0.3.5 ; extra == 'deepspeed-testing'
- evaluate>=0.2.0 ; extra == 'deepspeed-testing'
- pytest-timeout ; extra == 'deepspeed-testing'
- ruff==0.5.1 ; extra == 'deepspeed-testing'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'deepspeed-testing'
- ruff==0.11.2 ; extra == 'deepspeed-testing'
- rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1 ; extra == 'deepspeed-testing'
- nltk<=3.8.1 ; extra == 'deepspeed-testing'
- gitpython<3.1.19 ; extra == 'deepspeed-testing'
@@ -6389,6 +6385,7 @@ packages:
- tensorboard ; extra == 'deepspeed-testing'
- pydantic ; extra == 'deepspeed-testing'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'deepspeed-testing'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'deepspeed-testing'
- faiss-cpu ; extra == 'deepspeed-testing'
- cookiecutter==1.7.3 ; extra == 'deepspeed-testing'
- optuna ; extra == 'deepspeed-testing'
@@ -6398,7 +6395,7 @@ packages:
- tf2onnx ; extra == 'dev'
- tensorflow-text<2.16 ; extra == 'dev'
- keras-nlp>=0.3.1,<0.14.0 ; extra == 'dev'
- torch>=2.0 ; extra == 'dev'
- torch>=2.1,<2.7 ; extra == 'dev'
- accelerate>=0.26.0 ; extra == 'dev'
- jax>=0.4.1,<=0.4.13 ; extra == 'dev'
- jaxlib>=0.4.1,<=0.4.13 ; extra == 'dev'
@@ -6414,17 +6411,21 @@ packages:
- phonemizer ; extra == 'dev'
- kenlm ; extra == 'dev'
- pillow>=10.0.1,<=15.0 ; extra == 'dev'
- kernels>=0.4.4,<0.5 ; extra == 'dev'
- optuna ; extra == 'dev'
- ray[tune]>=2.7.0 ; extra == 'dev'
- sigopt ; extra == 'dev'
- timm<=1.0.11 ; extra == 'dev'
- torchvision ; extra == 'dev'
- codecarbon>=2.8.1 ; extra == 'dev'
- av==9.2.0 ; extra == 'dev'
- pytest>=7.2.0,<8.0.0 ; extra == 'dev'
- av ; extra == 'dev'
- num2words ; extra == 'dev'
- pytest>=7.2.0 ; extra == 'dev'
- pytest-asyncio ; extra == 'dev'
- pytest-rich ; extra == 'dev'
- pytest-xdist ; extra == 'dev'
- pytest-order ; extra == 'dev'
- pytest-rerunfailures ; extra == 'dev'
- timeout-decorator ; extra == 'dev'
- parameterized ; extra == 'dev'
- psutil ; extra == 'dev'
@@ -6432,8 +6433,7 @@ packages:
- dill<0.3.5 ; extra == 'dev'
- evaluate>=0.2.0 ; extra == 'dev'
- pytest-timeout ; extra == 'dev'
- ruff==0.5.1 ; extra == 'dev'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev'
- ruff==0.11.2 ; extra == 'dev'
- rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1 ; extra == 'dev'
- nltk<=3.8.1 ; extra == 'dev'
- gitpython<3.1.19 ; extra == 'dev'
@@ -6442,6 +6442,7 @@ packages:
- beautifulsoup4 ; extra == 'dev'
- tensorboard ; extra == 'dev'
- pydantic ; extra == 'dev'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev'
- faiss-cpu ; extra == 'dev'
- cookiecutter==1.7.3 ; extra == 'dev'
- isort>=5.5.4 ; extra == 'dev'
@@ -6456,10 +6457,12 @@ packages:
- sudachidict-core>=20220729 ; extra == 'dev'
- rhoknp>=1.1.0,<1.3.1 ; extra == 'dev'
- scikit-learn ; extra == 'dev'
- pytest>=7.2.0,<8.0.0 ; extra == 'dev-tensorflow'
- pytest>=7.2.0 ; extra == 'dev-tensorflow'
- pytest-asyncio ; extra == 'dev-tensorflow'
- pytest-rich ; extra == 'dev-tensorflow'
- pytest-xdist ; extra == 'dev-tensorflow'
- pytest-order ; extra == 'dev-tensorflow'
- pytest-rerunfailures ; extra == 'dev-tensorflow'
- timeout-decorator ; extra == 'dev-tensorflow'
- parameterized ; extra == 'dev-tensorflow'
- psutil ; extra == 'dev-tensorflow'
@@ -6467,8 +6470,7 @@ packages:
- dill<0.3.5 ; extra == 'dev-tensorflow'
- evaluate>=0.2.0 ; extra == 'dev-tensorflow'
- pytest-timeout ; extra == 'dev-tensorflow'
- ruff==0.5.1 ; extra == 'dev-tensorflow'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev-tensorflow'
- ruff==0.11.2 ; extra == 'dev-tensorflow'
- rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1 ; extra == 'dev-tensorflow'
- nltk<=3.8.1 ; extra == 'dev-tensorflow'
- gitpython<3.1.19 ; extra == 'dev-tensorflow'
@@ -6478,6 +6480,7 @@ packages:
- tensorboard ; extra == 'dev-tensorflow'
- pydantic ; extra == 'dev-tensorflow'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'dev-tensorflow'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev-tensorflow'
- faiss-cpu ; extra == 'dev-tensorflow'
- cookiecutter==1.7.3 ; extra == 'dev-tensorflow'
- tensorflow>2.9,<2.16 ; extra == 'dev-tensorflow'
@@ -6499,10 +6502,12 @@ packages:
- pyctcdecode>=0.4.0 ; extra == 'dev-tensorflow'
- phonemizer ; extra == 'dev-tensorflow'
- kenlm ; extra == 'dev-tensorflow'
- pytest>=7.2.0,<8.0.0 ; extra == 'dev-torch'
- pytest>=7.2.0 ; extra == 'dev-torch'
- pytest-asyncio ; extra == 'dev-torch'
- pytest-rich ; extra == 'dev-torch'
- pytest-xdist ; extra == 'dev-torch'
- pytest-order ; extra == 'dev-torch'
- pytest-rerunfailures ; extra == 'dev-torch'
- timeout-decorator ; extra == 'dev-torch'
- parameterized ; extra == 'dev-torch'
- psutil ; extra == 'dev-torch'
@@ -6510,8 +6515,7 @@ packages:
- dill<0.3.5 ; extra == 'dev-torch'
- evaluate>=0.2.0 ; extra == 'dev-torch'
- pytest-timeout ; extra == 'dev-torch'
- ruff==0.5.1 ; extra == 'dev-torch'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev-torch'
- ruff==0.11.2 ; extra == 'dev-torch'
- rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1 ; extra == 'dev-torch'
- nltk<=3.8.1 ; extra == 'dev-torch'
- gitpython<3.1.19 ; extra == 'dev-torch'
@@ -6521,9 +6525,10 @@ packages:
- tensorboard ; extra == 'dev-torch'
- pydantic ; extra == 'dev-torch'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'dev-torch'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'dev-torch'
- faiss-cpu ; extra == 'dev-torch'
- cookiecutter==1.7.3 ; extra == 'dev-torch'
- torch>=2.0 ; extra == 'dev-torch'
- torch>=2.1,<2.7 ; extra == 'dev-torch'
- accelerate>=0.26.0 ; extra == 'dev-torch'
- protobuf ; extra == 'dev-torch'
- tokenizers>=0.21,<0.22 ; extra == 'dev-torch'
@@ -6533,6 +6538,7 @@ packages:
- phonemizer ; extra == 'dev-torch'
- kenlm ; extra == 'dev-torch'
- pillow>=10.0.1,<=15.0 ; extra == 'dev-torch'
- kernels>=0.4.4,<0.5 ; extra == 'dev-torch'
- optuna ; extra == 'dev-torch'
- ray[tune]>=2.7.0 ; extra == 'dev-torch'
- sigopt ; extra == 'dev-torch'
@@ -6553,6 +6559,7 @@ packages:
- scikit-learn ; extra == 'dev-torch'
- onnxruntime>=1.4.0 ; extra == 'dev-torch'
- onnxruntime-tools>=1.4.2 ; extra == 'dev-torch'
- num2words ; extra == 'dev-torch'
- jax>=0.4.1,<=0.4.13 ; extra == 'flax'
- jaxlib>=0.4.1,<=0.4.13 ; extra == 'flax'
- flax>=0.4.1,<=0.7.0 ; extra == 'flax'
@@ -6563,6 +6570,9 @@ packages:
- phonemizer ; extra == 'flax-speech'
- kenlm ; extra == 'flax-speech'
- ftfy ; extra == 'ftfy'
- hf-xet ; extra == 'hf-xet'
- kernels>=0.4.4,<0.5 ; extra == 'hub-kernels'
- kernels>=0.4.4,<0.5 ; extra == 'integrations'
- optuna ; extra == 'integrations'
- ray[tune]>=2.7.0 ; extra == 'integrations'
- sigopt ; extra == 'integrations'
@@ -6575,6 +6585,7 @@ packages:
- rhoknp>=1.1.0,<1.3.1 ; extra == 'ja'
- cookiecutter==1.7.3 ; extra == 'modelcreation'
- natten>=0.14.6,<0.15.0 ; extra == 'natten'
- num2words ; extra == 'num2words'
- onnxconverter-common ; extra == 'onnx'
- tf2onnx ; extra == 'onnx'
- onnxruntime>=1.4.0 ; extra == 'onnx'
@@ -6584,7 +6595,7 @@ packages:
- optuna ; extra == 'optuna'
- datasets!=2.5.0 ; extra == 'quality'
- isort>=5.5.4 ; extra == 'quality'
- ruff==0.5.1 ; extra == 'quality'
- ruff==0.11.2 ; extra == 'quality'
- gitpython<3.1.19 ; extra == 'quality'
- urllib3<2.0.0 ; extra == 'quality'
- libcst ; extra == 'quality'
@@ -6592,7 +6603,7 @@ packages:
- ray[tune]>=2.7.0 ; extra == 'ray'
- faiss-cpu ; extra == 'retrieval'
- datasets!=2.5.0 ; extra == 'retrieval'
- ruff==0.5.1 ; extra == 'ruff'
- ruff==0.11.2 ; extra == 'ruff'
- sagemaker>=2.31.0 ; extra == 'sagemaker'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'sentencepiece'
- protobuf ; extra == 'sentencepiece'
@@ -6607,10 +6618,12 @@ packages:
- pyctcdecode>=0.4.0 ; extra == 'speech'
- phonemizer ; extra == 'speech'
- kenlm ; extra == 'speech'
- pytest>=7.2.0,<8.0.0 ; extra == 'testing'
- pytest>=7.2.0 ; extra == 'testing'
- pytest-asyncio ; extra == 'testing'
- pytest-rich ; extra == 'testing'
- pytest-xdist ; extra == 'testing'
- pytest-order ; extra == 'testing'
- pytest-rerunfailures ; extra == 'testing'
- timeout-decorator ; extra == 'testing'
- parameterized ; extra == 'testing'
- psutil ; extra == 'testing'
@@ -6618,8 +6631,7 @@ packages:
- dill<0.3.5 ; extra == 'testing'
- evaluate>=0.2.0 ; extra == 'testing'
- pytest-timeout ; extra == 'testing'
- ruff==0.5.1 ; extra == 'testing'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'testing'
- ruff==0.11.2 ; extra == 'testing'
- rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1 ; extra == 'testing'
- nltk<=3.8.1 ; extra == 'testing'
- gitpython<3.1.19 ; extra == 'testing'
@@ -6629,6 +6641,7 @@ packages:
- tensorboard ; extra == 'testing'
- pydantic ; extra == 'testing'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'testing'
- sacrebleu>=1.4.12,<2.0.0 ; extra == 'testing'
- faiss-cpu ; extra == 'testing'
- cookiecutter==1.7.3 ; extra == 'testing'
- tensorflow>2.9,<2.16 ; extra == 'tf'
@@ -6651,7 +6664,7 @@ packages:
- blobfile ; extra == 'tiktoken'
- timm<=1.0.11 ; extra == 'timm'
- tokenizers>=0.21,<0.22 ; extra == 'tokenizers'
- torch>=2.0 ; extra == 'torch'
- torch>=2.1,<2.7 ; extra == 'torch'
- accelerate>=0.26.0 ; extra == 'torch'
- torchaudio ; extra == 'torch-speech'
- librosa ; extra == 'torch-speech'
@@ -6661,7 +6674,7 @@ packages:
- torchvision ; extra == 'torch-vision'
- pillow>=10.0.1,<=15.0 ; extra == 'torch-vision'
- filelock ; extra == 'torchhub'
- huggingface-hub>=0.24.0,<1.0 ; extra == 'torchhub'
- huggingface-hub>=0.30.0,<1.0 ; extra == 'torchhub'
- importlib-metadata ; extra == 'torchhub'
- numpy>=1.17 ; extra == 'torchhub'
- packaging>=20.0 ; extra == 'torchhub'
@@ -6669,10 +6682,10 @@ packages:
- regex!=2019.12.17 ; extra == 'torchhub'
- requests ; extra == 'torchhub'
- sentencepiece>=0.1.91,!=0.1.92 ; extra == 'torchhub'
- torch>=2.0 ; extra == 'torchhub'
- torch>=2.1,<2.7 ; extra == 'torchhub'
- tokenizers>=0.21,<0.22 ; extra == 'torchhub'
- tqdm>=4.27 ; extra == 'torchhub'
- av==9.2.0 ; extra == 'video'
- av ; extra == 'video'
- pillow>=10.0.1,<=15.0 ; extra == 'vision'
requires_python: '>=3.9.0'
- pypi: https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.2.3"
version = "3.2.4.a0"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
@@ -24,7 +24,7 @@ dependencies = [
"torch>=2.2.0",
"torchvision",
"torchtext",
"transformers<4.48.2",
"transformers==4.52.4",
"ipython",
"einops",
"biotite>=1.0.0",

View File

@@ -3,7 +3,11 @@ DOCKER_TAG ?= dev
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}
build-oss-ci:
docker build -f oss_pytests/Dockerfile oss_pytests -t $(DOCKER_IMAGE_OSS)
docker build \
--output=type=docker \
-f oss_pytests/Dockerfile \
-t $(DOCKER_IMAGE_OSS) \
oss_pytests
start-docker-oss:
docker run \