From ca86eaae1ffcfc4ec128f4ef61be044228f36e8a Mon Sep 17 00:00:00 2001 From: aismail3 Date: Tue, 28 Nov 2023 20:24:40 +0000 Subject: [PATCH] Use tempfile for platform-independent temporary file locations, including cached centering parameters; related to #8, #13, #19 --- chroma/data/protein.py | 10 +++++++--- chroma/layers/structure/protein_graph.py | 17 +++++++++-------- chroma/utility/model.py | 5 ++++- tests/data/test_system.py | 11 ++++++----- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/chroma/data/protein.py b/chroma/data/protein.py index 4e20405..11e2b6f 100644 --- a/chroma/data/protein.py +++ b/chroma/data/protein.py @@ -15,6 +15,7 @@ from __future__ import annotations import copy +import os import tempfile from typing import List, Optional, Tuple, Union @@ -230,7 +231,7 @@ class Protein: from chroma.utility.fetchdb import RCSB_file_download - file_cif = f"/tmp/{pdb_id}.cif" + file_cif = os.path.join(tempfile.gettempdir(), f"{pdb_id}.cif") RCSB_file_download(pdb_id, ".cif", file_cif) protein = cls.from_CIF(file_cif, canonicalize=canonicalize, device=device) unlink(file_cif) @@ -336,7 +337,8 @@ class Protein: return X, C, S def to_XCS_trajectory( - self, device: Optional[str] = None, + self, + device: Optional[str] = None, ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]: """ Convert the current Protein object to its XCS tensor representations over a trajectory. @@ -416,7 +418,9 @@ class Protein: This method processes the protein to ensure it conforms to a canonical form. """ self.sys.canonicalize_protein( - level=2, drop_coors_unknowns=True, drop_coors_missing_backbone=True, + level=2, + drop_coors_unknowns=True, + drop_coors_missing_backbone=True, ) def sequence(self, format: str = "one-letter-string") -> Union[List[str], str]: diff --git a/chroma/layers/structure/protein_graph.py b/chroma/layers/structure/protein_graph.py index f2ff0a2..4584e5d 100644 --- a/chroma/layers/structure/protein_graph.py +++ b/chroma/layers/structure/protein_graph.py @@ -23,6 +23,7 @@ natively in pytorch. import json import os +import tempfile from typing import Optional, Tuple import numpy as np @@ -244,12 +245,12 @@ class ProteinFeatureGraph(nn.Module): return node_h, edge_h, edge_idx, mask_i, mask_ij def _load_centering_params(self, reference_pdb: str): - basepath = os.path.dirname(os.path.abspath(__file__)) + "/params/" + basepath = os.path.join(tempfile.gettempdir(), "generate", "params") if not os.path.exists(basepath): os.makedirs(basepath) filename = f"centering_{reference_pdb}.params" - self.centering_file = basepath + filename + self.centering_file = os.path.join(basepath, filename) key = ( reference_pdb + ";" @@ -310,7 +311,7 @@ class ProteinFeatureGraph(nn.Module): std = std.view(-1) if verbose: - frac = (100.0 * std ** 2 / (mean ** 2 + std ** 2)).type(torch.int32) + frac = (100.0 * std**2 / (mean**2 + std**2)).type(torch.int32) print(f"Fraction of raw variance: {frac}") return mean, std @@ -1032,7 +1033,7 @@ class EdgeOrientation2mer(nn.Module): def _normed_vec(self, V): # Unit vector from i to j - mag_sq = (V ** 2).sum(dim=-1, keepdim=True) + mag_sq = (V**2).sum(dim=-1, keepdim=True) mag = torch.sqrt(mag_sq + self.distance_eps) V_norm = V / mag return V_norm @@ -1118,7 +1119,7 @@ class EdgeOrientationChain(nn.Module): def _normed_vec(self, V): # Unit vector from i to j - mag_sq = (V ** 2).sum(dim=-1, keepdim=True) + mag_sq = (V**2).sum(dim=-1, keepdim=True) mag = torch.sqrt(mag_sq + self.norm_eps) V_norm = V / mag return V_norm @@ -1183,7 +1184,7 @@ class EdgeOrientationChain(nn.Module): def _transformation_features(self, X_i, X_j, R_i, R_j, edge_idx, edges=True): # Distance and direction dX = X_j - X_i.unsqueeze(2).contiguous() - L = torch.sqrt((dX ** 2).sum(-1, keepdim=True) + self.distance_eps) + L = torch.sqrt((dX**2).sum(-1, keepdim=True) + self.distance_eps) u_ij = torch.einsum("niab,nija->nijb", R_i, dX / L) # Relative orientation @@ -1473,7 +1474,7 @@ class NodeCartesianCoords(nn.Module): self.num_atom_types = num_atom_types # Public attribute - self.dim_out = 3 * (num_atom_types ** 2) + self.dim_out = 3 * (num_atom_types**2) def forward( self, @@ -1524,7 +1525,7 @@ class EdgeCartesianCoords(nn.Module): self.num_atom_types = num_atom_types # Public attribute - self.dim_out = 3 * (num_atom_types ** 2) + self.dim_out = 3 * (num_atom_types**2) def forward( self, diff --git a/chroma/utility/model.py b/chroma/utility/model.py index e7fcfac..7f57557 100644 --- a/chroma/utility/model.py +++ b/chroma/utility/model.py @@ -18,6 +18,7 @@ Utilities to save and load models with metadata. import os import os.path as osp +import tempfile from pathlib import Path from urllib.parse import parse_qs, urlparse from uuid import uuid4 @@ -44,7 +45,9 @@ def save_model(model, weight_file, metadata=None): if metadata is not None: save_dict.update(metadata) local_path = str( - Path("/tmp", str(uuid4())[:8]) if weight_file.startswith("s3:") else weight_file + Path(tempfile.gettempdir(), str(uuid4())[:8]) + if weight_file.startswith("s3:") + else weight_file ) torch.save(save_dict, local_path) if weight_file.startswith("s3:"): diff --git a/tests/data/test_system.py b/tests/data/test_system.py index 72551e2..9d15747 100644 --- a/tests/data/test_system.py +++ b/tests/data/test_system.py @@ -1,5 +1,6 @@ import copy import filecmp +import os import random import tempfile import time @@ -167,7 +168,7 @@ def test_invalid_input(cif_file): def next_structure_file(num=100, cif=True): - tmp_file = "/tmp/_pdb_list.txt" + tmp_file = os.path.join(tempfile.gettempdir(), "_pdb_list.txt") download_file( "https://files.wwpdb.org/pub/pdb/derived_data/pdb_entry_type.txt", tmp_file ) @@ -176,16 +177,16 @@ def next_structure_file(num=100, cif=True): random.shuffle(pdb_ids) if cif: - file = "/tmp/_pdb_download.cif" + file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif") else: - file = "/tmp/_pdb_download.pdb" + file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb") for pdb_id in pdb_ids[:num]: # download CIF file if cif: - file = "/tmp/_pdb_download.cif" + file = os.path.join(tempfile.gettempdir(), "_pdb_download.cif") download_file(f"https://files.rcsb.org/download/{pdb_id}.cif", file) else: - file = "/tmp/_pdb_download.pdb" + file = os.path.join(tempfile.gettempdir(), "_pdb_download.pdb") download_file(f"https://files.rcsb.org/download/{pdb_id}.pdb", file) yield pdb_id, file