mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
268 lines
8.9 KiB
Python
268 lines
8.9 KiB
Python
import argparse
|
|
from functools import partial
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
from multiprocessing import cpu_count
|
|
from shutil import copyfile
|
|
import tempfile
|
|
|
|
import openfold.data.mmcif_parsing as mmcif_parsing
|
|
from openfold.data.data_pipeline import AlignmentRunner
|
|
from openfold.data.parsers import parse_fasta
|
|
from openfold.data.tools import hhsearch, hmmsearch
|
|
from openfold.np import protein, residue_constants
|
|
|
|
from utils import add_data_args
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
|
|
|
|
def run_seq_group_alignments(seq_groups, alignment_runner, args):
|
|
dirs = set(os.listdir(args.output_dir))
|
|
for seq, names in seq_groups:
|
|
first_name = names[0]
|
|
alignment_dir = os.path.join(args.output_dir, first_name)
|
|
|
|
try:
|
|
os.makedirs(alignment_dir)
|
|
except Exception as e:
|
|
logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
|
|
continue
|
|
|
|
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
|
|
with os.fdopen(fd, 'w') as fp:
|
|
fp.write(f'>query\n{seq}')
|
|
|
|
try:
|
|
alignment_runner.run(
|
|
fasta_path, alignment_dir
|
|
)
|
|
except Exception as e:
|
|
logging.warning(e)
|
|
logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
|
|
os.remove(fasta_path)
|
|
os.rmdir(alignment_dir)
|
|
continue
|
|
|
|
os.remove(fasta_path)
|
|
|
|
for name in names[1:]:
|
|
if(name in dirs):
|
|
logging.warning(
|
|
f'{name} has already been processed. Skipping...'
|
|
)
|
|
continue
|
|
|
|
cp_dir = os.path.join(args.output_dir, name)
|
|
os.makedirs(cp_dir, exist_ok=True)
|
|
|
|
for f in os.listdir(alignment_dir):
|
|
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
|
|
|
|
|
|
def parse_and_align(files, alignment_runner, args):
|
|
for f in files:
|
|
path = os.path.join(args.input_dir, f)
|
|
file_id = os.path.splitext(f)[0]
|
|
seq_group_dict = {}
|
|
if(f.endswith('.cif')):
|
|
with open(path, 'r') as fp:
|
|
mmcif_str = fp.read()
|
|
mmcif = mmcif_parsing.parse(
|
|
file_id=file_id, mmcif_string=mmcif_str
|
|
)
|
|
if(mmcif.mmcif_object is None):
|
|
logging.warning(f'Failed to parse {f}...')
|
|
if(args.raise_errors):
|
|
raise list(mmcif.errors.values())[0]
|
|
else:
|
|
continue
|
|
mmcif = mmcif.mmcif_object
|
|
for chain_letter, seq in mmcif.chain_to_seqres.items():
|
|
chain_id = '_'.join([file_id, chain_letter])
|
|
l = seq_group_dict.setdefault(seq, [])
|
|
l.append(chain_id)
|
|
elif(f.endswith('.fasta') or f.endswith('.fa')):
|
|
with open(path, 'r') as fp:
|
|
fasta_str = fp.read()
|
|
input_seqs, _ = parse_fasta(fasta_str)
|
|
if len(input_seqs) != 1:
|
|
msg = f'More than one input_sequence found in {f}'
|
|
if(args.raise_errors):
|
|
raise ValueError(msg)
|
|
else:
|
|
logging.warning(msg)
|
|
input_sequence = input_seqs[0]
|
|
seq_group_dict[input_sequence] = [file_id]
|
|
elif(f.endswith('.core')):
|
|
with open(path, 'r') as fp:
|
|
core_str = fp.read()
|
|
core_prot = protein.from_proteinnet_string(core_str)
|
|
aatype = core_prot.aatype
|
|
seq = ''.join([
|
|
residue_constants.restypes_with_x[aatype[i]]
|
|
for i in range(len(aatype))
|
|
])
|
|
seq_group_dict[seq] = [file_id]
|
|
else:
|
|
continue
|
|
|
|
seq_group_tuples = [(k,v) for k,v in seq_group_dict.items()]
|
|
run_seq_group_alignments(seq_group_tuples, alignment_runner, args)
|
|
|
|
|
|
def main(args):
|
|
# Build the alignment tool runner
|
|
if args.hmmsearch_binary_path is not None and args.pdb_seqres_database_path is not None:
|
|
template_searcher = hmmsearch.Hmmsearch(
|
|
binary_path=args.hmmsearch_binary_path,
|
|
hmmbuild_binary_path=args.hmmbuild_binary_path,
|
|
database_path=args.pdb_seqres_database_path,
|
|
)
|
|
elif args.hhsearch_binary_path is not None and args.pdb70_database_path is not None:
|
|
template_searcher = hhsearch.HHSearch(
|
|
binary_path=args.hhsearch_binary_path,
|
|
databases=[args.pdb70_database_path],
|
|
)
|
|
else:
|
|
template_searcher = None
|
|
|
|
alignment_runner = AlignmentRunner(
|
|
jackhmmer_binary_path=args.jackhmmer_binary_path,
|
|
hhblits_binary_path=args.hhblits_binary_path,
|
|
uniref90_database_path=args.uniref90_database_path,
|
|
mgnify_database_path=args.mgnify_database_path,
|
|
bfd_database_path=args.bfd_database_path,
|
|
uniref30_database_path=args.uniref30_database_path,
|
|
uniclust30_database_path=args.uniclust30_database_path,
|
|
uniprot_database_path=args.uniprot_database_path,
|
|
template_searcher=template_searcher,
|
|
use_small_bfd=args.bfd_database_path is None,
|
|
no_cpus=args.cpus_per_task,
|
|
)
|
|
|
|
files = list(os.listdir(args.input_dir))
|
|
|
|
# Do some filtering
|
|
if(args.mmcif_cache is not None):
|
|
with open(args.mmcif_cache, "r") as fp:
|
|
cache = json.load(fp)
|
|
else:
|
|
cache = None
|
|
|
|
dirs = []
|
|
if(cache is not None and args.filter):
|
|
dirs = set(os.listdir(args.output_dir))
|
|
def prot_is_done(f):
|
|
prot_id = os.path.splitext(f)[0]
|
|
if(prot_id in cache):
|
|
chain_ids = cache[prot_id]["chain_ids"]
|
|
for c in chain_ids:
|
|
full_name = prot_id + "_" + c
|
|
if(not full_name in dirs):
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
return True
|
|
|
|
files = [f for f in files if not prot_is_done(f)]
|
|
|
|
def split_up_arglist(arglist):
|
|
# Split up the survivors
|
|
if(os.environ.get("SLURM_JOB_NUM_NODES", 0)):
|
|
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
|
if(num_nodes > 1):
|
|
node_id = int(os.environ["SLURM_NODEID"])
|
|
logging.warning(f"Num nodes: {num_nodes}")
|
|
logging.warning(f"Node ID: {node_id}")
|
|
arglist = arglist[node_id::num_nodes]
|
|
|
|
t_arglist = []
|
|
for i in range(args.no_tasks):
|
|
t_arglist.append(arglist[i::args.no_tasks])
|
|
|
|
return t_arglist
|
|
|
|
if(cache is not None and "seqs" in next(iter(cache.values()))):
|
|
seq_group_dict = {}
|
|
for f in files:
|
|
prot_id = os.path.splitext(f)[0]
|
|
if(prot_id in cache):
|
|
prot_cache = cache[prot_id]
|
|
chains_seqs = zip(
|
|
prot_cache["chain_ids"], prot_cache["seqs"]
|
|
)
|
|
for chain, seq in chains_seqs:
|
|
chain_name = prot_id + "_" + chain
|
|
if(chain_name not in dirs):
|
|
l = seq_group_dict.setdefault(seq, [])
|
|
l.append(chain_name)
|
|
|
|
func = partial(run_seq_group_alignments,
|
|
alignment_runner=alignment_runner,
|
|
args=args
|
|
)
|
|
|
|
seq_groups = [(k,v) for k,v in seq_group_dict.items()]
|
|
|
|
# Sort them by group length so the tasks are approximately balanced
|
|
seq_groups = sorted(seq_groups, key=lambda x: len(x[1]))
|
|
|
|
task_arglist = [[a] for a in split_up_arglist(seq_groups)]
|
|
else:
|
|
func = partial(parse_and_align,
|
|
alignment_runner=alignment_runner,
|
|
args=args,
|
|
)
|
|
task_arglist = [[a] for a in split_up_arglist(files)]
|
|
|
|
threads = []
|
|
for i, task_args in enumerate(task_arglist):
|
|
print(f"Started thread {i}...")
|
|
t = threading.Thread(target=func, args=task_args)
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"input_dir", type=str,
|
|
help="""Path to directory containing mmCIF, FASTA and/or ProteinNet
|
|
.core files"""
|
|
)
|
|
parser.add_argument(
|
|
"output_dir", type=str,
|
|
help="Directory in which to output alignments"
|
|
)
|
|
add_data_args(parser)
|
|
parser.add_argument(
|
|
"--raise_errors", action="store_true", default=False,
|
|
help="Whether to crash on parsing errors"
|
|
)
|
|
parser.add_argument(
|
|
"--cpus_per_task", type=int, default=cpu_count(),
|
|
help="Number of CPUs to use"
|
|
)
|
|
parser.add_argument(
|
|
"--mmcif_cache", type=str, default=None,
|
|
help="Path to mmCIF cache. Used to filter files to be parsed"
|
|
)
|
|
parser.add_argument(
|
|
"--no_tasks", type=int, default=1,
|
|
)
|
|
parser.add_argument(
|
|
"--filter", type=bool, default=True,
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|