Use tempfile for platform-independent temporary file locations, including cached centering parameters; related to #8, #13, #19

This commit is contained in:
aismail3
2023-11-28 20:24:40 +00:00
parent d5f57876ea
commit ca86eaae1f
4 changed files with 26 additions and 17 deletions

View File

@@ -15,6 +15,7 @@
from __future__ import annotations
import copy
import os
import tempfile
from typing import List, Optional, Tuple, Union
@@ -230,7 +231,7 @@ class Protein:
from chroma.utility.fetchdb import RCSB_file_download
file_cif = f"/tmp/{pdb_id}.cif"
file_cif = os.path.join(tempfile.gettempdir(), f"{pdb_id}.cif")
RCSB_file_download(pdb_id, ".cif", file_cif)
protein = cls.from_CIF(file_cif, canonicalize=canonicalize, device=device)
unlink(file_cif)
@@ -336,7 +337,8 @@ class Protein:
return X, C, S
def to_XCS_trajectory(
self, device: Optional[str] = None,
self,
device: Optional[str] = None,
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
"""
Convert the current Protein object to its XCS tensor representations over a trajectory.
@@ -416,7 +418,9 @@ class Protein:
This method processes the protein to ensure it conforms to a canonical form.
"""
self.sys.canonicalize_protein(
level=2, drop_coors_unknowns=True, drop_coors_missing_backbone=True,
level=2,
drop_coors_unknowns=True,
drop_coors_missing_backbone=True,
)
def sequence(self, format: str = "one-letter-string") -> Union[List[str], str]:

View File

@@ -23,6 +23,7 @@ natively in pytorch.
import json
import os
import tempfile
from typing import Optional, Tuple
import numpy as np
@@ -244,12 +245,12 @@ class ProteinFeatureGraph(nn.Module):
return node_h, edge_h, edge_idx, mask_i, mask_ij
def _load_centering_params(self, reference_pdb: str):
basepath = os.path.dirname(os.path.abspath(__file__)) + "/params/"
basepath = os.path.join(tempfile.gettempdir(), "generate", "params")
if not os.path.exists(basepath):
os.makedirs(basepath)
filename = f"centering_{reference_pdb}.params"
self.centering_file = basepath + filename
self.centering_file = os.path.join(basepath, filename)
key = (
reference_pdb
+ ";"
@@ -310,7 +311,7 @@ class ProteinFeatureGraph(nn.Module):
std = std.view(-1)
if verbose:
frac = (100.0 * std ** 2 / (mean ** 2 + std ** 2)).type(torch.int32)
frac = (100.0 * std**2 / (mean**2 + std**2)).type(torch.int32)
print(f"Fraction of raw variance: {frac}")
return mean, std
@@ -1032,7 +1033,7 @@ class EdgeOrientation2mer(nn.Module):
def _normed_vec(self, V):
# Unit vector from i to j
mag_sq = (V ** 2).sum(dim=-1, keepdim=True)
mag_sq = (V**2).sum(dim=-1, keepdim=True)
mag = torch.sqrt(mag_sq + self.distance_eps)
V_norm = V / mag
return V_norm
@@ -1118,7 +1119,7 @@ class EdgeOrientationChain(nn.Module):
def _normed_vec(self, V):
# Unit vector from i to j
mag_sq = (V ** 2).sum(dim=-1, keepdim=True)
mag_sq = (V**2).sum(dim=-1, keepdim=True)
mag = torch.sqrt(mag_sq + self.norm_eps)
V_norm = V / mag
return V_norm
@@ -1183,7 +1184,7 @@ class EdgeOrientationChain(nn.Module):
def _transformation_features(self, X_i, X_j, R_i, R_j, edge_idx, edges=True):
# Distance and direction
dX = X_j - X_i.unsqueeze(2).contiguous()
L = torch.sqrt((dX ** 2).sum(-1, keepdim=True) + self.distance_eps)
L = torch.sqrt((dX**2).sum(-1, keepdim=True) + self.distance_eps)
u_ij = torch.einsum("niab,nija->nijb", R_i, dX / L)
# Relative orientation
@@ -1473,7 +1474,7 @@ class NodeCartesianCoords(nn.Module):
self.num_atom_types = num_atom_types
# Public attribute
self.dim_out = 3 * (num_atom_types ** 2)
self.dim_out = 3 * (num_atom_types**2)
def forward(
self,
@@ -1524,7 +1525,7 @@ class EdgeCartesianCoords(nn.Module):
self.num_atom_types = num_atom_types
# Public attribute
self.dim_out = 3 * (num_atom_types ** 2)
self.dim_out = 3 * (num_atom_types**2)
def forward(
self,

View File

@@ -18,6 +18,7 @@ Utilities to save and load models with metadata.
import os
import os.path as osp
import tempfile
from pathlib import Path
from urllib.parse import parse_qs, urlparse
from uuid import uuid4
@@ -44,7 +45,9 @@ def save_model(model, weight_file, metadata=None):
if metadata is not None:
save_dict.update(metadata)
local_path = str(
Path("/tmp", str(uuid4())[:8]) if weight_file.startswith("s3:") else weight_file
Path(tempfile.gettempdir(), str(uuid4())[:8])
if weight_file.startswith("s3:")
else weight_file
)
torch.save(save_dict, local_path)
if weight_file.startswith("s3:"):

View File

@@ -1,5 +1,6 @@
import copy
import filecmp
import os
import random
import tempfile
import time
@@ -167,7 +168,7 @@ def test_invalid_input(cif_file):
def next_structure_file(num=100, cif=True):
tmp_file = "/tmp/_pdb_list.txt"
tmp_file = os.path.join(tempfile.gettempdir(), "_pdb_list.txt")
download_file(
"https://files.wwpdb.org/pub/pdb/derived_data/pdb_entry_type.txt", tmp_file
)
@@ -176,16 +177,16 @@ def next_structure_file(num=100, cif=True):
random.shuffle(pdb_ids)
if cif:
file = "/tmp/_pdb_download.cif"
file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif")
else:
file = "/tmp/_pdb_download.pdb"
file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb")
for pdb_id in pdb_ids[:num]:
# download CIF file
if cif:
file = "/tmp/_pdb_download.cif"
file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif")
download_file(f"https://files.rcsb.org/download/{pdb_id}.cif", file)
else:
file = "/tmp/_pdb_download.pdb"
file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb")
download_file(f"https://files.rcsb.org/download/{pdb_id}.pdb", file)
yield pdb_id, file