mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
Make MOLDIR_ZIP_CACHE multiprocessing-safe (https://github.com/HannesStark/boltzgen/issues/40)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import itertools
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
@@ -14,10 +15,44 @@ from boltzgen.data import const
|
||||
from boltzgen.data.pad import pad_dim
|
||||
from boltzgen.model.loss.confidence import lddt_dist
|
||||
|
||||
MOLDIR_ZIP_CACHE = {} # moldir -> open zip file
|
||||
|
||||
MOLDIR_ZIP_CACHE: Dict[
|
||||
tuple[int, Path], zipfile.ZipFile
|
||||
] = {} # moldir -> open zip file
|
||||
|
||||
|
||||
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
||||
def _get_zipfile(moldir: Path) -> zipfile.ZipFile:
|
||||
"""
|
||||
Retrieve a cached ZipFile object for the given molecule directory zip file.
|
||||
|
||||
If the ZipFile for the provided path and current process does not exist in the cache,
|
||||
it is created and stored. Otherwise, the cached instance is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
moldir : Path
|
||||
Path to the .zip file containing molecule data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
zipfile.ZipFile
|
||||
An open ZipFile object for reading molecule data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The cache is keyed by both process ID and zip file path to prevent issues with forked processes.
|
||||
"""
|
||||
pid = os.getpid()
|
||||
key = (pid, moldir)
|
||||
|
||||
zf = MOLDIR_ZIP_CACHE.get(key)
|
||||
if zf is None:
|
||||
zf = zipfile.ZipFile(moldir, "r")
|
||||
MOLDIR_ZIP_CACHE[key] = zf
|
||||
return zf
|
||||
|
||||
|
||||
def load_molecules(moldir: str, molecules: List[str]) -> Dict[str, Mol]:
|
||||
"""Load the given input data.
|
||||
|
||||
Parameters
|
||||
@@ -33,12 +68,10 @@ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
||||
The loaded molecules.
|
||||
"""
|
||||
moldir = Path(moldir)
|
||||
loaded_mols = {}
|
||||
loaded_mols: Dict[str, Mol] = {}
|
||||
|
||||
if moldir.is_file() and moldir.suffix == ".zip":
|
||||
if moldir not in MOLDIR_ZIP_CACHE:
|
||||
MOLDIR_ZIP_CACHE[moldir] = zipfile.ZipFile(moldir, "r")
|
||||
zip_file = MOLDIR_ZIP_CACHE[moldir]
|
||||
zip_file = _get_zipfile(moldir)
|
||||
for molecule in molecules:
|
||||
pkl_filename = f"{molecule}.pkl"
|
||||
with zip_file.open(pkl_filename, "r") as f:
|
||||
|
||||
Reference in New Issue
Block a user