mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge pull request #419 from aqlaboratory/setup-improvements_additional-scripts
Duplicate expansion support
This commit is contained in:
144
scripts/alignment_data_to_fasta.py
Normal file
144
scripts/alignment_data_to_fasta.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
This script generates a FASTA file for all chains in an alignment directory or
|
||||
alignment DB.
|
||||
"""
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def chain_dir_to_fasta(dir: Path) -> str:
|
||||
"""
|
||||
Generates a FASTA string from a chain directory.
|
||||
"""
|
||||
# take some alignment file
|
||||
for alignment_file_type in [
|
||||
"mgnify_hits.a3m",
|
||||
"uniref90_hits.a3m",
|
||||
"bfd_uniclust_hits.a3m",
|
||||
]:
|
||||
alignment_file = dir / alignment_file_type
|
||||
if alignment_file.exists():
|
||||
break
|
||||
|
||||
with open(alignment_file, "r") as f:
|
||||
next(f) # skip the first line
|
||||
seq = next(f).strip()
|
||||
|
||||
try:
|
||||
next_line = next(f)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
assert next_line.startswith(">") # ensure that sequence ended
|
||||
|
||||
chain_id = dir.name
|
||||
|
||||
return f">{chain_id}\n{seq}\n"
|
||||
|
||||
|
||||
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
|
||||
"""
|
||||
Generates a FASTA string from an alignment-db index entry.
|
||||
"""
|
||||
db_file = db_dir / index_entry["db"]
|
||||
|
||||
# look for an alignment file
|
||||
for alignment_file_type in [
|
||||
"mgnify_hits.a3m",
|
||||
"uniref90_hits.a3m",
|
||||
"bfd_uniclust_hits.a3m",
|
||||
]:
|
||||
for file_info in index_entry["files"]:
|
||||
if file_info[0] == alignment_file_type:
|
||||
start, size = file_info[1], file_info[2]
|
||||
break
|
||||
|
||||
with open(db_file, "rb") as f:
|
||||
f.seek(start)
|
||||
msa_lines = f.read(size).decode("utf-8").splitlines()
|
||||
seq = msa_lines[1]
|
||||
|
||||
try:
|
||||
next_line = msa_lines[2]
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
assert next_line.startswith(">") # ensure that sequence ended
|
||||
|
||||
return f">{chain_id}\n{seq}\n"
|
||||
|
||||
|
||||
def main(
|
||||
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
|
||||
) -> None:
|
||||
"""
|
||||
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
|
||||
"""
|
||||
fasta = []
|
||||
|
||||
if alignment_dir and alignment_db_index:
|
||||
raise ValueError(
|
||||
"Only one of alignment_db_index and alignment_dir can be provided."
|
||||
)
|
||||
|
||||
if alignment_dir:
|
||||
print("Creating FASTA from alignment directory...")
|
||||
chain_dirs = list(alignment_dir.iterdir())
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(chain_dir_to_fasta, chain_dir)
|
||||
for chain_dir in chain_dirs
|
||||
]
|
||||
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
|
||||
fasta.append(future.result())
|
||||
|
||||
elif alignment_db_index:
|
||||
print("Creating FASTA from alignment dbs...")
|
||||
|
||||
with open(alignment_db_index, "r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
db_dir = alignment_db_index.parent
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
|
||||
for chain_id, index_entry in index.items()
|
||||
]
|
||||
for future in tqdm(as_completed(futures), total=len(index)):
|
||||
fasta.append(future.result())
|
||||
else:
|
||||
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
f.write("".join(fasta))
|
||||
print(f"FASTA file written to {output_path}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=Path,
|
||||
help="Path to output FASTA file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alignment_db_index",
|
||||
type=Path,
|
||||
help="Path to alignment-db index file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alignment_dir",
|
||||
type=Path,
|
||||
help="Path to alignment directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.output_path, args.alignment_db_index, args.alignment_dir)
|
||||
@@ -5,17 +5,19 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be
|
||||
run on the output index. Additionally this script uses threading and
|
||||
multiprocessing and is much faster than the old version.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
from math import ceil
|
||||
from multiprocessing import cpu_count
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def split_file_list(file_list, n_shards):
|
||||
def split_file_list(file_list: list[Path], n_shards: int):
|
||||
"""
|
||||
Split up the total file list into n_shards sublists.
|
||||
"""
|
||||
@@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards):
|
||||
return split_list
|
||||
|
||||
|
||||
def chunked_iterator(lst, chunk_size):
|
||||
def chunked_iterator(lst: list, chunk_size: int):
|
||||
"""Iterate over a list in chunks of size chunk_size."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def read_chain_dir(chain_dir) -> dict:
|
||||
def read_chain_dir(chain_dir: Path) -> dict:
|
||||
"""
|
||||
Read all alignment files in a single chain directory and return a dict
|
||||
mapping chain name to file names and bytes.
|
||||
"""
|
||||
if not chain_dir.is_dir():
|
||||
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
|
||||
|
||||
|
||||
# ensure that PDB IDs are all lowercase
|
||||
pdb_id, chain = chain_dir.name.split("_")
|
||||
pdb_id = pdb_id.lower()
|
||||
chain_name = f"{pdb_id}_{chain}"
|
||||
|
||||
|
||||
|
||||
file_data = []
|
||||
|
||||
for file_path in sorted(chain_dir.iterdir()):
|
||||
@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict:
|
||||
return {chain_name: file_data}
|
||||
|
||||
|
||||
def process_chunk(chain_files: List[Path]) -> dict:
|
||||
def process_chunk(chain_files: list[Path]) -> dict:
|
||||
"""
|
||||
Returns the file names and bytes for all chains in a chunk of files.
|
||||
"""
|
||||
@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict:
|
||||
|
||||
|
||||
def create_shard(
|
||||
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
|
||||
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a single shard of the alignment database, and returns the
|
||||
@@ -92,7 +93,7 @@ def create_shard(
|
||||
CHUNK_SIZE = 200
|
||||
shard_index = defaultdict(
|
||||
create_index_default_dict
|
||||
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
|
||||
) # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
|
||||
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
|
||||
|
||||
pbar_desc = f"Shard {shard_num}"
|
||||
@@ -101,7 +102,11 @@ def create_shard(
|
||||
db_offset = 0
|
||||
db_file = open(output_path, "wb")
|
||||
for files_chunk in tqdm(
|
||||
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
|
||||
chunk_iter,
|
||||
total=ceil(len(shard_files) / CHUNK_SIZE),
|
||||
desc=pbar_desc,
|
||||
position=shard_num,
|
||||
leave=False,
|
||||
):
|
||||
# get processed files for one chunk
|
||||
chunk_data = process_chunk(files_chunk)
|
||||
@@ -125,9 +130,17 @@ def create_shard(
|
||||
def main(args):
|
||||
alignment_dir = args.alignment_dir
|
||||
output_dir = args.output_db_path
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
output_db_name = args.output_db_name
|
||||
n_shards = args.n_shards
|
||||
|
||||
n_cpus = cpu_count()
|
||||
if n_shards > n_cpus:
|
||||
print(
|
||||
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
|
||||
"This may result in slower performance. Consider using a smaller number of shards."
|
||||
)
|
||||
|
||||
# get all chain dirs in alignment_dir
|
||||
print("Getting chain directories...")
|
||||
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
|
||||
@@ -153,12 +166,36 @@ def main(args):
|
||||
super_index.update(shard_index)
|
||||
print("\nCreated all shards.")
|
||||
|
||||
if args.duplicate_chains_file:
|
||||
print("Extending super index with duplicate chains...")
|
||||
duplicates_added = 0
|
||||
with open(args.duplicate_chains_file, "r") as fp:
|
||||
duplicate_chains = [line.strip().split() for line in fp]
|
||||
|
||||
for chains in duplicate_chains:
|
||||
# find representative with alignment
|
||||
for chain in chains:
|
||||
if chain in super_index:
|
||||
representative_chain = chain
|
||||
break
|
||||
else:
|
||||
print(f"No representative chain found for {chains}, skipping...")
|
||||
continue
|
||||
|
||||
# add duplicates to index
|
||||
for chain in chains:
|
||||
if chain != representative_chain:
|
||||
super_index[chain] = super_index[representative_chain]
|
||||
duplicates_added += 1
|
||||
|
||||
print(f"Added {duplicates_added} duplicate chains to index.")
|
||||
|
||||
# write super index to file
|
||||
print("\nWriting super index...")
|
||||
index_path = output_dir / f"{output_db_name}.index"
|
||||
with open(index_path, "w") as fp:
|
||||
json.dump(super_index, fp, indent=4)
|
||||
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
@@ -179,13 +216,27 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"alignment_dir",
|
||||
type=Path,
|
||||
help="""Path to precomputed alignment directory, with one subdirectory
|
||||
per chain.""",
|
||||
help="""Path to precomputed flattened alignment directory, with one
|
||||
subdirectory per chain.""",
|
||||
)
|
||||
parser.add_argument("output_db_path", type=Path)
|
||||
parser.add_argument("output_db_name", type=str)
|
||||
parser.add_argument(
|
||||
"n_shards", type=int, help="Number of shards to split the database into"
|
||||
"--n_shards",
|
||||
type=int,
|
||||
help="Number of shards to split the database into",
|
||||
default=10,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duplicate_chains_file",
|
||||
type=Path,
|
||||
help="""
|
||||
Optional path to file containing duplicate chain information, where each
|
||||
line contains chains that are 100% sequence identical. If provided,
|
||||
duplicate chains will be added to the index and point to the same
|
||||
underlying database entry as their representatives in the alignment dir.
|
||||
""",
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
79
scripts/expand_alignment_duplicates.py
Normal file
79
scripts/expand_alignment_duplicates.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
The OpenProteinSet alignment database is non-redundant, meaning that it only
|
||||
stores one explicit representative alignment directory for all PDB chains in a
|
||||
100% sequence identity cluster. In order to add explicit alignments for all PDB
|
||||
chains, this script will add the missing chain directories and symlink them to
|
||||
their representative alignment directories. This is required in order to train
|
||||
OpenFold on the full PDB, not just one representative chain per cluster.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
|
||||
"""
|
||||
Create duplicate directory symlinks for all chains in the given duplicate lists.
|
||||
|
||||
Args:
|
||||
duplicate_lists (list[list[str]]): A list of lists, where each inner list
|
||||
contains chains that are 100% sequence identical.
|
||||
alignment_dir (Path): Path to flattened alignment directory, with one
|
||||
subdirectory per chain.
|
||||
"""
|
||||
print("Creating duplicate directory symlinks...")
|
||||
dirs_created = 0
|
||||
for chains in tqdm(duplicate_chains):
|
||||
# find the chain that has an alignment
|
||||
for chain in chains:
|
||||
if (alignment_dir / chain).exists():
|
||||
representative_chain = chain
|
||||
break
|
||||
else:
|
||||
print(f"No representative chain found for {chains}, skipping...")
|
||||
continue
|
||||
|
||||
# create symlinks for all other chains
|
||||
for chain in chains:
|
||||
if chain != representative_chain:
|
||||
target_path = alignment_dir / chain
|
||||
if target_path.exists():
|
||||
print(f"Chain {chain} already exists, skipping...")
|
||||
else:
|
||||
(target_path).symlink_to(alignment_dir / representative_chain)
|
||||
dirs_created += 1
|
||||
|
||||
print(f"Created directories for {dirs_created} duplicate chains.")
|
||||
|
||||
|
||||
def main(alignment_dir: Path, duplicate_chains_file: Path):
|
||||
# read duplicate chains file
|
||||
with open(duplicate_chains_file, "r") as fp:
|
||||
duplicate_chains = [list(line.strip().split()) for line in fp]
|
||||
|
||||
# convert to absolute path for symlink creation
|
||||
alignment_dir = alignment_dir.resolve()
|
||||
|
||||
create_duplicate_dirs(duplicate_chains, alignment_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"alignment_dir",
|
||||
type=Path,
|
||||
help="""Path to flattened alignment directory, with one subdirectory
|
||||
per chain.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"duplicate_chains_file",
|
||||
type=Path,
|
||||
help="""Path to file containing duplicate chains, where each line
|
||||
contains a space-separated list of chains that are 100%%
|
||||
sequence identical.
|
||||
""",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.alignment_dir, args.duplicate_chains_file)
|
||||
@@ -85,7 +85,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(
|
||||
description="Creates a sequence cluster file from a .fasta file using mmseqs2 with PDB settings."
|
||||
description=__doc__
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_fasta",
|
||||
|
||||
Reference in New Issue
Block a user