mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Flesh out the residues generation script to handle entire folders
This commit is contained in:
@@ -13,8 +13,14 @@ pip install git+https://github.com/facebookresearch/esm.git
|
||||
|
||||
# uses the following notebook as a reference:
|
||||
# https://colab.research.google.com/github/facebookresearch/esm/blob/main/examples/inverse_folding/notebook.ipynb
|
||||
import os
|
||||
import glob
|
||||
import functools
|
||||
import logging
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# Verfies that the environment is set up correctly
|
||||
import torch
|
||||
@@ -28,6 +34,7 @@ import esm.inverse_folding
|
||||
|
||||
def write_fa(fname: str, seq: str):
|
||||
"""Write a fasta file"""
|
||||
assert fname.endswith(".fasta")
|
||||
with open(fname, "w") as f:
|
||||
f.write(">sampled\n")
|
||||
for chunk in [seq[i : i + 80] for i in range(0, len(seq), 80)]:
|
||||
@@ -35,14 +42,49 @@ def write_fa(fname: str, seq: str):
|
||||
return fname
|
||||
|
||||
|
||||
def generate_residues(
|
||||
fpath: str, model, chain_id: str = "A", n: int = 10, temperature: float = 1.0
|
||||
) -> List[str]:
|
||||
"""Generate residues for the structure contained in the PDB file"""
|
||||
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
|
||||
# Coords have shape (seq_len, 3, 3)
|
||||
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
|
||||
structure
|
||||
)
|
||||
logging.debug(f"Native sequence: {native_seq}")
|
||||
retval = []
|
||||
for _ in range(n):
|
||||
sampled_seq = model.sample(coords, temperature=temperature)
|
||||
logging.debug(f"Sampled sequence: {sampled_seq}")
|
||||
retval.append(sampled_seq)
|
||||
return retval
|
||||
|
||||
|
||||
def update_fname(fname: str, i: int, new_dir: str = "") -> str:
|
||||
"""
|
||||
Update the pdb filename to include a numeric index and a .fasta extension.
|
||||
If new_dir is given then we move the output filename to that directory.
|
||||
"""
|
||||
assert os.path.isfile(fname)
|
||||
parent, child = os.path.split(fname)
|
||||
assert child
|
||||
child_base, _child_ext = os.path.splitext(child)
|
||||
assert child_base
|
||||
if new_dir:
|
||||
assert os.path.isdir(new_dir), f"Expected {new_dir} to be a directory"
|
||||
parent = new_dir
|
||||
return os.path.join(parent, f"{child_base}_esm_generated_{i}.fasta")
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
"""Build a basic CLI"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("fname", type=str, help="PDB file to generate residues for")
|
||||
parser.add_argument(
|
||||
"-c", "--chain", type=str, default="A", help="Chain to use within PDB file"
|
||||
"fname",
|
||||
type=str,
|
||||
help="PDB file to generate residues for, or a folder containing these",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
@@ -52,11 +94,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
help="Temperature to sample at. Lower values result in lower diversity but higher sequence recovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=str,
|
||||
default="",
|
||||
help="Output file (fasta format) to write to. If not provided, default to input + .fasta",
|
||||
"-n", type=int, default=10, help="Number of sequences to generate per structure"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
@@ -72,31 +110,38 @@ def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load in the file
|
||||
fpath = args.fname
|
||||
chain_id = args.chain
|
||||
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
|
||||
# Coords have shape (seq_len, 3, 3)
|
||||
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
|
||||
structure
|
||||
)
|
||||
logging.info(f"Native sequence: {native_seq}")
|
||||
|
||||
# Load the model
|
||||
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
|
||||
model = model.eval()
|
||||
|
||||
# Sample the residues
|
||||
torch.manual_seed(args.seed)
|
||||
sampled_seq = model.sample(coords, temperature=1.0)
|
||||
logging.info(f"Sampled sequence: {sampled_seq}")
|
||||
|
||||
# If output file is given, write it
|
||||
out = args.output
|
||||
if out == "":
|
||||
out = fpath.replace(".pdb", ".fasta")
|
||||
if args.output:
|
||||
write_fa(out, sampled_seq)
|
||||
pfunc = functools.partial(
|
||||
generate_residues, model=model, n=args.n, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Load in the file
|
||||
if os.path.isfile(args.fname):
|
||||
# If output file is given, write it
|
||||
sequences = pfunc(args.fname)
|
||||
for i, seq in enumerate(sequences):
|
||||
out_fname = update_fname(args.fname, i)
|
||||
write_fa(out_fname, seq)
|
||||
elif os.path.isdir(args.fname):
|
||||
# create a subdirecotry to store the fastas
|
||||
outdir = os.path.join(args.fname, "esm_generated_fastas")
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
logging.info(f"Writing output to {outdir}")
|
||||
# Query for inputs and process them
|
||||
inputs = glob.glob(os.path.join(args.fname, "*.pdb"))
|
||||
logging.info(f"Found {len(inputs)} PDB files to process in {args.fname}")
|
||||
generated = [pfunc(f) for f in tqdm(inputs)]
|
||||
# Write outputs
|
||||
for orig_fname, seqs in zip(inputs, generated):
|
||||
for i, seq in enumerate(seqs):
|
||||
out_fname = update_fname(orig_fname, i, new_dir=outdir)
|
||||
write_fa(out_fname, seq)
|
||||
else:
|
||||
raise RuntimeError(f"Expected {args.fname} to be a file or directory")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user