mirror of
https://github.com/generatebio/chroma.git
synced 2026-06-04 13:30:34 +08:00
Use tempfile for platform-independent temporary file locations, including cached centering parameters; related to #8, #13, #19
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -230,7 +231,7 @@ class Protein:
|
||||
|
||||
from chroma.utility.fetchdb import RCSB_file_download
|
||||
|
||||
file_cif = f"/tmp/{pdb_id}.cif"
|
||||
file_cif = os.path.join(tempfile.gettempdir(), f"{pdb_id}.cif")
|
||||
RCSB_file_download(pdb_id, ".cif", file_cif)
|
||||
protein = cls.from_CIF(file_cif, canonicalize=canonicalize, device=device)
|
||||
unlink(file_cif)
|
||||
@@ -336,7 +337,8 @@ class Protein:
|
||||
return X, C, S
|
||||
|
||||
def to_XCS_trajectory(
|
||||
self, device: Optional[str] = None,
|
||||
self,
|
||||
device: Optional[str] = None,
|
||||
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert the current Protein object to its XCS tensor representations over a trajectory.
|
||||
@@ -416,7 +418,9 @@ class Protein:
|
||||
This method processes the protein to ensure it conforms to a canonical form.
|
||||
"""
|
||||
self.sys.canonicalize_protein(
|
||||
level=2, drop_coors_unknowns=True, drop_coors_missing_backbone=True,
|
||||
level=2,
|
||||
drop_coors_unknowns=True,
|
||||
drop_coors_missing_backbone=True,
|
||||
)
|
||||
|
||||
def sequence(self, format: str = "one-letter-string") -> Union[List[str], str]:
|
||||
|
||||
@@ -23,6 +23,7 @@ natively in pytorch.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -244,12 +245,12 @@ class ProteinFeatureGraph(nn.Module):
|
||||
return node_h, edge_h, edge_idx, mask_i, mask_ij
|
||||
|
||||
def _load_centering_params(self, reference_pdb: str):
|
||||
basepath = os.path.dirname(os.path.abspath(__file__)) + "/params/"
|
||||
basepath = os.path.join(tempfile.gettempdir(), "generate", "params")
|
||||
if not os.path.exists(basepath):
|
||||
os.makedirs(basepath)
|
||||
|
||||
filename = f"centering_{reference_pdb}.params"
|
||||
self.centering_file = basepath + filename
|
||||
self.centering_file = os.path.join(basepath, filename)
|
||||
key = (
|
||||
reference_pdb
|
||||
+ ";"
|
||||
@@ -310,7 +311,7 @@ class ProteinFeatureGraph(nn.Module):
|
||||
std = std.view(-1)
|
||||
|
||||
if verbose:
|
||||
frac = (100.0 * std ** 2 / (mean ** 2 + std ** 2)).type(torch.int32)
|
||||
frac = (100.0 * std**2 / (mean**2 + std**2)).type(torch.int32)
|
||||
print(f"Fraction of raw variance: {frac}")
|
||||
return mean, std
|
||||
|
||||
@@ -1032,7 +1033,7 @@ class EdgeOrientation2mer(nn.Module):
|
||||
|
||||
def _normed_vec(self, V):
|
||||
# Unit vector from i to j
|
||||
mag_sq = (V ** 2).sum(dim=-1, keepdim=True)
|
||||
mag_sq = (V**2).sum(dim=-1, keepdim=True)
|
||||
mag = torch.sqrt(mag_sq + self.distance_eps)
|
||||
V_norm = V / mag
|
||||
return V_norm
|
||||
@@ -1118,7 +1119,7 @@ class EdgeOrientationChain(nn.Module):
|
||||
|
||||
def _normed_vec(self, V):
|
||||
# Unit vector from i to j
|
||||
mag_sq = (V ** 2).sum(dim=-1, keepdim=True)
|
||||
mag_sq = (V**2).sum(dim=-1, keepdim=True)
|
||||
mag = torch.sqrt(mag_sq + self.norm_eps)
|
||||
V_norm = V / mag
|
||||
return V_norm
|
||||
@@ -1183,7 +1184,7 @@ class EdgeOrientationChain(nn.Module):
|
||||
def _transformation_features(self, X_i, X_j, R_i, R_j, edge_idx, edges=True):
|
||||
# Distance and direction
|
||||
dX = X_j - X_i.unsqueeze(2).contiguous()
|
||||
L = torch.sqrt((dX ** 2).sum(-1, keepdim=True) + self.distance_eps)
|
||||
L = torch.sqrt((dX**2).sum(-1, keepdim=True) + self.distance_eps)
|
||||
u_ij = torch.einsum("niab,nija->nijb", R_i, dX / L)
|
||||
|
||||
# Relative orientation
|
||||
@@ -1473,7 +1474,7 @@ class NodeCartesianCoords(nn.Module):
|
||||
self.num_atom_types = num_atom_types
|
||||
|
||||
# Public attribute
|
||||
self.dim_out = 3 * (num_atom_types ** 2)
|
||||
self.dim_out = 3 * (num_atom_types**2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1524,7 +1525,7 @@ class EdgeCartesianCoords(nn.Module):
|
||||
self.num_atom_types = num_atom_types
|
||||
|
||||
# Public attribute
|
||||
self.dim_out = 3 * (num_atom_types ** 2)
|
||||
self.dim_out = 3 * (num_atom_types**2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -18,6 +18,7 @@ Utilities to save and load models with metadata.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from uuid import uuid4
|
||||
@@ -44,7 +45,9 @@ def save_model(model, weight_file, metadata=None):
|
||||
if metadata is not None:
|
||||
save_dict.update(metadata)
|
||||
local_path = str(
|
||||
Path("/tmp", str(uuid4())[:8]) if weight_file.startswith("s3:") else weight_file
|
||||
Path(tempfile.gettempdir(), str(uuid4())[:8])
|
||||
if weight_file.startswith("s3:")
|
||||
else weight_file
|
||||
)
|
||||
torch.save(save_dict, local_path)
|
||||
if weight_file.startswith("s3:"):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import filecmp
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
@@ -167,7 +168,7 @@ def test_invalid_input(cif_file):
|
||||
|
||||
|
||||
def next_structure_file(num=100, cif=True):
|
||||
tmp_file = "/tmp/_pdb_list.txt"
|
||||
tmp_file = os.path.join(tempfile.gettempdir(), "_pdb_list.txt")
|
||||
download_file(
|
||||
"https://files.wwpdb.org/pub/pdb/derived_data/pdb_entry_type.txt", tmp_file
|
||||
)
|
||||
@@ -176,16 +177,16 @@ def next_structure_file(num=100, cif=True):
|
||||
random.shuffle(pdb_ids)
|
||||
|
||||
if cif:
|
||||
file = "/tmp/_pdb_download.cif"
|
||||
file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif")
|
||||
else:
|
||||
file = "/tmp/_pdb_download.pdb"
|
||||
file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb")
|
||||
for pdb_id in pdb_ids[:num]:
|
||||
# download CIF file
|
||||
if cif:
|
||||
file = "/tmp/_pdb_download.cif"
|
||||
file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif")
|
||||
download_file(f"https://files.rcsb.org/download/{pdb_id}.cif", file)
|
||||
else:
|
||||
file = "/tmp/_pdb_download.pdb"
|
||||
file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb")
|
||||
download_file(f"https://files.rcsb.org/download/{pdb_id}.pdb", file)
|
||||
yield pdb_id, file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user