diff --git a/bin/pdb_to_residues_esm.py b/bin/pdb_to_residues_esm.py index 109fc4b..9b2a7af 100644 --- a/bin/pdb_to_residues_esm.py +++ b/bin/pdb_to_residues_esm.py @@ -13,8 +13,14 @@ pip install git+https://github.com/facebookresearch/esm.git # uses the following notebook as a reference: # https://colab.research.google.com/github/facebookresearch/esm/blob/main/examples/inverse_folding/notebook.ipynb +import os +import glob +import functools import logging import argparse +from typing import List, Optional + +from tqdm.auto import tqdm # Verfies that the environment is set up correctly import torch @@ -28,6 +34,7 @@ import esm.inverse_folding def write_fa(fname: str, seq: str): """Write a fasta file""" + assert fname.endswith(".fasta") with open(fname, "w") as f: f.write(">sampled\n") for chunk in [seq[i : i + 80] for i in range(0, len(seq), 80)]: @@ -35,14 +42,49 @@ def write_fa(fname: str, seq: str): return fname +def generate_residues( + fpath: str, model, chain_id: str = "A", n: int = 10, temperature: float = 1.0 +) -> List[str]: + """Generate residues for the structure contained in the PDB file""" + structure = esm.inverse_folding.util.load_structure(fpath, chain_id) + # Coords have shape (seq_len, 3, 3) + coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure( + structure + ) + logging.debug(f"Native sequence: {native_seq}") + retval = [] + for _ in range(n): + sampled_seq = model.sample(coords, temperature=temperature) + logging.debug(f"Sampled sequence: {sampled_seq}") + retval.append(sampled_seq) + return retval + + +def update_fname(fname: str, i: int, new_dir: str = "") -> str: + """ + Update the pdb filename to include a numeric index and a .fasta extension. + If new_dir is given then we move the output filename to that directory. + """ + assert os.path.isfile(fname) + parent, child = os.path.split(fname) + assert child + child_base, _child_ext = os.path.splitext(child) + assert child_base + if new_dir: + assert os.path.isdir(new_dir), f"Expected {new_dir} to be a directory" + parent = new_dir + return os.path.join(parent, f"{child_base}_esm_generated_{i}.fasta") + + def build_parser() -> argparse.ArgumentParser: """Build a basic CLI""" parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("fname", type=str, help="PDB file to generate residues for") parser.add_argument( - "-c", "--chain", type=str, default="A", help="Chain to use within PDB file" + "fname", + type=str, + help="PDB file to generate residues for, or a folder containing these", ) parser.add_argument( "-t", @@ -52,11 +94,7 @@ def build_parser() -> argparse.ArgumentParser: help="Temperature to sample at. Lower values result in lower diversity but higher sequence recovery", ) parser.add_argument( - "-o", - "--output", - type=str, - default="", - help="Output file (fasta format) to write to. If not provided, default to input + .fasta", + "-n", type=int, default=10, help="Number of sequences to generate per structure" ) parser.add_argument( "-s", @@ -72,31 +110,38 @@ def main(): parser = build_parser() args = parser.parse_args() - # Load in the file - fpath = args.fname - chain_id = args.chain - structure = esm.inverse_folding.util.load_structure(fpath, chain_id) - # Coords have shape (seq_len, 3, 3) - coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure( - structure - ) - logging.info(f"Native sequence: {native_seq}") - # Load the model model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() model = model.eval() - - # Sample the residues torch.manual_seed(args.seed) - sampled_seq = model.sample(coords, temperature=1.0) - logging.info(f"Sampled sequence: {sampled_seq}") - # If output file is given, write it - out = args.output - if out == "": - out = fpath.replace(".pdb", ".fasta") - if args.output: - write_fa(out, sampled_seq) + pfunc = functools.partial( + generate_residues, model=model, n=args.n, temperature=args.temperature + ) + + # Load in the file + if os.path.isfile(args.fname): + # If output file is given, write it + sequences = pfunc(args.fname) + for i, seq in enumerate(sequences): + out_fname = update_fname(args.fname, i) + write_fa(out_fname, seq) + elif os.path.isdir(args.fname): + # create a subdirecotry to store the fastas + outdir = os.path.join(args.fname, "esm_generated_fastas") + os.makedirs(outdir, exist_ok=True) + logging.info(f"Writing output to {outdir}") + # Query for inputs and process them + inputs = glob.glob(os.path.join(args.fname, "*.pdb")) + logging.info(f"Found {len(inputs)} PDB files to process in {args.fname}") + generated = [pfunc(f) for f in tqdm(inputs)] + # Write outputs + for orig_fname, seqs in zip(inputs, generated): + for i, seq in enumerate(seqs): + out_fname = update_fname(orig_fname, i, new_dir=outdir) + write_fa(out_fname, seq) + else: + raise RuntimeError(f"Expected {args.fname} to be a file or directory") if __name__ == "__main__":