mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Remove colabfold wrappers
This commit is contained in:
@@ -1,106 +0,0 @@
|
||||
"""
|
||||
Short script to parallelize colabfold to run across GPUs to speed up runtime. Make
|
||||
sure to set export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 before running!
|
||||
"""
|
||||
|
||||
# Use only standard libraries so we don't need to modify the env
|
||||
import functools
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import shutil
|
||||
import subprocess
|
||||
import multiprocessing as mp
|
||||
from pathlib import Path
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Build a basic CLI
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"foldername", type=str, help="Folder containing a3m msa files to run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--outdir",
|
||||
type=str,
|
||||
default=os.path.abspath(os.path.join(os.getcwd(), "colabfold_predictions")),
|
||||
help="Output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gpus", type=int, nargs="*", default=[0, 1, 2, 3], help="GPUs to use"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_colabfold(input_a3m: Path, outdir: Path, gpu: int) -> None:
|
||||
"""Run colabfold on the given a3m MSA file"""
|
||||
# Example command: colabfold_batch msas predictions --use-gpu-relax --amber --num-recycle=3 --model-type=AlphaFold2-ptm
|
||||
executable = shutil.which("colabfold_batch")
|
||||
if not executable:
|
||||
raise FileNotFoundError("Could not find colabfold_batch in PATH")
|
||||
cmd = f"CUDA_VISIBLE_DEVICES={gpu} {executable} {input_a3m} {outdir} --use-gpu-relax --amber --num-recycle=3 --model-type=AlphaFold2-ptm"
|
||||
retval = subprocess.call(cmd, shell=True)
|
||||
assert (
|
||||
retval == 0
|
||||
), f"colabfold_batch on {input_a3m} failed with return code {retval}"
|
||||
|
||||
|
||||
def run_colabfold_multi(input_a3m_files: List[Path], gpu: int, outdir: Path) -> None:
|
||||
"""Runs each file in a different folder"""
|
||||
for f in input_a3m_files:
|
||||
this_outdir = outdir / os.path.splitext(os.path.basename(f))[0]
|
||||
os.makedirs(this_outdir, exist_ok=True) # Make sure the output directory exists
|
||||
run_colabfold(f, this_outdir, gpu)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run the script
|
||||
"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
assert "XLA_FLAGS" in os.environ, "XLA_FLAGS not set!"
|
||||
|
||||
# Create output directory
|
||||
outdir = Path(args.outdir)
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
# Get all a3m files in the input directory
|
||||
input_files = sorted(glob.glob(os.path.join(args.foldername, "*.a3m")))
|
||||
assert input_files, f"No a3m files found in {args.foldername}"
|
||||
|
||||
# Split the input_files into chunks equal to the number of GPUs
|
||||
indices = np.array_split(np.arange(len(input_files)), len(args.gpus))
|
||||
input_files_split = []
|
||||
for idx in indices:
|
||||
input_files_split.append([input_files[i] for i in idx])
|
||||
assert len(input_files_split) == len(args.gpus)
|
||||
logging.info(
|
||||
f"Splitting input {len(input_files)} into sizes {[len(i) for i in indices]}"
|
||||
)
|
||||
|
||||
pfunc = functools.partial(run_colabfold_multi, outdir=outdir)
|
||||
# Create processes for each set of files
|
||||
processes = [
|
||||
mp.Process(target=pfunc, args=(input_files_split[i], args.gpus[i]))
|
||||
for i in range(len(args.gpus))
|
||||
]
|
||||
for p in processes:
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
@@ -1,135 +0,0 @@
|
||||
"""
|
||||
Script for calculating the self tm score given the input as sampled structures
|
||||
and the structures resulting from running the structures via inverse folding
|
||||
residue generation and alphafold/colabfold. Expects the following directory
|
||||
structure:
|
||||
working_dir (where this script is run):
|
||||
- sampled_pdb (containing the original generated pdb structures)
|
||||
- msas (containing the a3m files from MSA generation)
|
||||
- colabfold_predictions (containing the results from colabfold)
|
||||
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import *
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve()
|
||||
assert SRC_DIR.is_dir()
|
||||
sys.path.append(str(SRC_DIR))
|
||||
import tmalign
|
||||
|
||||
|
||||
def get_sctm_score(orig_pdb: Path, folded_pdb_dirs: List[Path]) -> float:
|
||||
"""
|
||||
Get the scTM score given the original pdb file and list of dirs with folded pdbs
|
||||
"""
|
||||
if not folded_pdb_dirs:
|
||||
return 0.0
|
||||
folded_pdb_files = []
|
||||
for d in folded_pdb_dirs:
|
||||
matches = glob.glob(str(d / "*_relaxed_rank_*_model_*.pdb"))
|
||||
assert len(matches) <= 5
|
||||
folded_pdb_files.extend(matches)
|
||||
if not folded_pdb_files:
|
||||
return 0.0
|
||||
logging.debug(
|
||||
f"Matching {orig_pdb} against {len(folded_pdb_files)} folded structures"
|
||||
)
|
||||
|
||||
# Get the scTM score
|
||||
score = tmalign.max_tm_across_refs(orig_pdb, folded_pdb_files, chunksize=1)
|
||||
return score
|
||||
|
||||
|
||||
def seqname_from_a3m(a3m_path: str) -> str:
|
||||
"""
|
||||
Gets the original query sequence from a3m
|
||||
"""
|
||||
# Return the first header line
|
||||
with open(a3m_path) as source:
|
||||
for line in source:
|
||||
if line.startswith(">"):
|
||||
return line.split()[0][1:].strip()
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
"""Get the CLI parser"""
|
||||
parser = argparse.ArgumentParser(
|
||||
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results_dir",
|
||||
type=str,
|
||||
default=os.getcwd(),
|
||||
help="Directory containing the results",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the script"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
results_dir = Path(args.results_dir).expanduser()
|
||||
assert os.path.isdir(results_dir)
|
||||
|
||||
# Get the list of generated structures
|
||||
gen_struct_subdir = results_dir / "sampled_pdb"
|
||||
assert os.path.isdir(gen_struct_subdir)
|
||||
gen_structs = glob.glob(str(gen_struct_subdir / "*.pdb"))
|
||||
assert gen_structs
|
||||
# Dictionary mapping names to path to structure, something like generated_123
|
||||
gen_structs = {os.path.splitext(os.path.basename(s))[0]: s for s in gen_structs}
|
||||
|
||||
# Get the list of inverse folding sequences
|
||||
msa_subdir = results_dir / "msas"
|
||||
msa_files = glob.glob(str(msa_subdir / "*.a3m"))
|
||||
assert msa_files
|
||||
# Create a mapping from the non-readable auto-generated MSA names to the readable normal names
|
||||
msa_name_to_human_name = {
|
||||
os.path.splitext(os.path.basename(s))[0]: seqname_from_a3m(s) for s in msa_files
|
||||
}
|
||||
|
||||
# Query the list of folded structures. This should contain folders corresponding to the
|
||||
# cryptic msa names
|
||||
fold_subdir = results_dir / "colabfold_predictions"
|
||||
fold_subdir_contents = [fold_subdir / d for d in os.listdir(fold_subdir)]
|
||||
fold_subdir_contents = [d for d in fold_subdir_contents if os.path.isdir(d)]
|
||||
# Walk through and map each to a generated structure name gen_struct_names
|
||||
gen_struct_to_folded_structs = {s: [] for s in gen_structs}
|
||||
for d in fold_subdir_contents:
|
||||
d_readable = msa_name_to_human_name[os.path.basename(d)]
|
||||
generated_base = "_".join(d_readable.split("_")[:2])
|
||||
assert generated_base in gen_structs
|
||||
gen_struct_to_folded_structs[generated_base].append(d)
|
||||
|
||||
# For each set of reference and generated structures, compute score
|
||||
sctm_scores = {}
|
||||
for s in tqdm(gen_structs):
|
||||
sctm_scores[s] = get_sctm_score(gen_structs[s], gen_struct_to_folded_structs[s])
|
||||
|
||||
# Write output
|
||||
with open("sctm_scores.json", "w") as sink:
|
||||
json.dump(sctm_scores, sink, indent=2)
|
||||
|
||||
fig, ax = plt.subplots(dpi=300)
|
||||
ax.hist(sctm_scores.values(), bins=20)
|
||||
ax.set(
|
||||
xlabel="scTM score",
|
||||
title=f"Self-consistency TM scores for {len(sctm_scores)} generated structures",
|
||||
)
|
||||
fig.savefig("sctm_scores.pdf", bbox_inches="tight")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
Reference in New Issue
Block a user