mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
243 lines
7.8 KiB
Python
243 lines
7.8 KiB
Python
"""
|
|
Count secondary structures in a PDB file as determined by p-sea
|
|
|
|
https://www.biotite-python.org/apidoc/biotite.structure.annotate_sse.html
|
|
"""
|
|
|
|
# Examples:
|
|
# python ~/projects/protdiff/bin/annot_secondary_structures.py sampled_pdb/*.pdb plots/ss_cooccurrence_sampled.pdf
|
|
# python ~/projects/protdiff/bin/annot_secondary_structures.py model_snapshot/training_args.json plots/ss_cooccurrence_test.pdf
|
|
|
|
import json
|
|
import os, sys
|
|
from pathlib import Path
|
|
import logging
|
|
import warnings
|
|
import functools
|
|
import multiprocessing as mp
|
|
import argparse
|
|
from itertools import groupby
|
|
from collections import Counter
|
|
from typing import Tuple, Collection, Literal, Dict, Any
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
import biotite.structure as struc
|
|
from biotite.application import dssp
|
|
from biotite.structure.io.pdb import PDBFile
|
|
|
|
SSE_BACKEND = Literal["dssp", "psea"]
|
|
|
|
from train import get_train_valid_test_sets
|
|
|
|
from foldingdiff.angles_and_coords import get_pdb_length
|
|
|
|
|
|
def build_datasets(training_args: Dict[str, Any]):
|
|
"""
|
|
Build datasets given args again
|
|
"""
|
|
# Build args based on training args
|
|
dset_args = dict(
|
|
timesteps=training_args["timesteps"],
|
|
variance_schedule=training_args["variance_schedule"],
|
|
max_seq_len=training_args["max_seq_len"],
|
|
min_seq_len=training_args["min_seq_len"],
|
|
var_scale=training_args["variance_scale"],
|
|
syn_noiser=training_args["syn_noiser"],
|
|
exhaustive_t=training_args["exhaustive_validation_t"],
|
|
single_angle_debug=training_args["single_angle_debug"],
|
|
single_time_debug=training_args["single_timestep_debug"],
|
|
toy=training_args["subset"],
|
|
angles_definitions=training_args["angles_definitions"],
|
|
train_only=False,
|
|
)
|
|
|
|
train_dset, valid_dset, test_dset = get_train_valid_test_sets(**dset_args)
|
|
logging.info(
|
|
f"Training dset contains features: {train_dset.feature_names} - angular {train_dset.feature_is_angular}"
|
|
)
|
|
return train_dset, valid_dset, test_dset
|
|
|
|
|
|
def count_structures_in_pdb(
|
|
fname: str, backend: SSE_BACKEND = "psea"
|
|
) -> Tuple[int, int]:
|
|
"""Count the secondary structures (# alpha, # beta) in the given pdb file"""
|
|
assert os.path.exists(fname)
|
|
|
|
# Get the secondary structure
|
|
warnings.filterwarnings("ignore", ".*elements were guessed from atom_.*")
|
|
source = PDBFile.read(fname)
|
|
if source.get_model_count() > 1:
|
|
return (-1, -1)
|
|
source_struct = source.get_structure()[0]
|
|
chain_ids = np.unique(source_struct.chain_id)
|
|
assert len(chain_ids) == 1
|
|
chain_id = chain_ids[0]
|
|
|
|
if backend == "psea":
|
|
# a = alpha helix, b = beta sheet, c = coil
|
|
ss = struc.annotate_sse(source_struct, chain_id)
|
|
# https://stackoverflow.com/questions/6352425/whats-the-most-pythonic-way-to-identify-consecutive-duplicates-in-a-list
|
|
ss_grouped = [(k, sum(1 for _ in g)) for k, g in groupby(ss)]
|
|
ss_counts = Counter([chain for chain, _ in ss_grouped])
|
|
|
|
num_alpha = ss_counts["a"] if "a" in ss_counts else 0
|
|
num_beta = ss_counts["b"] if "b" in ss_counts else 0
|
|
elif backend == "dssp":
|
|
# https://www.biotite-python.org/apidoc/biotite.application.dssp.DsspApp.html#biotite.application.dssp.DsspApp
|
|
app = dssp.DsspApp(source_struct)
|
|
app.start()
|
|
app.join()
|
|
ss = app.get_sse()
|
|
ss_grouped = [(k, sum(1 for _ in g)) for k, g in groupby(ss)]
|
|
ss_counts = Counter([chain for chain, _ in ss_grouped])
|
|
|
|
num_alpha = ss_counts["H"] if "H" in ss_counts else 0
|
|
num_beta = ss_counts["B"] if "B" in ss_counts else 0
|
|
else:
|
|
raise ValueError(
|
|
f"Unrecognized backend for calculating secondary structures: {backend}"
|
|
)
|
|
logging.debug(f"From {fname}:\t{num_alpha} {num_beta}")
|
|
return num_alpha, num_beta
|
|
|
|
|
|
def make_ss_cooccurrence_plot(
|
|
pdb_files: Collection[str],
|
|
outpdf: str,
|
|
json_file: str = "",
|
|
max_seq_len: int = 0,
|
|
backend: SSE_BACKEND = "psea",
|
|
threads: int = 8,
|
|
title: str = "Secondary structure co-occurrence",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Create a secondary structure co-occurrence plot
|
|
**kwargs are passed to hist2d
|
|
"""
|
|
if max_seq_len > 0:
|
|
orig_len = len(pdb_files)
|
|
pdb_files = [p for p in pdb_files if get_pdb_length(p) <= max_seq_len]
|
|
logging.info(
|
|
f"Filtering out sequences with more than {max_seq_len} residues: {orig_len} --> {len(pdb_files)}"
|
|
)
|
|
logging.info(f"Calculating {len(pdb_files)} structures using {backend}")
|
|
pfunc = functools.partial(count_structures_in_pdb, backend=backend)
|
|
pool = mp.Pool(threads)
|
|
alpha_beta_counts = list(pool.map(pfunc, pdb_files, chunksize=10))
|
|
pool.close()
|
|
pool.join()
|
|
|
|
alpha_beta_counts = [p for p in alpha_beta_counts if p != (-1, -1)]
|
|
alpha_counts, beta_counts = zip(*alpha_beta_counts)
|
|
|
|
# Write a json file if specified
|
|
if json_file:
|
|
logging.info(f"Writing json of ss counts to {json_file}")
|
|
with open(json_file, "w") as sink:
|
|
json.dump(
|
|
{
|
|
os.path.basename(k): ab_counts
|
|
for k, ab_counts in zip(pdb_files, alpha_beta_counts)
|
|
},
|
|
sink,
|
|
indent=4,
|
|
)
|
|
|
|
fig, ax = plt.subplots(dpi=300)
|
|
h = ax.hist2d(
|
|
alpha_counts,
|
|
beta_counts,
|
|
bins=np.arange(10),
|
|
density=True,
|
|
vmin=0.0,
|
|
**kwargs,
|
|
)
|
|
ax.set_xlabel(r"Number of $\alpha$ helices", fontsize=12)
|
|
ax.set_ylabel(r"Number of $\beta$ sheets", fontsize=12)
|
|
if title:
|
|
ax.set_title(title.strip(), fontsize=14)
|
|
cbar = fig.colorbar(h[-1], ax=ax)
|
|
cbar.ax.set_ylabel("Frequency", fontsize=12)
|
|
fig.savefig(outpdf, bbox_inches="tight")
|
|
|
|
|
|
def build_parser():
|
|
parser = argparse.ArgumentParser(
|
|
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
"infiles",
|
|
type=str,
|
|
nargs="+",
|
|
help="PDB files to compute secondary structures for, or json file containing config for which we take test set",
|
|
)
|
|
parser.add_argument(
|
|
"outpdf",
|
|
type=str,
|
|
help="PDF file to write plot of secondary structure co-occurrence frequencies",
|
|
)
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
choices=["dssp", "psea"],
|
|
default="psea",
|
|
help="Backend for calculating secondary structure",
|
|
)
|
|
parser.add_argument(
|
|
"-t",
|
|
"--threads",
|
|
type=int,
|
|
default=mp.cpu_count(),
|
|
help="Number of threads to use",
|
|
)
|
|
parser.add_argument("--title", type=str, default="", help="Title for plot")
|
|
parser.add_argument(
|
|
"--freqlim",
|
|
type=float,
|
|
default=0.09,
|
|
help="Upper limit for frequency in 2D histogram. Set to 0 to disable.",
|
|
)
|
|
parser.add_argument(
|
|
"--json",
|
|
type=str,
|
|
default="",
|
|
help="JSON file to write co-occurences in (alpha, beta)",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main():
|
|
"""Run the script"""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
|
|
fnames = args.infiles
|
|
is_test_data = False
|
|
if len(fnames) == 1 and fnames[0].endswith(".json"):
|
|
is_test_data = True
|
|
with open(fnames[0]) as source:
|
|
training_args = json.load(source)
|
|
_, _, test_dset = build_datasets(training_args)
|
|
fnames = test_dset.filenames
|
|
|
|
make_ss_cooccurrence_plot(
|
|
pdb_files=fnames,
|
|
outpdf=args.outpdf,
|
|
json_file=args.json,
|
|
backend=args.backend,
|
|
threads=args.threads,
|
|
title=args.title,
|
|
max_seq_len=test_dset.dset.pad if is_test_data else 0,
|
|
vmax=args.freqlim if args.freqlim > 0 else None,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main()
|