Flesh out the residues generation script to handle entire folders

This commit is contained in:
Kevin Wu
2022-09-01 14:59:19 -07:00
parent cd30f99176
commit 2d32054500

View File

@@ -13,8 +13,14 @@ pip install git+https://github.com/facebookresearch/esm.git
# uses the following notebook as a reference:
# https://colab.research.google.com/github/facebookresearch/esm/blob/main/examples/inverse_folding/notebook.ipynb
import os
import glob
import functools
import logging
import argparse
from typing import List, Optional
from tqdm.auto import tqdm
# Verfies that the environment is set up correctly
import torch
@@ -28,6 +34,7 @@ import esm.inverse_folding
def write_fa(fname: str, seq: str):
"""Write a fasta file"""
assert fname.endswith(".fasta")
with open(fname, "w") as f:
f.write(">sampled\n")
for chunk in [seq[i : i + 80] for i in range(0, len(seq), 80)]:
@@ -35,14 +42,49 @@ def write_fa(fname: str, seq: str):
return fname
def generate_residues(
fpath: str, model, chain_id: str = "A", n: int = 10, temperature: float = 1.0
) -> List[str]:
"""Generate residues for the structure contained in the PDB file"""
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
# Coords have shape (seq_len, 3, 3)
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
structure
)
logging.debug(f"Native sequence: {native_seq}")
retval = []
for _ in range(n):
sampled_seq = model.sample(coords, temperature=temperature)
logging.debug(f"Sampled sequence: {sampled_seq}")
retval.append(sampled_seq)
return retval
def update_fname(fname: str, i: int, new_dir: str = "") -> str:
"""
Update the pdb filename to include a numeric index and a .fasta extension.
If new_dir is given then we move the output filename to that directory.
"""
assert os.path.isfile(fname)
parent, child = os.path.split(fname)
assert child
child_base, _child_ext = os.path.splitext(child)
assert child_base
if new_dir:
assert os.path.isdir(new_dir), f"Expected {new_dir} to be a directory"
parent = new_dir
return os.path.join(parent, f"{child_base}_esm_generated_{i}.fasta")
def build_parser() -> argparse.ArgumentParser:
"""Build a basic CLI"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("fname", type=str, help="PDB file to generate residues for")
parser.add_argument(
"-c", "--chain", type=str, default="A", help="Chain to use within PDB file"
"fname",
type=str,
help="PDB file to generate residues for, or a folder containing these",
)
parser.add_argument(
"-t",
@@ -52,11 +94,7 @@ def build_parser() -> argparse.ArgumentParser:
help="Temperature to sample at. Lower values result in lower diversity but higher sequence recovery",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="",
help="Output file (fasta format) to write to. If not provided, default to input + .fasta",
"-n", type=int, default=10, help="Number of sequences to generate per structure"
)
parser.add_argument(
"-s",
@@ -72,31 +110,38 @@ def main():
parser = build_parser()
args = parser.parse_args()
# Load in the file
fpath = args.fname
chain_id = args.chain
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
# Coords have shape (seq_len, 3, 3)
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
structure
)
logging.info(f"Native sequence: {native_seq}")
# Load the model
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()
# Sample the residues
torch.manual_seed(args.seed)
sampled_seq = model.sample(coords, temperature=1.0)
logging.info(f"Sampled sequence: {sampled_seq}")
# If output file is given, write it
out = args.output
if out == "":
out = fpath.replace(".pdb", ".fasta")
if args.output:
write_fa(out, sampled_seq)
pfunc = functools.partial(
generate_residues, model=model, n=args.n, temperature=args.temperature
)
# Load in the file
if os.path.isfile(args.fname):
# If output file is given, write it
sequences = pfunc(args.fname)
for i, seq in enumerate(sequences):
out_fname = update_fname(args.fname, i)
write_fa(out_fname, seq)
elif os.path.isdir(args.fname):
# create a subdirecotry to store the fastas
outdir = os.path.join(args.fname, "esm_generated_fastas")
os.makedirs(outdir, exist_ok=True)
logging.info(f"Writing output to {outdir}")
# Query for inputs and process them
inputs = glob.glob(os.path.join(args.fname, "*.pdb"))
logging.info(f"Found {len(inputs)} PDB files to process in {args.fname}")
generated = [pfunc(f) for f in tqdm(inputs)]
# Write outputs
for orig_fname, seqs in zip(inputs, generated):
for i, seq in enumerate(seqs):
out_fname = update_fname(orig_fname, i, new_dir=outdir)
write_fa(out_fname, seq)
else:
raise RuntimeError(f"Expected {args.fname} to be a file or directory")
if __name__ == "__main__":