mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-05 05:04:24 +08:00
194 lines
6.0 KiB
Python
194 lines
6.0 KiB
Python
"""
|
|
This is a modified version of the create_alignment_db.py script in OpenFold
|
|
which supports sharding into multiple files. The created index is already a
|
|
super index, meaning that "unify_alignment_db_indices.py" does not need to be
|
|
run on the output index. Additionally this script uses threading and
|
|
multiprocessing and is much faster than the old version.
|
|
"""
|
|
import argparse
|
|
from collections import defaultdict
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List
|
|
from tqdm import tqdm
|
|
from math import ceil
|
|
|
|
|
|
def split_file_list(file_list, n_shards):
|
|
"""
|
|
Split up the total file list into n_shards sublists.
|
|
"""
|
|
split_list = []
|
|
|
|
for i in range(n_shards):
|
|
split_list.append(file_list[i::n_shards])
|
|
|
|
assert len([f for sublist in split_list for f in sublist]) == len(file_list)
|
|
|
|
return split_list
|
|
|
|
|
|
def chunked_iterator(lst, chunk_size):
|
|
"""Iterate over a list in chunks of size chunk_size."""
|
|
for i in range(0, len(lst), chunk_size):
|
|
yield lst[i : i + chunk_size]
|
|
|
|
|
|
def read_chain_dir(chain_dir) -> dict:
|
|
"""
|
|
Read all alignment files in a single chain directory and return a dict
|
|
mapping chain name to file names and bytes.
|
|
"""
|
|
if not chain_dir.is_dir():
|
|
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
|
|
|
|
# ensure that PDB IDs are all lowercase
|
|
pdb_id, chain = chain_dir.name.split("_")
|
|
pdb_id = pdb_id.lower()
|
|
chain_name = f"{pdb_id}_{chain}"
|
|
|
|
|
|
file_data = []
|
|
|
|
for file_path in sorted(chain_dir.iterdir()):
|
|
file_name = file_path.name
|
|
|
|
with open(file_path, "rb") as file:
|
|
file_bytes = file.read()
|
|
|
|
file_data.append((file_name, file_bytes))
|
|
|
|
return {chain_name: file_data}
|
|
|
|
|
|
def process_chunk(chain_files: List[Path]) -> dict:
|
|
"""
|
|
Returns the file names and bytes for all chains in a chunk of files.
|
|
"""
|
|
chunk_data = {}
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
for file_data in executor.map(read_chain_dir, chain_files):
|
|
chunk_data.update(file_data)
|
|
|
|
return chunk_data
|
|
|
|
|
|
def create_index_default_dict() -> dict:
|
|
"""
|
|
Returns a default dict for the index entries).
|
|
"""
|
|
return {"db": None, "files": []}
|
|
|
|
|
|
def create_shard(
|
|
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
|
|
) -> dict:
|
|
"""
|
|
Creates a single shard of the alignment database, and returns the
|
|
corresponding indices for the super index.
|
|
"""
|
|
CHUNK_SIZE = 200
|
|
shard_index = defaultdict(
|
|
create_index_default_dict
|
|
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
|
|
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
|
|
|
|
pbar_desc = f"Shard {shard_num}"
|
|
output_path = output_dir / f"{output_name}_{shard_num}.db"
|
|
|
|
db_offset = 0
|
|
db_file = open(output_path, "wb")
|
|
for files_chunk in tqdm(
|
|
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
|
|
):
|
|
# get processed files for one chunk
|
|
chunk_data = process_chunk(files_chunk)
|
|
|
|
# write to db and store info in index
|
|
for chain_name, file_data in chunk_data.items():
|
|
shard_index[chain_name]["db"] = output_path.name
|
|
|
|
for file_name, file_bytes in file_data:
|
|
file_length = len(file_bytes)
|
|
shard_index[chain_name]["files"].append(
|
|
(file_name, db_offset, file_length)
|
|
)
|
|
db_file.write(file_bytes)
|
|
db_offset += file_length
|
|
db_file.close()
|
|
|
|
return shard_index
|
|
|
|
|
|
def main(args):
|
|
alignment_dir = args.alignment_dir
|
|
output_dir = args.output_db_path
|
|
output_db_name = args.output_db_name
|
|
n_shards = args.n_shards
|
|
|
|
# get all chain dirs in alignment_dir
|
|
print("Getting chain directories...")
|
|
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
|
|
|
|
# split chain dirs into n_shards sublists
|
|
chain_dir_shards = split_file_list(all_chain_dirs, n_shards)
|
|
|
|
# total index for all shards
|
|
super_index = {}
|
|
|
|
# create a shard for each sublist
|
|
print(f"Creating {n_shards} alignment-db files...")
|
|
with ProcessPoolExecutor() as executor:
|
|
futures = [
|
|
executor.submit(
|
|
create_shard, shard_files, output_dir, output_db_name, shard_index
|
|
)
|
|
for shard_index, shard_files in enumerate(chain_dir_shards)
|
|
]
|
|
|
|
for future in as_completed(futures):
|
|
shard_index = future.result()
|
|
super_index.update(shard_index)
|
|
print("\nCreated all shards.")
|
|
|
|
# write super index to file
|
|
print("\nWriting super index...")
|
|
index_path = output_dir / f"{output_db_name}.index"
|
|
with open(index_path, "w") as fp:
|
|
json.dump(super_index, fp, indent=4)
|
|
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="""
|
|
This script creates an alignment database format from a directory of
|
|
precomputed alignments. For better file system health, the total
|
|
database is split into n_shards files, where each shard contains a
|
|
subset of the total alignments. The output is a directory containing the
|
|
n_shards database files, and a single index file mapping chain names to
|
|
the database file and byte offsets for each alignment file.
|
|
|
|
Note: For optimal performance, your machine should have at least as many
|
|
cores as shards you want to create.
|
|
"""
|
|
)
|
|
parser.add_argument(
|
|
"alignment_dir",
|
|
type=Path,
|
|
help="""Path to precomputed alignment directory, with one subdirectory
|
|
per chain.""",
|
|
)
|
|
parser.add_argument("output_db_path", type=Path)
|
|
parser.add_argument("output_db_name", type=str)
|
|
parser.add_argument(
|
|
"n_shards", type=int, help="Number of shards to split the database into"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|