Files
ReQFlow/models/flow_module.py
zichongwang 02d1131783 modify prot
2025-02-25 16:16:11 +08:00

687 lines
28 KiB
Python

from typing import Any
import torch
import time
import os
import random
import wandb
import numpy as np
import pandas as pd
import logging
import torch.distributed as dist
from pytorch_lightning import LightningModule
from analysis import metrics
from analysis import utils as au
from models.flow_model import FlowModel
from models import utils as mu
from data.interpolant import Interpolant
from data import utils as du
from data import all_atom
from data import so3_utils
from data import residue_constants
from experiments import utils as eu
from pytorch_lightning.loggers.wandb import WandbLogger
import esm
from data import utils as du
from data import residue_constants
from biotite.sequence.io import fasta
import subprocess
from typing import Optional
from analysis import metrics
import shutil
from datetime import datetime
from openfold.utils.rigid_utils import rot_to_quat, quat_to_rot
class FlowModule(LightningModule):
def __init__(self, cfg):
super().__init__()
self._print_logger = logging.getLogger(__name__)
self._exp_cfg = cfg.experiment
self._model_cfg = cfg.model
self._data_cfg = cfg.data
self._interpolant_cfg = cfg.interpolant
# Set-up vector field prediction model
self.model = FlowModel(cfg.model)
# Set-up interpolant
self.interpolant = Interpolant(cfg.interpolant)
self.validation_epoch_metrics = []
self.validation_epoch_samples = []
self.save_hyperparameters()
self._checkpoint_dir = None
self._inference_dir = None
self._time_records = pd.DataFrame(
columns=['length', 'sample_path', 'start_time', 'sample_time', 'eval_time', 'total_time', 'memory_allocated', 'memory_reserved'])
print("self._exp_cfg.is_training", self._exp_cfg.is_training)
if not self._exp_cfg.is_training:
self._folding_model = esm.pretrained.esmfold_v1().eval()
self.save_hyperparameters(ignore=["_folding_model"])
for param in self._folding_model.parameters():
param.requires_grad = False
@property
def checkpoint_dir(self):
if self._checkpoint_dir is None:
if dist.is_initialized():
if dist.get_rank() == 0:
checkpoint_dir = [self._exp_cfg.checkpointer.dirpath]
else:
checkpoint_dir = [None]
dist.broadcast_object_list(checkpoint_dir, src=0)
checkpoint_dir = checkpoint_dir[0]
else:
checkpoint_dir = self._exp_cfg.checkpointer.dirpath
self._checkpoint_dir = checkpoint_dir
os.makedirs(self._checkpoint_dir, exist_ok=True)
return self._checkpoint_dir
@property
def inference_dir(self):
if self._inference_dir is None:
if dist.is_initialized():
if dist.get_rank() == 0:
inference_dir = [self._exp_cfg.inference_dir]
else:
inference_dir = [None]
dist.broadcast_object_list(inference_dir, src=0)
inference_dir = inference_dir[0]
else:
inference_dir = self._exp_cfg.inference_dir
self._inference_dir = inference_dir
os.makedirs(self._inference_dir, exist_ok=True)
return self._inference_dir
def on_train_start(self):
self._epoch_start_time = time.time()
def on_train_epoch_end(self):
epoch_time = (time.time() - self._epoch_start_time) / 60.0
self.log(
'train/epoch_time_minutes',
epoch_time,
on_step=False,
on_epoch=True,
prog_bar=False
)
self._epoch_start_time = time.time()
def model_step(self, noisy_batch: Any):
training_cfg = self._exp_cfg.training
loss_mask = noisy_batch['res_mask'] * noisy_batch['diffuse_mask']
if torch.any(torch.sum(loss_mask, dim=-1) < 1):
raise ValueError('Empty batch encountered')
num_batch, num_res = loss_mask.shape
# Ground truth labels
gt_trans_1 = noisy_batch['trans_1']
gt_rotmats_1 = noisy_batch['rotmats_1']
gt_rotquats_1 = noisy_batch['rotquats_1']
rotquats_t = noisy_batch['rotquats_t']
gt_rot_quat_vf = so3_utils.calc_quat_wt_qt_q1(
rotquats_t, gt_rotquats_1.type(torch.float32))
if torch.any(torch.isnan(gt_rot_quat_vf)):
raise ValueError('NaN encountered in gt_rot_quat_vf')
gt_bb_atoms = all_atom.to_atom37(gt_trans_1, gt_rotmats_1)[:, :, :3]
# Timestep used for normalization.
r3_t = noisy_batch['r3_t']
so3_t = noisy_batch['so3_t']
r3_norm_scale = 1 - torch.min(
r3_t[..., None], torch.tensor(training_cfg.t_normalize_clip))
so3_norm_scale = 1 - torch.min(
so3_t[..., None], torch.tensor(training_cfg.t_normalize_clip))
# Model output predictions.
model_output = self.model(noisy_batch)
pred_trans_1 = model_output['pred_trans']
pred_rotmats_1 = model_output['pred_rotmats']
pred_rotquats_1 = model_output['pred_rotquats']
# pred_rots_quat_vf = so3_utils.calc_quat_wt_qt_q1( #* Exp schedule
# rotquats_t, pred_rotquats_1) * 10 * torch.exp( -so3_t[..., None] * 10)
pred_rots_quat_vf = so3_utils.calc_quat_wt_qt_q1(
rotquats_t, pred_rotquats_1)
if torch.any(torch.isnan(pred_rots_quat_vf)):
raise ValueError('NaN encountered in pred_rots_quat_vf')
# Backbone atom loss
pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3]
gt_bb_atoms *= training_cfg.bb_atom_scale / r3_norm_scale[..., None]
pred_bb_atoms *= training_cfg.bb_atom_scale / r3_norm_scale[..., None]
loss_denom = torch.sum(loss_mask, dim=-1) * 3
bb_atom_loss = torch.sum(
(gt_bb_atoms - pred_bb_atoms) ** 2 * loss_mask[..., None, None],
dim=(-1, -2, -3)
) / loss_denom
# Translation VF loss
trans_error = (gt_trans_1 - pred_trans_1) / r3_norm_scale * training_cfg.trans_scale
trans_loss = training_cfg.translation_loss_weight * torch.sum(
trans_error ** 2 * loss_mask[..., None],
dim=(-1, -2)
) / loss_denom
trans_loss = torch.clamp(trans_loss, max=5)
# Quat VF loss
rots_quats_vf_error = (gt_rot_quat_vf - pred_rots_quat_vf) / so3_norm_scale
rots_quats_vf_loss = training_cfg.rotation_loss_weights * torch.sum(
rots_quats_vf_error ** 2 * loss_mask[..., None],
dim=(-1, -2)
) / loss_denom
# Pairwise distance loss
gt_flat_atoms = gt_bb_atoms.reshape([num_batch, num_res*3, 3])
gt_pair_dists = torch.linalg.norm(
gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
pred_flat_atoms = pred_bb_atoms.reshape([num_batch, num_res*3, 3])
pred_pair_dists = torch.linalg.norm(
pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)
flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))
flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3])
flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))
flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3])
gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :]
dist_mat_loss = torch.sum(
(gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
dim=(1, 2))
dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) + 1)
se3_vf_loss = trans_loss + rots_quats_vf_loss
auxiliary_loss = (
bb_atom_loss * training_cfg.aux_loss_use_bb_loss
+ dist_mat_loss * training_cfg.aux_loss_use_pair_loss
)
auxiliary_loss *= (
(r3_t[:, 0] > training_cfg.aux_loss_t_pass)
& (so3_t[:, 0] > training_cfg.aux_loss_t_pass)
)
auxiliary_loss *= self._exp_cfg.training.aux_loss_weight
auxiliary_loss = torch.clamp(auxiliary_loss, max=5)
se3_vf_loss += auxiliary_loss
if torch.any(torch.isnan(se3_vf_loss)):
raise ValueError('NaN loss encountered')
# print(f"trans_loss: {trans_loss}\n aux_loss: {auxiliary_loss}\n rots_quats_vf_loss: {rots_quats_vf_loss}\n se3_vf_loss: {se3_vf_loss}")
return {
"trans_loss": trans_loss,
"auxiliary_loss": auxiliary_loss,
"rots_quats_vf_loss": rots_quats_vf_loss,
"se3_vf_loss": se3_vf_loss
}
def validation_step(self, batch: Any, batch_idx: int):
if self._data_cfg.rectify:
res_mask = batch['sample']['res_mask']
self.interpolant.set_device(res_mask.device)
num_batch, num_res = res_mask.shape
diffuse_mask = batch['sample']['diffuse_mask']
csv_idx = batch['sample']['csv_idx']
prot_traj, atom37_traj, _, _ = self.interpolant.sample(
num_batch,
num_res,
self.model,
trans_1=batch['sample']['trans_1'],
rotmats_1=batch['sample']['rotmats_1'],
diffuse_mask=diffuse_mask,
chain_idx=batch['sample']['chain_idx'],
res_idx=batch['sample']['res_idx'],
)
else:
res_mask = batch['res_mask']
self.interpolant.set_device(res_mask.device)
num_batch, num_res = res_mask.shape
diffuse_mask = batch['diffuse_mask']
csv_idx = batch['csv_idx']
prot_traj, atom37_traj, _, _ = self.interpolant.sample(
num_batch,
num_res,
self.model,
trans_1=batch['trans_1'],
rotmats_1=batch['rotmats_1'],
diffuse_mask=diffuse_mask,
chain_idx=batch['chain_idx'],
res_idx=batch['res_idx'],
)
samples = atom37_traj[-1].numpy()
batch_metrics = []
for i in range(num_batch):
sample_dir = os.path.join(
self.checkpoint_dir,
f'sample_{csv_idx[i].item()}_idx_{batch_idx}_len_{num_res}'
)
os.makedirs(sample_dir, exist_ok=True)
# Write out sample to PDB file
final_pos = samples[i]
saved_path = au.write_prot_to_pdb(
final_pos,
os.path.join(sample_dir, 'sample.pdb'),
no_indexing=True
)
if isinstance(self.logger, WandbLogger):
self.validation_epoch_samples.append(
[saved_path, self.global_step, wandb.Molecule(saved_path)]
)
mdtraj_metrics = metrics.calc_mdtraj_metrics(saved_path)
ca_idx = residue_constants.atom_order['CA']
ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, ca_idx])
batch_metrics.append((mdtraj_metrics | ca_ca_metrics))
batch_metrics = pd.DataFrame(batch_metrics)
self.validation_epoch_metrics.append(batch_metrics)
def on_validation_epoch_end(self):
if len(self.validation_epoch_samples) > 0:
self.logger.log_table(
key='valid/samples',
columns=["sample_path", "global_step", "Protein"],
data=self.validation_epoch_samples)
self.validation_epoch_samples.clear()
val_epoch_metrics = pd.concat(self.validation_epoch_metrics)
for metric_name,metric_val in val_epoch_metrics.mean().to_dict().items():
self._log_scalar(
f'valid/{metric_name}',
metric_val,
on_step=False,
on_epoch=True,
prog_bar=False,
batch_size=len(val_epoch_metrics),
)
self.validation_epoch_metrics.clear()
def _log_scalar(
self,
key,
value,
on_step=True,
on_epoch=False,
prog_bar=True,
batch_size=None,
sync_dist=False,
rank_zero_only=True
):
if sync_dist and rank_zero_only:
raise ValueError('Unable to sync dist when rank_zero_only=True')
self.log(
key,
value,
on_step=on_step,
on_epoch=on_epoch,
prog_bar=prog_bar,
batch_size=batch_size,
sync_dist=sync_dist,
rank_zero_only=rank_zero_only
)
def training_step(self, batch: Any, stage: int):
#* If rectify, A batch is a dictionary with 'sample' and 'noise'.
step_start_time = time.time()
if self._data_cfg.rectify:
self.interpolant.set_device(batch['sample']['res_mask'].device)
noisy_batch = self.interpolant.rectify_corrupt_batch(batch)
else:
self.interpolant.set_device(batch['res_mask'].device)
noisy_batch = self.interpolant.corrupt_batch(batch)
if self._interpolant_cfg.self_condition and random.random() > 0.5:
with torch.no_grad():
model_sc = self.model(noisy_batch)
noisy_batch['trans_sc'] = (
model_sc['pred_trans'] * noisy_batch['diffuse_mask'][..., None]
+ noisy_batch['trans_1'] * (1 - noisy_batch['diffuse_mask'][..., None])
)
batch_losses = self.model_step(noisy_batch)
num_batch = batch_losses['trans_loss'].shape[0]
total_losses = {
k: torch.mean(v) for k,v in batch_losses.items()
}
for k,v in total_losses.items():
self._log_scalar(
f"train/{k}", v, prog_bar=False, batch_size=num_batch)
# Losses to track. Stratified across t.
so3_t = torch.squeeze(noisy_batch['so3_t'])
self._log_scalar(
"train/so3_t",
np.mean(du.to_numpy(so3_t)),
prog_bar=False, batch_size=num_batch)
r3_t = torch.squeeze(noisy_batch['r3_t'])
self._log_scalar(
"train/r3_t",
np.mean(du.to_numpy(r3_t)),
prog_bar=False, batch_size=num_batch)
for loss_name, loss_dict in batch_losses.items():
if loss_name == 'rots_vf_loss':
batch_t = so3_t
else:
batch_t = r3_t
stratified_losses = mu.t_stratified_loss(
batch_t, loss_dict, loss_name=loss_name)
for k,v in stratified_losses.items():
self._log_scalar(
f"train/{k}", v, prog_bar=False, batch_size=num_batch)
# Training throughput
if self._data_cfg.rectify:
scaffold_percent = torch.mean(batch['sample']['diffuse_mask'].float()).item()
self._log_scalar(
"train/scaffolding_percent",
scaffold_percent, prog_bar=False, batch_size=num_batch)
motif_mask = 1 - batch['sample']['diffuse_mask'].float()
num_motif_res = torch.sum(motif_mask, dim=-1)
self._log_scalar(
"train/motif_size",
torch.mean(num_motif_res).item(), prog_bar=False, batch_size=num_batch)
self._log_scalar(
"train/length", batch['sample']['res_mask'].shape[1], prog_bar=False, batch_size=num_batch)
self._log_scalar(
"train/batch_size", num_batch, prog_bar=False)
step_time = time.time() - step_start_time
self._log_scalar(
"train/examples_per_second", num_batch / step_time)
train_loss = total_losses['se3_vf_loss']
self._log_scalar(
"train/loss", train_loss, batch_size=num_batch)
else:
scaffold_percent = torch.mean(batch['diffuse_mask'].float()).item()
self._log_scalar(
"train/scaffolding_percent",
scaffold_percent, prog_bar=False, batch_size=num_batch)
motif_mask = 1 - batch['diffuse_mask'].float()
num_motif_res = torch.sum(motif_mask, dim=-1)
self._log_scalar(
"train/motif_size",
torch.mean(num_motif_res).item(), prog_bar=False, batch_size=num_batch)
self._log_scalar(
"train/length", batch['res_mask'].shape[1], prog_bar=False, batch_size=num_batch)
self._log_scalar(
"train/batch_size", num_batch, prog_bar=False)
step_time = time.time() - step_start_time
self._log_scalar(
"train/examples_per_second", num_batch / step_time)
train_loss = total_losses['se3_vf_loss']
self._log_scalar(
"train/loss", train_loss, batch_size=num_batch)
return train_loss
def configure_optimizers(self):
return torch.optim.AdamW(
params=self.model.parameters(),
**self._exp_cfg.optimizer
)
def predict_step(self, batch, batch_idx):
'''
The main inference function for the model.
This function is called by the PyTorch Lightning Trainer.
Each time this function is called, the model generates a batch of protein structure.
Dataloader will be used automatically to generate the batch.
'''
del batch_idx # Unused
device = f'cuda:{torch.cuda.current_device()}'
interpolant = Interpolant(self._infer_cfg.interpolant)
interpolant.set_device(device)
sample_ids = batch['sample_id'].squeeze().tolist()
sample_ids = [sample_ids] if isinstance(sample_ids, int) else sample_ids
num_batch = len(sample_ids)
if 'diffuse_mask' in batch: # motif-scaffolding
target = batch['target'][0]
trans_1 = batch['trans_1']
rotmats_1 = batch['rotmats_1']
diffuse_mask = batch['diffuse_mask']
true_bb_pos = all_atom.atom37_from_trans_rot(trans_1, rotmats_1, 1 - diffuse_mask)
true_bb_pos = true_bb_pos[..., :3, :].reshape(-1, 3).cpu().numpy()
_, sample_length, _ = trans_1.shape
sample_dirs = [os.path.join(
self.inference_dir, target, f'sample_{str(sample_id)}')
for sample_id in sample_ids]
else: # unconditional
sample_length = batch['num_res'].item()
true_bb_pos = None
sample_dirs = [os.path.join(
self.inference_dir, f'length_{sample_length}', f'sample_{str(sample_id)}')
for sample_id in sample_ids]
trans_1 = rotmats_1 = diffuse_mask = None
diffuse_mask = torch.ones(1, sample_length, device=device)
for i in range(num_batch):
sample_dir = sample_dirs[i]
os.makedirs(sample_dir, exist_ok=True)
if 'aatype' in batch:
aatype = du.to_numpy(batch['aatype'].long())[0]
else:
aatype = np.zeros(sample_length, dtype=int)
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
start_sample_time = time.time()
prot_traj, atom37_traj, model_traj, _ = interpolant.sample(
1, sample_length, self.model,
trans_1=trans_1, rotmats_1=rotmats_1, diffuse_mask=diffuse_mask
)
finish_sample_time = time.time()
sample_elapsed_time = finish_sample_time - start_sample_time
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
bb_trajs = du.to_numpy(torch.stack(atom37_traj, dim=0).transpose(0, 1))
bb_traj = bb_trajs[0]
traj_paths = eu.save_traj(
bb_traj[-1],
bb_traj[0],
np.flip(du.to_numpy(torch.concat(model_traj, dim=0)), axis=0),
du.to_numpy(diffuse_mask)[0],
output_dir=sample_dir,
aatype=aatype,
)
if self._infer_cfg.samples.seq_per_sample > 0:
pdb_path = traj_paths['sample_path']
sc_output_dir = os.path.join(sample_dir, 'self_consistency')
os.makedirs(sc_output_dir, exist_ok=True)
shutil.copy(pdb_path, os.path.join(
sc_output_dir, os.path.basename(pdb_path)))
_ = self.run_self_consistency(
sc_output_dir,
pdb_path,
motif_mask=None
)
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
self._print_logger.info(f'Done sampling sample {i}: {pdb_path}, allocated {allocated:.2f}GB, reserved {reserved:.2f}GB')
finish_eval_time = time.time()
eval_time = finish_eval_time - finish_sample_time
total_time = finish_eval_time - start_sample_time
self._time_records = pd.concat(
[self._time_records, pd.DataFrame([[sample_length, pdb_path, start_time, sample_elapsed_time, eval_time, total_time, allocated, reserved]],
columns=self._time_records.columns)],
ignore_index=True
)
torch.cuda.empty_cache()
else:
pdb_path = traj_paths['sample_path']
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
finish_eval_time = time.time()
eval_time = None
total_time = finish_eval_time - start_sample_time
self._time_records = pd.concat(
[self._time_records, pd.DataFrame([[sample_length, pdb_path, start_time, sample_elapsed_time, eval_time, total_time, allocated, reserved]],
columns=self._time_records.columns)],
ignore_index=True
)
torch.cuda.empty_cache()
def on_predict_epoch_end(self):
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
rank = self.global_rank
gathered_time_records = [None for _ in range(world_size)]
torch.distributed.all_gather_object(gathered_time_records, self._time_records)
if rank == 0:
# merge
all_time_records = pd.concat(gathered_time_records, ignore_index=True)
all_time_records['start_time'] = pd.to_datetime(all_time_records['start_time'])
all_time_records.sort_values(by='start_time', inplace=True)
csv_file_path = os.path.join(self.inference_dir, 'time_records.csv')
if os.path.exists(csv_file_path):
all_time_records.to_csv(csv_file_path, mode='a', header=False, index=False)
else:
all_time_records.to_csv(csv_file_path, index=False)
else:
csv_file_path = os.path.join(self.inference_dir, 'time_records.csv')
if os.path.exists(csv_file_path):
self._time_records.to_csv(csv_file_path, mode='a', header=False, index=False)
else:
self._time_records.to_csv(csv_file_path, index=False)
def run_self_consistency(
self,
decoy_pdb_dir: str,
reference_pdb_path: str,
motif_mask: Optional[np.ndarray]=None):
"""Run self-consistency on design proteins against reference protein.
Args:
decoy_pdb_dir: directory where designed protein files are stored.
reference_pdb_path: path to reference protein file
motif_mask: Optional mask of which residues are the motif.
Returns:
Writes ProteinMPNN outputs to decoy_pdb_dir/seqs
Writes ESMFold outputs to decoy_pdb_dir/esmf
Writes results in decoy_pdb_dir/sc_results.csv
"""
# Run PorteinMPNN
self._print_logger.info(f'Running ProteinMPNN on {decoy_pdb_dir}')
output_path = os.path.join(decoy_pdb_dir, "parsed_pdbs.jsonl")
#* parse different chains in the pdb file, and save the results to a jsonl file
process = subprocess.Popen([
'python',
f'{self._infer_cfg.pmpnn_dir}/helper_scripts/parse_multiple_chains.py',
f'--input_path={decoy_pdb_dir}',
f'--output_path={output_path}',
])
_ = process.wait()
num_tries = 0
ret = -1
pmpnn_args = [
'python',
f'{self._infer_cfg.pmpnn_dir}/protein_mpnn_run.py',
'--out_folder',
decoy_pdb_dir,
'--jsonl_path',
output_path,
'--num_seq_per_target',
str(self._infer_cfg.samples.seq_per_sample),
'--sampling_temp',
'0.1',
'--seed',
'38',
'--batch_size',
'1',
]
gpu_id = torch.cuda.current_device()
pmpnn_args.append('--device')
pmpnn_args.append(str(gpu_id))
while ret < 0:
try:
process = subprocess.Popen(
pmpnn_args,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT
)
ret = process.wait()
except Exception as e:
num_tries += 1
self._print_logger.info(f'Failed ProteinMPNN. Attempt {num_tries}/5')
torch.cuda.empty_cache()
if num_tries > 4:
raise e
mpnn_fasta_path = os.path.join(
decoy_pdb_dir,
'seqs',
os.path.basename(reference_pdb_path).replace('.pdb', '.fa')
)
# Run ESMFold on each ProteinMPNN sequence and calculate metrics.
self._print_logger.info(f'Running ESMFold on {mpnn_fasta_path}')
mpnn_results = {
'tm_score': [],
'sample_path': [],
'header': [],
'sequence': [],
'rmsd': [],
}
if motif_mask is not None:
# Only calculate motif RMSD if mask is specified.
mpnn_results['motif_rmsd'] = []
esmf_dir = os.path.join(decoy_pdb_dir, 'esmf')
os.makedirs(esmf_dir, exist_ok=True)
fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path)
sample_feats = du.parse_pdb_feats('sample', reference_pdb_path)
for i, (header, string) in enumerate(fasta_seqs.items()):
# Run ESMFold
esmf_sample_path = os.path.join(esmf_dir, f'sample_{i}.pdb')
_ = self.run_folding(string, esmf_sample_path)
esmf_feats = du.parse_pdb_feats('folded_sample', esmf_sample_path)
sample_seq = du.aatype_to_seq(sample_feats['aatype'])
# Calculate scTM of ESMFold outputs with reference protein
_, tm_score = metrics.calc_tm_score(
sample_feats['bb_positions'], esmf_feats['bb_positions'],
sample_seq, sample_seq)
rmsd = metrics.calc_aligned_rmsd(
sample_feats['bb_positions'], esmf_feats['bb_positions'])
if motif_mask is not None:
sample_motif = sample_feats['bb_positions'][motif_mask]
of_motif = esmf_feats['bb_positions'][motif_mask]
motif_rmsd = metrics.calc_aligned_rmsd(
sample_motif, of_motif)
mpnn_results['motif_rmsd'].append(motif_rmsd)
mpnn_results['rmsd'].append(rmsd)
mpnn_results['tm_score'].append(tm_score)
mpnn_results['sample_path'].append(esmf_sample_path)
mpnn_results['header'].append(header)
mpnn_results['sequence'].append(string)
# Save results to CSV
csv_path = os.path.join(decoy_pdb_dir, 'sc_results.csv')
mpnn_results = pd.DataFrame(mpnn_results)
mpnn_results.to_csv(csv_path)
def run_folding(self, sequence, save_path):
"""Run ESMFold on sequence."""
self._folding_model = self._folding_model.to(self.device)
with torch.no_grad():
output = self._folding_model.infer_pdb(sequence)
with open(save_path, "w") as f:
f.write(output)
return output