Make MOLDIR_ZIP_CACHE multiprocessing-safe (https://github.com/HannesStark/boltzgen/issues/40)

This commit is contained in:
Anton Bushuiev
2025-12-08 18:58:41 -05:00
parent 291c4633b2
commit 743236150a

View File

@@ -1,8 +1,9 @@
import itertools
import os
import pickle
from pathlib import Path
import random
from typing import Dict
from typing import Dict, List
import zipfile
import numpy as np
@@ -14,10 +15,44 @@ from boltzgen.data import const
from boltzgen.data.pad import pad_dim
from boltzgen.model.loss.confidence import lddt_dist
MOLDIR_ZIP_CACHE = {} # moldir -> open zip file
MOLDIR_ZIP_CACHE: Dict[
tuple[int, Path], zipfile.ZipFile
] = {} # moldir -> open zip file
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
def _get_zipfile(moldir: Path) -> zipfile.ZipFile:
"""
Retrieve a cached ZipFile object for the given molecule directory zip file.
If the ZipFile for the provided path and current process does not exist in the cache,
it is created and stored. Otherwise, the cached instance is returned.
Parameters
----------
moldir : Path
Path to the .zip file containing molecule data.
Returns
-------
zipfile.ZipFile
An open ZipFile object for reading molecule data.
Notes
-----
The cache is keyed by both process ID and zip file path to prevent issues with forked processes.
"""
pid = os.getpid()
key = (pid, moldir)
zf = MOLDIR_ZIP_CACHE.get(key)
if zf is None:
zf = zipfile.ZipFile(moldir, "r")
MOLDIR_ZIP_CACHE[key] = zf
return zf
def load_molecules(moldir: str, molecules: List[str]) -> Dict[str, Mol]:
"""Load the given input data.
Parameters
@@ -33,12 +68,10 @@ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
The loaded molecules.
"""
moldir = Path(moldir)
loaded_mols = {}
loaded_mols: Dict[str, Mol] = {}
if moldir.is_file() and moldir.suffix == ".zip":
if moldir not in MOLDIR_ZIP_CACHE:
MOLDIR_ZIP_CACHE[moldir] = zipfile.ZipFile(moldir, "r")
zip_file = MOLDIR_ZIP_CACHE[moldir]
zip_file = _get_zipfile(moldir)
for molecule in molecules:
pkl_filename = f"{molecule}.pkl"
with zip_file.open(pkl_filename, "r") as f: