Refactor network

This commit is contained in:
Rohith Krishna
2024-01-30 03:16:52 +00:00
parent 4975565df3
commit fd53fffe50
93 changed files with 4882 additions and 20765 deletions

2
.gitignore vendored
View File

@@ -12,3 +12,5 @@ __pycache__/
*/run_scripts/
*/tests/
unit_tests/
ruff.toml
*/scratch/

View File

@@ -25,7 +25,7 @@ import torch.distributed as dist
from abc import ABC
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from se3_transformer.runtime.utils import get_local_rank
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:

View File

@@ -31,9 +31,9 @@ from torch import Tensor
from torch.utils.data import random_split, DataLoader, Dataset
from tqdm import tqdm
from se3_transformer.data_loading.data_module import DataModule
from se3_transformer.model.basis import get_basis
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
from rf2aa.SE3Transformer.se3_transformer.data_loading.data_module import DataModule
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:

View File

@@ -31,7 +31,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.runtime.utils import degree_to_dim
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
@lru_cache(maxsize=None)

View File

@@ -29,7 +29,7 @@ from typing import Dict
import torch
from torch import Tensor
from se3_transformer.runtime.utils import degree_to_dim
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])

View File

@@ -30,10 +30,10 @@ from dgl.ops import edge_softmax
from torch import Tensor
from typing import Dict, Optional, Union
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range

View File

@@ -33,8 +33,8 @@ from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):

View File

@@ -29,7 +29,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module):

View File

@@ -29,7 +29,7 @@ import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module):

View File

@@ -29,14 +29,14 @@ import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis, update_basis_with_fused
from rf2aa.SE3Transformer.se3_transformer.model.layers.attention import AttentionBlockSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.norm import NormSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.pooling import GPooling
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential):

View File

@@ -24,9 +24,9 @@
import argparse
import pathlib
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.runtime.utils import str2bool
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')

View File

@@ -29,8 +29,8 @@ from typing import Optional
import numpy as np
import torch
from se3_transformer.runtime.loggers import Logger
from se3_transformer.runtime.metrics import MeanAbsoluteError
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.metrics import MeanAbsoluteError
class BaseCallback(ABC):

View File

@@ -29,11 +29,11 @@ from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback
from se3_transformer.runtime.loggers import DLLogger
from se3_transformer.runtime.utils import to_cuda, get_local_rank
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import BaseCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import DLLogger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank
@torch.inference_mode()
@@ -57,10 +57,10 @@ def evaluate(model: nn.Module,
if __name__ == '__main__':
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
from se3_transformer.runtime.utils import init_distributed, seed_everything
from se3_transformer.model import SE3TransformerPooled, Fiber
from se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import init_distributed, seed_everything
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled, Fiber
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
import torch.distributed as dist
import logging
import sys

View File

@@ -31,7 +31,7 @@ import torch.distributed as dist
import wandb
from dllogger import Verbosity
from se3_transformer.runtime.utils import rank_zero_only
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import rank_zero_only
class Logger(ABC):

View File

@@ -36,16 +36,16 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
from rf2aa.SE3Transformer.se3_transformer.runtime.inference import evaluate
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity

View File

@@ -23,8 +23,8 @@
import torch
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )

View File

@@ -1,96 +0,0 @@
import torch
import torch.nn as nn
from icecream import ic
import inspect
import sys, os
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
sys.path.insert(0,script_dir+'SE3Transformer')
#sys.path.insert(0, '/home/ahern/projects/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer') # jue commented this -- might need to uncomment
from rf2aa.util_module import init_lecun_normal_param
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
se3_transformer_path = inspect.getfile(SE3Transformer)
se3_fiber_path = inspect.getfile(Fiber)
#ic(se3_transformer_path, se3_fiber_path)
assert 'rf2aa' in se3_transformer_path
class SE3TransformerWrapper(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32):
super().__init__()
# Build the network
self.l1_in = l1_in_features
self.l1_out = l1_out_features
#
fiber_edge = Fiber({0: num_edge_features})
if l1_out_features > 0:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
self.se3 = SE3Transformer(num_layers=num_layers,
fiber_in=fiber_in,
fiber_hidden=fiber_hidden,
fiber_out = fiber_out,
num_heads=n_heads,
channels_div=div,
fiber_edge=fiber_edge,
populate_edge="arcsin",
final_layer="lin",
use_layer_norm=True)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
if self.l1_out > 0:
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
node_features = {'0': type_0_features, '1': type_1_features}
else:
node_features = {'0': type_0_features}
edge_features = {'0': edge_features}
return self.se3(G, node_features, edge_features)

View File

@@ -22,7 +22,7 @@ DATASET_PARAMS = [
"fraction_sm_compl_covale",
"fraction_sm",
"fraction_atomize_pdb",
"fraction_atomize_complex"
"fraction_atomize_complex",
"fraction_sm_compl_asmb",
"fraction_sm_compl_furthest_neg",
"fraction_sm_compl_permuted_neg",
@@ -51,6 +51,7 @@ DATASET_PARAMS = [
"n_valid_sm_compl_furthest_neg",
"n_valid_sm_compl_permuted_neg",
"n_valid_sm_compl_docked_neg",
"n_valid_atomize_complex",
"n_valid_dude_actives",
"n_valid_dude_inactives",
"p_short_crop",
@@ -268,6 +269,19 @@ def get_args(parser: Optional[argparse.ArgumentParser] = None, input_args: Optio
help="probability of a given non-standard residue being atomized, rather "
"than being converted to a standard equivalent [1.0]",
)
data_group.add_argument(
"-batch_by_dataset",
action="store_true",
default=False,
help="Batch examples by dataset, e.g., all nodes receive an example from the same dataset. [False]",
)
data_group.add_argument(
"-batch_by_length",
action="store_true",
default=False,
help="Batch examples by example length, e.g., all nodes receive a similarly-sized example. [False]",
)
# dataset parameters
dataset_group = parser.add_argument_group("data loading parameters")

View File

@@ -4,12 +4,12 @@ import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from rf2aa.parsers import parse_a3m, parse_fasta, read_template_pdb
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.data.parsers import parse_a3m, parse_fasta, read_template_pdb
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
from rf2aa.ffindex import *
from rf2aa.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
from rf2aa.data.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_init_xyz
from rf2aa.util_module import ComputeAllAtomCoords
from rf2aa.chemical import NTOTAL, NTOTALDOFS, NAATOKENS

View File

@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
import wandb
from torch.utils import data
from rf2aa.data_loader import (
from rf2aa.data.data_loader import (
get_train_valid_set,
loader_pdb,
loader_fb,
@@ -25,11 +25,11 @@ from rf2aa.data_loader import (
DatasetSMComplex,
DatasetSMComplexAssembly,
)
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.loss import *
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.loss.loss import *
from rf2aa.util import *
from rf2aa.train_multi_EMA import Trainer, EMA, count_parameters
from rf2aa.archived.train_multi_EMA import Trainer, EMA, count_parameters
# disable openbabel warnings
from openbabel import openbabel as ob

View File

@@ -11,18 +11,18 @@ script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(script_dir)
import rf2aa
import rf2aa.parsers as parsers
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
import rf2aa.data.parsers as parsers
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule as RoseTTAFoldModule
import rf2aa.util as util
from rf2aa.util import *
from rf2aa.loss import *
from collections import namedtuple, OrderedDict
from rf2aa.ffindex import *
from rf2aa.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo, merge_a3m_hetero
from rf2aa.data.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo, merge_a3m_hetero
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
from rf2aa.util_module import XYZConverter
from rf2aa.chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
from rf2aa.parsers import read_templates, parse_multichain_fasta, parse_mixed_fasta
from rf2aa.data.parsers import read_templates, parse_multichain_fasta, parse_mixed_fasta
from rf2aa.memory import mem_report
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs, update_symm_subs, get_symm_map, get_symmetry
@@ -72,8 +72,11 @@ def get_args():
parser.add_argument("-n_cycle", type=int, default=10, help='number of recycles')
parser.add_argument("-trunc_N", type=int, default=0, help='residues to truncate at N-term on MSA to match PDB')
parser.add_argument("-trunc_C", type=int, default=0, help='residues to truncate at C-term on MSA to match PDB')
parser.add_argument("-no_extra_l1", dest='use_extra_l1', default='True', action='store_false',
help="Turn off chirality and LJ grad inputs to SE3 layers (for backwards compatibility).")
parser.add_argument("-use_chiral_l1", type=bool, default=True,
help="use chiral L1 features (for backwards compatibility)")
parser.add_argument("-use_lj_l1", type=bool, default=True,
help="use LJ L1 features (for backwards compatibility)")
parser.add_argument("-no_atom_frames", dest='use_atom_frames', default='True', action='store_false',
help="Turn off l1 features from atom frames in SE3 layers (for backwards compatibility).")
@@ -370,8 +373,11 @@ class Predictor():
read_data(args.db+'_pdb.ffdata'))
# define model & load model
MODEL_PARAM['use_extra_l1'] = args.use_extra_l1
MODEL_PARAM['use_chiral_l1'] = args.use_chiral_l1
MODEL_PARAM['use_lj_l1'] = args.use_lj_l1
MODEL_PARAM['use_atom_frames'] = args.use_atom_frames
MODEL_PARAM['use_same_chain'] = True
MODEL_PARAM['recycling_type'] = 'all'
self.model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask = util.allatom_mask.to(self.device),

View File

@@ -3,7 +3,7 @@ import torch
from torch.utils import data
# from chemical import NFRAMES
from rf2aa.data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params, loader_atomize_pdb
from rf2aa.data.data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params, loader_atomize_pdb
from rf2aa.kinematics import xyz_to_c6d, xyz_to_t2d
from rf2aa.chemical import num2aa, aa2elt, aa2num, aabonds,aa2long, aabtypes, atomized_protein_frames
from rf2aa.loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss, calc_chiral_loss
@@ -263,9 +263,9 @@ class LossTestCase(unittest.TestCase):
print(true_crds)
break
def test_chiral_loss(self):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats, chirals, task, item in self.valid_sm_compl_loader:
print(calc_chiral_loss(true_crds[0][None],chirals)
def test_chiral_loss(self):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats, chirals, task, item in self.valid_sm_compl_loader:
print(calc_chiral_loss(true_crds[0][None],chirals))
class DataLoaderTestCase(unittest.TestCase):

View File

@@ -19,7 +19,7 @@ from tqdm import tqdm
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(script_dir)
from rf2aa.data_loader import (
from rf2aa.data.data_loader import (
get_train_valid_set, loader_pdb, loader_fb, loader_complex,
loader_na_complex, loader_distil_tf, loader_tf_complex, loader_dna_rna,
loader_sm, loader_atomize_pdb, loader_atomize_complex,
@@ -29,11 +29,11 @@ from rf2aa.data_loader import (
DistilledDataset, DistributedWeightedSampler, unbatch_item
)
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, xyz_to_bbtor
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule as RoseTTAFoldModule
from rf2aa.loss import *
from rf2aa.util import *
from rf2aa.util_module import XYZConverter
from rf2aa.scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup
from rf2aa.training.scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs, get_symm_map
from rf2aa.chemical import load_pdb_ideal_sdf_strings
@@ -55,6 +55,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
# limit thread counts
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
## To reproduce errors
import random
@@ -80,6 +81,20 @@ def add_weight_decay(model, l2_coeff):
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_recycle_schedule(n_epochs, n_train, n_max, world_size):
'''
get's the number of recycles per example.
'''
assert n_train % world_size == 0
# need to sync different gpus
recycle_schedules=[]
# make deterministic
np.random.seed(0)
for i in range(n_epochs):
recycle_schedule=[np.random.randint(1,n_max) for _ in range(n_train//world_size)]
recycle_schedules.append(torch.tensor(recycle_schedule))
return torch.stack(recycle_schedules, dim=0)
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
@@ -97,13 +112,13 @@ class EMA(nn.Module):
print("EMA update should only be called during training", file=stderr, flush=True)
return
model_params = OrderedDict(self.model.named_parameters())
self.model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
assert self.model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
for name, param in self.model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
if param.requires_grad:
@@ -128,7 +143,7 @@ class EMA(nn.Module):
class Trainer():
def __init__(self, model_name='BFF', checkpoint_path=None,
n_epoch=100, step_lr=100, lr=1.0e-4, l2_coeff=1.0e-2, port=None, interactive=False,
model_param={}, loader_param={}, loss_param={}, dataset_param={}, batch_size=1,
model_params={}, loader_param={}, loss_param={}, dataset_param={}, batch_size=1,
accum_step=1, maxcycle=4, eval=False, out_dir=None, wandb_prefix=None,
model_dir='models/', dataloader_kwargs = {}, **kwargs):
@@ -140,7 +155,7 @@ class Trainer():
self.port = port
self.interactive = interactive
self.eval = eval
self.model_param = model_param
self.model_params = model_params
self.loader_param = loader_param
self.loss_param = loss_param
self.dataset_param = dataset_param
@@ -158,7 +173,9 @@ class Trainer():
self.debug_mode = kwargs.get("debug_mode", False)
self.skip_valid = kwargs.get("skip_valid", 1)
self.start_epoch = kwargs.get("start_epoch", 0)
self.n_train = self.dataset_param["n_train"]
self.maxcycle = maxcycle
self.recycle_schedule=get_recycle_schedule(self.n_epoch, self.n_train, self.maxcycle+1, world_size)
# for all-atom str loss
#self.ti_dev = torsion_indices
#self.ti_flip = torsion_can_flip
@@ -184,7 +201,6 @@ class Trainer():
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
self.active_fn = nn.Softmax(dim=1)
self.maxcycle = maxcycle
self.pdb_counter=0
@@ -468,8 +484,9 @@ class Trainer():
else:
l_fape_sm_protein = torch.tensor(0).to(gpu)
frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
#frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
#inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
inter_fape = l_fape_sm_protein
bb_l_fape_inter = (w_bb_fape*inter_fape).sum()
tot_loss += 0.5*w_inter_fape*bb_l_fape_inter
else:
@@ -601,25 +618,27 @@ class Trainer():
tot_loss += w_bond*bond_loss
loss_dict['bond_geom'] = bond_loss.detach()
if (pred_allatom.shape[0] > 1):
bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss.mean()
loss_dict['clash_loss'] = ( bond_loss.detach() )
else:
bond_loss = torch.tensor(0).to(gpu)
loss_dict['bond_loss'] = bond_loss.detach()
# if (pred_allatom.shape[0] > 1):
# bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
# if w_bond > 0.0:
# tot_loss += w_bond*bond_loss.mean()
# loss_dict['clash_loss'] = ( bond_loss.detach() )
# else:
# bond_loss = torch.tensor(0).to(gpu)
# loss_dict['bond_loss'] = bond_loss.detach()
# clash [use all atoms not just those in native]
clash_loss = calc_lj(
seq[0], pred_allatom,
self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
lj_lin=lj_lin
)
# clash_loss = calc_lj(
# seq[0], pred_allatom,
# self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
# lj_lin=lj_lin
# )
clash_loss, num_violations = calc_l1_clash_loss(pred_allatom, seq[0],\
self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, \
self.lj_correction_parameters, self.num_bonds)
if w_clash > 0.0:
tot_loss += w_clash*clash_loss.mean()
loss_dict['clash_loss'] = clash_loss[0].detach()
loss_dict['clash_loss'] = clash_loss.detach()
if torch.any(mask_BB[0]):
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
pred=pred_allatom[:,mask_BB[0]],
@@ -761,6 +780,7 @@ class Trainer():
def load_model(self, model, model_name, rank, suffix='last', checkpoint_path=None, resume_train=False,
optimizer=None, scheduler=None, scaler=None):
torch.cuda.empty_cache()
if self.debug_mode:
return -1, 99999999.9
if checkpoint_path==None:
@@ -849,15 +869,27 @@ class Trainer():
print("Running in DEBUG mode...")
world_size = 1
rank = 0
self.train_model(rank, world_size)
if (not self.interactive and "SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
world_size = int(os.environ["SLURM_NTASKS"])
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
if torch.cuda.device_count() == int(os.environ["SLURM_NTASKS"]):
# If launching 1 job per node
world_size = int(os.environ["SLURM_NTASKS"])
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
elif torch.cuda.device_count() > 1 and int(os.environ["SLURM_NTASKS"]) == 1:
# If launching all jobs from same node
world_size = torch.cuda.device_count()
print(f"Spawning all jobs from one node. World size: {world_size}")
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
else:
# Raise error, since we either need one job per node or all jobs from one node
raise RuntimeError("Invalid distributed processing combination of nodes/tasks/gpus for SLURM.")
else:
print ("Launched from interactive")
world_size = torch.cuda.device_count()
if world_size == 1:
# No need for multiple processes with 1 GPU
self.train_model(0, world_size)
@@ -906,7 +938,7 @@ class Trainer():
)
all_param = {}
all_param.update(self.loader_param)
all_param.update(self.model_param)
all_param.update(self.self.model_params)
all_param.update(self.loss_param)
wandb.config = all_param
@@ -914,7 +946,7 @@ class Trainer():
#print ("running ddp on rank %d, world_size %d"%(rank, world_size))
gpu = rank % torch.cuda.device_count()
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)
torch.cuda.set_device("cuda:%d"%gpu)
# Get ligand dictionary. This is used for loading negative examples.
@@ -1002,7 +1034,9 @@ class Trainer():
fractions=OrderedDict([(k, self.dataset_param['fraction_'+k]) for k in train_dict]),
num_replicas=world_size,
rank=rank,
replacement=True
lengths=self.loader_param["EXAMPLE_LENGTHS"],
batch_by_dataset=self.loader_param["BATCH_BY_DATASET"],
batch_by_length=self.loader_param["BATCH_BY_LENGTH"],
)
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **self.dataloader_kwargs)
@@ -1048,11 +1082,6 @@ class Trainer():
loader_distil_tf, valid_dict['distil_tf'],
self.loader_param, negative=False, native_NA_frac=0.0
),
atomize_pdb = Dataset(
valid_ID_dict['atomize_pdb'][:self.dataset_param['n_valid_atomize_pdb']],
loader_atomize_pdb, valid_dict['atomize_pdb'],
self.loader_param, homo, p_homo_cut=-1.0, n_res_atomize=3, flank=0, p_short_crop=-1.0
),
metal_compl = DatasetSMComplexAssembly(
valid_ID_dict['metal_compl'][:self.dataset_param['n_valid_metal_compl']],
loader_sm_compl_assembly, valid_dict['metal_compl'],
@@ -1302,10 +1331,16 @@ class Trainer():
self.cb_len = self.cb_len.to(gpu)
self.cb_ang = self.cb_ang.to(gpu)
self.cb_tor = self.cb_tor.to(gpu)
self.model_params['use_chiral_l1'] = True
self.model_params['use_lj_l1'] = False
self.model_params['use_atom_frames'] = True
self.model_params['use_same_chain'] = False
self.model_params['recycling_type'] = 'msa_pair'
self.model_params.pop('use_extra_l1')
# define model
model = EMA(RoseTTAFoldModule(
**self.model_param,
**self.model_params,
aamask=self.aamask,
atom_type_index=self.atom_type_index,
ljlk_parameters=self.ljlk_parameters,
@@ -1553,9 +1588,10 @@ class Trainer():
network_input['symmRs'] = symmRs
network_input['symmmeta'] = symmmeta
mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1)
mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
mask_recycle = same_chain.float()*mask_recycle.float()
#mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1)
#mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
#mask_recycle = same_chain.float()*mask_recycle.float()
mask_recycle = None
return task, item, network_input, xyz_prev, alpha_prev, mask_recycle, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
def _get_model_input(self, network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
@@ -1565,7 +1601,11 @@ class Trainer():
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
else:
input_i[key] = network_input[key]
msa_prev, pair_prev, xyz_prev, alpha, mask_recycle = output_i
L = input_i["msa_latent"].shape[2]
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
input_i['msa_prev'] = msa_prev
input_i['pair_prev'] = pair_prev
input_i['xyz'] = xyz_prev
@@ -1579,7 +1619,7 @@ class Trainer():
self, output_i, true_crds, atom_mask, same_chain,
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label, ctrid=0
):
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _ = output_i
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
if (symmRs is not None):
#print ('a', pred_crds.shape, true_crds.shape, mask_crds.shape)
@@ -1682,7 +1722,10 @@ class Trainer():
counter = 0
for inputs in train_loader:
for train_idx, inputs in enumerate(train_loader):
regression = {
"dataloader_inputs": inputs
}
(
task, item, network_input, xyz_prev, alpha_prev, mask_recycle,
true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
@@ -1691,9 +1734,12 @@ class Trainer():
counter += 1
N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling
#N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling
# all examples in a pseudo batch have the same recycle
N_cycle = self.recycle_schedule[epoch, train_idx] # number of recycling
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
N_cycle = 1
for i_cycle in range(N_cycle):
with ExitStack() as stack:
if i_cycle < N_cycle -1:
@@ -1707,9 +1753,8 @@ class Trainer():
return_raw=False
use_checkpoint=True
input_i = self._get_model_input(network_input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
output_i = ddp_model(**input_i)
if i_cycle < N_cycle - 1:
continue
loss, loss_dict, acc_s, _, true_crds, pred_allatom, res_mask = self._get_loss_and_misc(
@@ -1720,10 +1765,21 @@ class Trainer():
unclamp, negative, task, item, symmRs, Lasu, ch_label,
len(train_loader)*rank+counter
)
regression.update( {
"model_input": input_i,
"model_output": output_i,
"loss": loss,
"loss_dict": loss_dict
})
torch.save(
regression,
"test_pickles/model_io.pt"
)
loss = loss / self.ACCUM_STEP
scaler.scale(loss).backward()
if counter%self.ACCUM_STEP == 0:
# gradient clipping
print("accumulation")
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.2)
scaler.step(optimizer)
@@ -1734,7 +1790,15 @@ class Trainer():
if not skip_lr_sched:
scheduler.step()
ddp_model.module.update() # apply EMA
torch.save({'epoch': epoch,
#'model_state_dict': ddp_model.state_dict(),
'model_state_dict': ddp_model.module.shadow.state_dict(),
'final_state_dict': ddp_model.module.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict()},"test_pickles/optimizer_regression.pt")
raise UnboundLocalError("stopping after 1 iteration")
item_ = unbatch_item(item) # remove nested lists to make more readable when printed
save_pdbs = False
if torch.isnan(loss):
@@ -1886,6 +1950,7 @@ class Trainer():
return_raw=False
input_i = self._get_model_input(network_input, output_i, i_cycle, gpu, return_raw=return_raw)
output_i = ddp_model(**input_i)
if i_cycle < N_cycle - 1:
@@ -2233,7 +2298,7 @@ class Trainer():
if __name__ == "__main__":
from arguments import get_args
args, dataset_param, model_param, loader_param, loss_param = get_args()
args, dataset_param, model_params, loader_param, loss_param = get_args()
if args.debug:
DEBUG = True
@@ -2250,7 +2315,8 @@ if __name__ == "__main__":
"num_workers": args.dataloader_num_workers,
"pin_memory": not args.dont_pin_memory,
}
world_size = torch.cuda.device_count()
trainer_object = Trainer(
model_name=args.model_name,
checkpoint_path = args.checkpoint_path,
@@ -2259,7 +2325,7 @@ if __name__ == "__main__":
lr=args.lr,
l2_coeff=1.0e-2,
port=args.port,
model_param=model_param,
model_params=model_params,
loader_param=loader_param,
loss_param=loss_param,
batch_size=args.batch_size,
@@ -2274,5 +2340,6 @@ if __name__ == "__main__":
dataloader_kwargs=dataloader_kwargs,
debug_mode=args.debug,
skip_valid=args.skip_valid,
world_size=world_size
)
trainer_object.run_model_training(torch.cuda.device_count())
trainer_object.run_model_training(world_size)

View File

@@ -0,0 +1,129 @@
experiment:
name: rf2a-fd4-20231117
n_epoch: 800
output_dir: null
trainer: "legacy"
model:
embedding: null
blocks: null
refinment: null
auxiliary_predictors: {}
legacy_model: null
dataset_params:
fraction_pdb: 0
fraction_fb: 0
fraction_compl: 0
fraction_neg_compl: 0
fraction_na_compl: 0
fraction_neg_na_compl: 0
fraction_distil_tf: 0
fraction_tf: 0
fraction_neg_tf: 0
fraction_rna: 0
fraction_dna: 0
fraction_sm_compl: 1
fraction_metal_compl: 0
fraction_sm_compl_multi: 0
fraction_sm_compl_covale: 0
fraction_sm: 0
fraction_atomize_pdb: 0
fraction_atomize_complex: 0
fraction_sm_compl_asmb: 0
fraction_sm_compl_furthest_neg: 0
fraction_sm_compl_permuted_neg: 0
fraction_sm_compl_docked_neg: 0
n_train: 256000
n_valid_pdb: 0
n_valid_homo: 0
n_valid_dslf: 0
n_valid_compl: 0
n_valid_neg_compl: 0
n_valid_na_compl: 0
n_valid_neg_na_compl: 0
n_valid_distil_tf: 0
n_valid_tf: 0
n_valid_neg_tf: 0
n_valid_rna: 0
n_valid_dna: 0
n_valid_sm_compl: 0
n_valid_metal_compl: 0
n_valid_sm_compl_multi: 0
n_valid_sm_compl_covale: 0
n_valid_sm_compl_strict: 0
n_valid_sm: 0
n_valid_atomize_pdb: 0
n_valid_atomize_complex: 0
n_valid_sm_compl_asmb: 0
n_valid_sm_compl_furthest_neg: 0
n_valid_sm_compl_permuted_neg: 0
n_valid_sm_compl_docked_neg: 0
n_valid_dude_actives: 0
n_valid_dude_inactives: 0
p_short_crop: 0
p_dslf_crop: 0
dslf_fb_upsample: 1
loader_params:
maxseq: 1024
maxtoken: 1024
maxlat: 128
crop: 256
rescut: 4.5
mintplt: 1
maxtplt: 4
seqid: 150.0
p_msa_mask: 0.15
maxcycle: 4
nres_atomize_min: 3
nres_atomize_max: 5
atomize_flank: 0
p_metal: 1
p_atomize_modres: 1
batch_by_dataset: True
batch_by_length: True
dataloader_kwargs:
shuffle: False
num_workers: 0
pin_memory: True
ddp_params:
accum: 1
batch_size: 1
port: 12435
training_params:
resume_train: False
EMA: 0.99
weight_decay: 0.01
learning_rate: .001
learning_rate_schedule:
num_warmup_steps: 0
num_steps_decay: 5000
decay_rate: 0.95
grad_clip: 0.2
use_amp: False
loss_param:
w_dist: 1.0
w_str: 1.0
w_inter_fape: 0.0
w_lig_fape: 0.0
w_lddt: 0.1
w_aa: 3.0
w_bond: 0.0
w_bind: 0.0
binder_loss_label_smoothing: 0.0
w_clash: 0.0
w_atom_bond: 0.0
w_skip_bond: 0.0
w_rigid: 0.0
w_hb: 0.0
w_pae: 0.05
w_pde: 0.05
lj_lin: 0.75
log_params:
log_every_n_examples: 1
eval_params: null

View File

@@ -0,0 +1,61 @@
defaults:
- base
experiment:
name: frank_model-10rec
output_dir: null
legacy_model_param:
n_extra_block: 4
n_main_block: 32
n_ref_block: 4
n_finetune_block: 0
d_msa: 256
d_msa_full: 64
d_pair: 192
d_templ: 64
n_head_msa: 8
n_head_pair: 6
n_head_templ: 4
d_hidden_templ: 64
p_drop: 0.0
use_chiral_l1: True
use_lj_l1: True
use_atom_frames: True
recycling_type: "all"
use_same_chain: True
lj_lin: 0.75
SE3_param:
num_layers: 1
num_channels: 32
num_degrees: 2
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
n_heads: 4
div: 4
SE3_ref_param:
num_layers: 2
num_channels: 32
num_degrees: 2
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
n_heads: 4
div: 4
ddp_params:
port: 12375
batch_size: 1
loader_params:
crop: 800
maxcycle: 10
maxlat: 256
eval_params:
checkpoint_path: "/home/dimaio/RF2-allatom/rf2aa/models/RF2_26a_last.pt"

View File

@@ -0,0 +1,119 @@
defaults:
- base
experiment:
name: rf2aa-reproduction
trainer: "composed"
model:
global_params:
d_msa: 256
d_msa_full: 64
d_pair: 192
d_state: 64
embedding:
rf2aa:
params:
p_drop: 0.15
d_templ: 64
n_head_templ: 4
d_hidden_templ: 64
templ_p_drop: 0.25
symmetrize_repeats: False
repeat_length: null
symmsub_k: null
sym_method: null
main_block: null
copy_main_block_template: False
additional_dt1d: 0
recycling_type: "msa_pair"
use_same_chain: False
blocks:
RF2aa_full:
num_blocks: 4
params:
d_rbf: 64
p_drop_row: 0.25
p_drop_pair: 0.25
msa_transition_drop: 0.0
outer_product_channels: 16
p_drop_outer_product: 0.0
n_pair_head: 6
n_pair_channels: 32
n_msa_head: 8
n_msa_channels: 8
structure_bias_gate_channels: 16
structure_bias_channels: 64
n_se3_layers: 1
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 6
l1_out_features: 2
n_se3_edge_features: 64
sc_pred_d_hidden: 128
sc_pred_p_drop: 0.0
RF2aa:
num_blocks: 32
params:
d_rbf: 64
p_drop_row: 0.25
p_drop_pair: 0.25
msa_transition_drop: 0.0
outer_product_channels: 16
p_drop_outer_product: 0.0
n_pair_head: 6
n_pair_channels: 32
n_msa_head: 8
n_msa_channels: 32
structure_bias_gate_channels: 16
structure_bias_channels: 64
n_se3_layers: 1
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 6
l1_out_features: 2
n_se3_edge_features: 64
sc_pred_d_hidden: 128
sc_pred_p_drop: 0.0
refinement:
local:
params:
num_iterations: 4
d_rbf: 64
n_se3_layers: 2
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 6
l1_out_features: 2
n_se3_edge_features: 64
top_k: 64
sc_pred_d_hidden: 128
sc_pred_p_drop: 0.15
auxiliary_predictors:
c6d:
n_feat: 192
input_feature: "pair"
mlm:
n_feat: 256
input_feature: "msa"
pae:
n_feat: 192
input_feature: "pair"
plddt:
n_feat: 64
input_feature: "state"

View File

@@ -0,0 +1,111 @@
defaults:
- base
experiment:
name: rf2a-fd4-20240116
trainer: "composed"
model:
global_params:
d_msa: 256
d_msa_full: 64
d_pair: 192
d_state: 32
embedding:
rf2aa_nostate:
params:
p_drop: 0.15
d_templ: 64
n_head_templ: 4
d_hidden_templ: 64
templ_p_drop: 0.25
recycling_type: "msa_pair"
use_same_chain: False
blocks:
RF2_withgradients_full:
num_blocks: 4
params:
n_msa_head: 8
n_msa_channels: 64
p_drop_row: 0.25
p_drop_col: 0.25
structure_bias_type: "ungated"
structure_bias_channels: 64
n_pair_head: 6
n_pair_channels: 192
outer_product_channels: 16
intermediate_seq_channels: 192 #this is from AF2
n_se3_layers: 1
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 32
l1_in_features: 1
l1_out_features: 1
n_se3_edge_features: 64
RF2_withgradients:
num_blocks: 36
params:
n_msa_head: 8
n_msa_channels: 256
p_drop_row: 0.25
p_drop_col: 0.25
structure_bias_type: "ungated"
structure_bias_channels: 64
n_pair_head: 6
n_pair_channels: 192
outer_product_channels: 16
intermediate_seq_channels: 192 #this is from AF2
n_se3_layers: 1
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 32
l1_in_features: 1
l1_out_features: 1
n_se3_edge_features: 64
refinement:
local:
params:
num_iterations: 4
d_rbf: 64
n_se3_layers: 2
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 6
l1_out_features: 2
n_se3_edge_features: 64
top_k: 64
sc_pred_d_hidden: 128
sc_pred_p_drop: 0.15
auxiliary_predictors:
c6d:
n_feat: 192
input_feature: "pair"
mlm:
n_feat: 256
input_feature: "msa"
pae:
n_feat: 192
input_feature: "pair"
plddt:
n_feat: 64
input_feature: "state"
dataset_params:
fraction_sm_compl: 1
n_train: 25600
training_params:
resume_train: False

View File

@@ -0,0 +1,365 @@
import numpy as np
import pickle
import torch.utils.data as data
from collections import OrderedDict
from rf2aa.data.data_loader import get_train_valid_set, loader_pdb, loader_complex, loader_na_complex, \
loader_distil_tf, loader_tf_complex, loader_fb, loader_dna_rna, loader_sm_compl_assembly_single, \
loader_sm_compl_assembly, loader_sm, loader_atomize_pdb, loader_atomize_complex, DistilledDataset, \
DistributedWeightedSampler, Dataset, DatasetRNA, DatasetNAComplex, DatasetNAComplex, \
DatasetSMComplexAssembly, DatasetSM, _load_df
#### handle defaults
#TODO: shouldn't have to do this in the future, should all be handled in config
base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02"
compl_dir = "/projects/ml/RoseTTAComplex"
na_dir = "/projects/ml/nucleic"
fb_dir = "/projects/ml/TrRosetta/fb_af"
sm_compl_dir = "/projects/ml/RF2_allatom"
mol_dir = "/projects/ml/RF2_allatom/rcsb/pkl" # for phase 3 dataloaders
# mol_dir = "/projects/ml/RF2_allatom/isdf" # for legacy datasets
tf_dir = "/projects/ml/prot_dna"
csd_dir = "/databases/csd543"
sample_lengths_dir = "/projects/ml/RF2_allatom"
default_dataloader_params = {
"COMPL_LIST" : "%s/list.hetero.csv"%compl_dir,
"HOMO_LIST" : "%s/list.homo.csv"%compl_dir,
"NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir,
"RNA_LIST" : "%s/list.rnaonly.csv"%na_dir,
"DNA_LIST" : "%s/list.dnaonly.v3.csv"%na_dir,
"NA_COMPL_LIST" : "%s/list.nucleic.v3.csv"%na_dir,
"NEG_NA_COMPL_LIST": "%s/list.na_negatives.v3.csv"%na_dir,
"TF_DISTIL_LIST" : "%s/prot_na_distill.v3.csv"%tf_dir,
"TF_COMPL_LIST" : "%s/tf_compl_list.v4.csv"%tf_dir,
"SM_LIST" : "%s/sm_compl_all_20230418.csv"%sm_compl_dir,
"PDB_LIST" : "%s/list_v02_w_taxid.csv"%sm_compl_dir, # on digs
"PDB_METADATA" : "%s/list_v00_w_taxid_20230201.csv"%sm_compl_dir, # on digs
"FB_LIST" : "%s/list_b1-3.csv"%fb_dir,
"CSD_LIST" : "%s/csd543_cleaned01.csv"%csd_dir,
"VAL_PDB" : "%s/valid_remapped"%sm_compl_dir,
"VAL_RNA" : "%s/rna_valid.csv"%na_dir,
"VAL_DNA" : "%s/dna_valid.csv"%na_dir,
"VAL_COMPL" : "%s/val_lists/xaa"%compl_dir,
"VAL_NEG" : "%s/val_lists/xaa.neg"%compl_dir,
"VAL_TF" : "%s/tf_valid_clusters_v4.txt"%tf_dir,
"VAL_SM_STRICT" : "%s/sm_compl_valid_strict_20230418.csv"%sm_compl_dir,
"TEST_SM" : "%s/sm_test_heldout_test_clusters.txt"%sm_compl_dir,
"DATAPKL" : "%s/dataset_20231116.pkl"%sm_compl_dir, # cache for faster loading
"DSLF_LIST" : "%s/list.dslf.csv"%na_dir,
"DSLF_FB_LIST" : "%s/list.dslf_fb.csv"%na_dir,
"DUDE_LIST" : "/home/dnan/projects/gald_distil_set/nbs/dude_dataset_cutoff_-5.csv", # on digs (dnan)
"DUDE_MSAS" : "/home/dnan/projects/gald_distil_set/DUDE/fastas", # on digs (dnan)
"DUDE_PDB_DIR" : "/home/dnan/projects/gald_distil_set/DUDE/pdbs_all",
"EXAMPLE_LENGTHS" : "%s/all_sample_lengths_crop_1K.pt"%sample_lengths_dir,
"PDB_DIR" : base_dir,
"FB_DIR" : fb_dir,
"COMPL_DIR" : compl_dir,
"NA_DIR" : na_dir,
"TF_DIR" : tf_dir,
"MOL_DIR" : mol_dir,
"CSD_DIR" : csd_dir,
"MINTPLT" : 0,
"MAXTPLT" : 5,
"MINSEQ" : 1,
"MAXSEQ" : 1024,
"MAXLAT" : 128,
"CROP" : 256,
"DATCUT" : "2021-Aug-1",
"RESCUT" : 4.5,
"BLOCKCUT" : 5,
"PLDDTCUT" : 70.0,
"SCCUT" : 90.0,
"ROWS" : 1,
"SEQID" : 95.0,
"MAXCYCLE" : 4,
"RMAX" : 5.0,
"MAXRES" : 1,
"MINATOMS" : 5,
"MAXATOMS" : 100,
"MAXSIM" : 0.85,
"MAXNSYMM" : 1024,
"NRES_ATOMIZE_MIN" : 5,
"NRES_ATOMIZE_MAX" : 15,
"ATOMIZE_FLANK" : 0,
"MAXPROTCHAINS" : 6,
"MAXLIGCHAINS" : 10,
"MAXMASKEDLIGATOMS": 30,
"P_METAL" : 0.75,
"P_ATOMIZE_MODRES" : 0.75,
"MAXMONOMERLENGTH" : None,
"ATOMIZE_CLUSTER" : True,
"P_ATOMIZE_TEMPLATE": 0.0,
"NUM_SEQS_SUBSAMPLE": 50,
"BLACK_HOLE_INIT" : False,
"SHOW_SM_TEMPLATES" : False,
"BATCH_BY_DATASET" : False,
"BATCH_BY_LENGTH" : False,
}
def set_data_loader_params(loader_params):
""" add things from config into default dataloader params """
for param in default_dataloader_params:
if hasattr(loader_params, param.lower()):
default_dataloader_params[param] = getattr(loader_params, param.lower())
# cursed but add things in the param but not the default params back
loader_params_dict = dict(loader_params)
for param, value in loader_params_dict.items():
if param.upper() not in default_dataloader_params:
default_dataloader_params[param] = value
return default_dataloader_params
def compose_dataset(dataset_params, loader_params, rank, world_size):
# define dataset & data loader
# this function overrides the default dataloader params with those in the config
#TODO: cache this in checkpoints so checkpoints use the same dataloder params as training
loader_params = set_data_loader_params(loader_params=loader_params)
train_ID_dict, valid_ID_dict, weights_dict, train_dict, valid_dict, homo, chid2hash, chid2taxid, chid2smpartners = \
get_train_valid_set(loader_params)
# define atomize_pdb train/valid sets, which use the same examples as pdb set
train_ID_dict['atomize_pdb'] = train_ID_dict['pdb']
valid_ID_dict['atomize_pdb'] = valid_ID_dict['pdb']
weights_dict['atomize_pdb'] = weights_dict['pdb']
train_dict['atomize_pdb'] = train_dict['pdb']
valid_dict['atomize_pdb'] = valid_dict['pdb']
# define atomize_pdb train/valid sets, which use the same examples as pdb set
train_ID_dict['atomize_complex'] = train_ID_dict['compl']
valid_ID_dict['atomize_complex'] = valid_ID_dict['compl']
weights_dict['atomize_complex'] = weights_dict['compl']
train_dict['atomize_complex'] = train_dict['compl']
valid_dict['atomize_complex'] = valid_dict['compl']
# reweight fb examples containing disulfide loops
to_reweight_ex = train_dict['fb']['HAS_DSLF_LOOP']
to_reweight_cluster = train_dict['fb'][to_reweight_ex].CLUSTER.unique()
reweight_mask = np.in1d(train_ID_dict['fb'],to_reweight_cluster)
weights_dict['fb'][ reweight_mask ] *= dataset_params['dslf_fb_upsample']
# set number of validation examples being used
for k in valid_dict:
if dataset_params['n_valid_'+k] is None:
dataset_params["n_valid_"+k] = len(valid_dict[k])
loader_dict = dict(
pdb = loader_pdb,
peptide = loader_pdb,
compl = loader_complex,
neg_compl = loader_complex,
na_compl = loader_na_complex,
neg_na_compl = loader_na_complex,
distil_tf = loader_distil_tf,
tf = loader_tf_complex,
neg_tf = loader_tf_complex,
fb = loader_fb,
rna = loader_dna_rna,
dna = loader_dna_rna,
sm_compl = loader_sm_compl_assembly_single,
metal_compl = loader_sm_compl_assembly_single,
sm_compl_multi = loader_sm_compl_assembly_single,
sm_compl_covale = loader_sm_compl_assembly_single,
sm_compl_asmb = loader_sm_compl_assembly,
sm = loader_sm,
atomize_pdb = loader_atomize_pdb,
atomize_complex = loader_atomize_complex,
sm_compl_furthest_neg = loader_sm_compl_assembly,
sm_compl_permuted_neg = loader_sm_compl_assembly,
sm_compl_docked_neg = loader_sm_compl_assembly,
)
train_set = DistilledDataset(
train_ID_dict, train_dict, loader_dict, homo, chid2hash, chid2taxid, chid2smpartners,
loader_params, native_NA_frac=0.25,
p_short_crop=dataset_params['p_short_crop'],
p_dslf_crop=dataset_params['p_dslf_crop'])
train_sampler = DistributedWeightedSampler(
train_set,
weights_dict,
num_example_per_epoch=dataset_params['n_train'],
fractions=OrderedDict([(k, dataset_params['fraction_'+k]) for k in train_dict]),
num_replicas=world_size,
rank=rank,
lengths=loader_params["EXAMPLE_LENGTHS"],
batch_by_dataset=loader_params["BATCH_BY_DATASET"],
batch_by_length=loader_params["BATCH_BY_LENGTH"],
)
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=1, **loader_params["dataloader_kwargs"])
valid_sets = dict(
atomize_pdb = Dataset(
valid_ID_dict['atomize_pdb'][:dataset_params['n_valid_atomize_pdb']],
loader_atomize_pdb, valid_dict['atomize_pdb'],
loader_params, homo, p_homo_cut=-1.0, n_res_atomize=9, flank=0, p_short_crop=-1.0
),
atomize_complex = Dataset(
valid_ID_dict['atomize_complex'][:dataset_params['n_valid_atomize_complex']],
loader_atomize_complex, valid_dict['atomize_complex'],
loader_params, homo, p_homo_cut=-1.0, n_res_atomize=9, flank=0, p_short_crop=-1.0
),
pdb = Dataset(
valid_ID_dict['pdb'][:dataset_params['n_valid_pdb']],
loader_pdb, valid_dict['pdb'],
loader_params, homo, p_homo_cut=-1.0, p_short_crop=-1.0, p_dslf_crop=-1.0
),
dslf = Dataset(
valid_ID_dict['dslf'][:dataset_params['n_valid_dslf']],
loader_pdb, valid_dict['dslf'],
loader_params, homo, p_homo_cut=-1.0, p_short_crop=-1.0, p_dslf_crop=1.0
),
homo = Dataset(
valid_ID_dict['homo'][:dataset_params['n_valid_homo']],
loader_pdb, valid_dict['homo'],
loader_params, homo, p_homo_cut=1.0, p_short_crop=-1.0, p_dslf_crop=-1.0
),
rna = DatasetRNA(
valid_ID_dict['rna'][:dataset_params['n_valid_rna']],
loader_dna_rna, valid_dict['rna'],
loader_params,
),
dna = DatasetRNA(
valid_ID_dict['dna'][:dataset_params['n_valid_dna']],
loader_dna_rna, valid_dict['dna'],
loader_params,
),
distil_tf = DatasetNAComplex(
valid_ID_dict['distil_tf'][:dataset_params['n_valid_distil_tf']],
loader_distil_tf, valid_dict['distil_tf'],
loader_params, negative=False, native_NA_frac=0.0
),
metal_compl = DatasetSMComplexAssembly(
valid_ID_dict['metal_compl'][:dataset_params['n_valid_metal_compl']],
loader_sm_compl_assembly, valid_dict['metal_compl'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='metal_compl',
num_protein_chains=1,
),
sm_compl = DatasetSMComplexAssembly(
valid_ID_dict['sm_compl'][:dataset_params['n_valid_sm_compl']],
loader_sm_compl_assembly, valid_dict['sm_compl'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl',
num_protein_chains=1,
),
sm_compl_multi = DatasetSMComplexAssembly(
valid_ID_dict['sm_compl_multi'][:dataset_params['n_valid_sm_compl_multi']],
loader_sm_compl_assembly, valid_dict['sm_compl_multi'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl_multi',
num_protein_chains=1,
),
sm_compl_covale = DatasetSMComplexAssembly(
valid_ID_dict['sm_compl_covale'][:dataset_params['n_valid_sm_compl_covale']],
loader_sm_compl_assembly, valid_dict['sm_compl_covale'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl_covale',
num_protein_chains=1,
),
sm_compl_strict = DatasetSMComplexAssembly(
valid_ID_dict['sm_compl_strict'][:dataset_params['n_valid_sm_compl_strict']],
loader_sm_compl_assembly, valid_dict['sm_compl_strict'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl_strict',
num_protein_chains=1,
),
sm_compl_asmb = DatasetSMComplexAssembly(
valid_ID_dict['sm_compl_asmb'][:dataset_params['n_valid_sm_compl_asmb']],
loader_sm_compl_assembly, valid_dict['sm_compl_asmb'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl_asmb'
),
sm = DatasetSM(
valid_ID_dict['sm'][:dataset_params['n_valid_sm']],
loader_sm, valid_dict['sm'],
loader_params,
),
)
valid_headers = dict(
distil_tf = 'TF_Distil',
pdb = 'Monomer',
dslf = 'Disulfide_loop',
homo = 'Homo',
rna = 'RNA',
dna = 'DNA',
sm_compl = 'SM_Compl',
metal_compl = 'Metal_ion',
sm_compl_multi = 'Multires_ligand',
sm_compl_covale = "Covalent_ligand",
sm_compl_strict = 'SM_Compl_(strict)',
sm = 'SM_CSD',
atomize_pdb = 'Monomer_atomize',
atomize_complex = 'Complex_atomize',
sm_compl_asmb = 'SMCompl_Assembly',
)
valid_samplers = OrderedDict([
(k, data.distributed.DistributedSampler(v, num_replicas=world_size, rank=rank))
for k,v in valid_sets.items()
])
valid_loaders = OrderedDict([
(k, data.DataLoader(v, sampler=valid_samplers[k], **loader_params["dataloader_kwargs"]))
for k,v in valid_sets.items()
])
return train_loader, train_sampler, valid_loaders, valid_samplers
def compose_posebusters(loader_params, rank, world_size):
loader_params = set_data_loader_params(loader_params=loader_params)
valid_ID_dict, valid_dict = {}, {}
valid_dict["benchmark"] = _load_df("/home/rohith/RF2ligand/posebusters_benchmark.csv", pad_hash=False, eval_cols=["LIGAND", "PARTNERS", "LIGXF"])
valid_ID_dict["benchmark"] = valid_dict["benchmark"]["CLUSTER"]
with open(
"/projects/ml/RF2_allatom/posebusters/posebusters_chid2hash_081723.pkl", "rb"
) as f:
chid2hash = pickle.load(f)
with open(
"/projects/ml/RF2_allatom/posebusters/posebusters_chid2taxid_081723.pkl", "rb"
) as f:
chid2taxid = pickle.load(f)
loader_params["MINTPLT"] = 0
loader_params["MAXTPLT"] = 0
loader_params["PDB_DIR"] = "/projects/ml/RF2_allatom/benchmark"
benchmark = DatasetSMComplexAssembly(
valid_ID_dict['benchmark'],
loader_sm_compl_assembly, valid_dict['benchmark'],
chid2hash, chid2taxid, # used for MSA generation of assemblies
loader_params,
task='sm_compl',
num_protein_chains=1,
num_ligand_chains=2,
)
sampler = data.distributed.DistributedSampler(benchmark, rank=rank, num_replicas=world_size)
loader = data.DataLoader(benchmark, sampler=sampler, **loader_params["dataloader_kwargs"])
return loader
def compose_single_item_dataset(item, loader_params, loader, loader_kwargs):
class SpoofDataset(data.Dataset):
def __init__(self, loader_params, loader, loader_kwargs) -> None:
super().__init__()
self.loader_params = loader_params
self.loader = loader
self.loader_kwargs = loader_kwargs
def __getitem__(self, idx):
return self.loader(item, self.loader_params, **self.loader_kwargs)
def __len__(self):
return 1
dataset = SpoofDataset(loader_params, loader, loader_kwargs)
loader = data.DataLoader(dataset, **loader_params["dataloader_kwargs"])
return loader

View File

@@ -1,7 +1,6 @@
import torch
import warnings
import time
import deepdiff
from icecream import ic
from torch.utils import data
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
@@ -19,121 +18,27 @@ sys.path.append(script_dir+'/../')
import numpy as np
import pandas as pd
import torch
from torch.utils import data
import scipy
from scipy.sparse.csgraph import shortest_path
import networkx as nx
from rf2aa import cifutils
from rf2aa.parsers import parse_a3m, parse_pdb, parse_fasta_if_exists, parse_mol, parse_mixed_fasta, get_dislf
import rf2aa.cifutils as cifutils
from rf2aa.data.parsers import parse_a3m, parse_pdb, parse_fasta_if_exists, parse_mol, parse_mixed_fasta, get_dislf
from rf2aa.chemical import INIT_CRDS, INIT_NA_CRDS, NAATOKENS, MASKINDEX, UNKINDEX, \
NTOTAL, NBTYPES, CHAIN_GAP, num2aa, METAL_RES_NAMES, aa2num, atomnum2atomtype, load_tanimoto_sim_matrix
NTOTAL, NBTYPES, CHAIN_GAP, num2aa, METAL_RES_NAMES, aa2num, atomnum2atomtype, load_tanimoto_sim_matrix, NPROTAAS
from rf2aa.kinematics import get_chirals
from rf2aa.symmetry import get_symmetry
from rf2aa.identical_ligands import get_extra_identical_copies_from_chains
from rf2aa.data.identical_ligands import get_extra_identical_copies_from_chains
from rf2aa.util import get_nxgraph, get_atom_frames, get_bond_feats, get_protein_bond_feats, \
center_and_realign_missing, random_rot_trans, allatom_mask, cif_prot_to_xyz, \
cif_ligand_to_xyz, cif_ligand_to_obmol, get_automorphs, get_ligand_atoms_bonds, \
map_identical_prot_chains, cartprodcat, idx_from_Ls, same_chain_2d_from_Ls, bond_feats_from_Ls, \
reindex_protein_feats_after_atomize, get_residue_contacts, atomize_discontiguous_residues, pop_protein_feats, \
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, is_nucleic
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, \
is_protein, is_nucleic, is_atom
# faster for remote/tukwila nodes
#base_dir = "/databases/TrRosetta/PDB-2021AUG02"
#compl_dir = "/databases/TrRosetta/RoseTTAComplex"
#na_dir = "/databases/TrRosetta/nucleic"
#sm_compl_dir = "/databases/TrRosetta/RF2_allatom"
#mol_dir = "/databases/TrRosetta/RF2_allatom/by-pdb"
csd_dir = "/databases/csd543"
# older paths, still good but best for local/UW nodes
base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02"
compl_dir = "/projects/ml/RoseTTAComplex"
na_dir = "/projects/ml/nucleic"
fb_dir = "/projects/ml/TrRosetta/fb_af"
sm_compl_dir = "/projects/ml/RF2_allatom"
mol_dir = "/projects/ml/RF2_allatom/rcsb/pkl" # for phase 3 dataloaders
# mol_dir = "/projects/ml/RF2_allatom/isdf" # for legacy datasets
tf_dir = "/projects/ml/prot_dna"
default_dataloader_params = {
"COMPL_LIST" : "%s/list.hetero.csv"%compl_dir,
"HOMO_LIST" : "%s/list.homo.csv"%compl_dir,
"NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir,
"RNA_LIST" : "%s/list.rnaonly.csv"%na_dir,
"DNA_LIST" : "%s/list.dnaonly.v3.csv"%na_dir,
"NA_COMPL_LIST" : "%s/list.nucleic.v3.csv"%na_dir,
"NEG_NA_COMPL_LIST": "%s/list.na_negatives.v3.csv"%na_dir,
"TF_DISTIL_LIST" : "%s/prot_na_distill.v3.csv"%tf_dir,
"TF_COMPL_LIST" : "%s/tf_compl_list.v4.csv"%tf_dir,
"SM_LIST" : "%s/sm_compl_all_20230418.csv"%sm_compl_dir,
"PDB_LIST" : "%s/list_v02_w_taxid.csv"%sm_compl_dir, # on digs
"PDB_METADATA" : "%s/list_v00_w_taxid_20230201.csv"%sm_compl_dir, # on digs
"FB_LIST" : "%s/list_b1-3.csv"%fb_dir,
"CSD_LIST" : "%s/csd543_cleaned01.csv"%csd_dir,
"VAL_PDB" : "%s/valid_remapped"%sm_compl_dir,
"VAL_RNA" : "%s/rna_valid.csv"%na_dir,
"VAL_DNA" : "%s/dna_valid.csv"%na_dir,
"VAL_COMPL" : "%s/val_lists/xaa"%compl_dir,
"VAL_NEG" : "%s/val_lists/xaa.neg"%compl_dir,
"VAL_TF" : "%s/tf_valid_clusters_v4.txt"%tf_dir,
"VAL_SM_STRICT" : "%s/sm_compl_valid_strict_20230418.csv"%sm_compl_dir,
"TEST_SM" : "%s/sm_test_heldout_test_clusters.txt"%sm_compl_dir,
"DATAPKL" : "%s/dataset_20231116.pkl"%sm_compl_dir, # cache for faster loading
"DSLF_LIST" : "%s/list.dslf.csv"%na_dir,
"DSLF_FB_LIST" : "%s/list.dslf_fb.csv"%na_dir,
"DUDE_LIST" : "/home/dnan/projects/gald_distil_set/nbs/dude_dataset_cutoff_-5.csv", # on digs (dnan)
"DUDE_MSAS" : "/home/dnan/projects/gald_distil_set/DUDE/fastas", # on digs (dnan)
"DUDE_PDB_DIR" : "/home/dnan/projects/gald_distil_set/DUDE/pdbs_all",
"PDB_DIR" : base_dir,
"FB_DIR" : fb_dir,
"COMPL_DIR" : compl_dir,
"NA_DIR" : na_dir,
"TF_DIR" : tf_dir,
"MOL_DIR" : mol_dir,
"CSD_DIR" : csd_dir,
"MINTPLT" : 0,
"MAXTPLT" : 5,
"MINSEQ" : 1,
"MAXSEQ" : 1024,
"MAXLAT" : 128,
"CROP" : 256,
"DATCUT" : "2021-Aug-1",
"RESCUT" : 4.5,
"BLOCKCUT" : 5,
"PLDDTCUT" : 70.0,
"SCCUT" : 90.0,
"ROWS" : 1,
"SEQID" : 95.0,
"MAXCYCLE" : 4,
"RMAX" : 5.0,
"MAXRES" : 1,
"MINATOMS" : 5,
"MAXATOMS" : 100,
"MAXSIM" : 0.85,
"MAXNSYMM" : 1024,
"NRES_ATOMIZE_MIN" : 5,
"NRES_ATOMIZE_MAX" : 15,
"ATOMIZE_FLANK" : 0,
"MAXPROTCHAINS" : 6,
"MAXLIGCHAINS" : 10,
"MAXMASKEDLIGATOMS": 30,
"P_METAL" : 0.75,
"P_ATOMIZE_MODRES" : 0.75,
"MAXMONOMERLENGTH" : None,
"ATOMIZE_CLUSTER" : True,
"P_ATOMIZE_TEMPLATE": 1.0,
"NUM_SEQS_SUBSAMPLE": 50,
"BLACK_HOLE_INIT" : False,
"SHOW_SM_TEMPLATES" : True,
}
def set_data_loader_params(args):
for param in default_dataloader_params:
if hasattr(args, param.lower()):
default_dataloader_params[param] = getattr(args, param.lower())
return default_dataloader_params
assert "rf2aa" in os.path.abspath(cifutils.__file__)
def MSABlockDeletion(msa, ins, nb=5):
'''
@@ -197,6 +102,12 @@ def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
- insertion info (1)
- N-term or C-term? (2)
'''
# Truncate MSA (for efficiency when pre-computing lengths)
if params.get("MSA_LIMIT") is not None:
# Raise a warning that we are truncating the MSA
warnings.warn(f"Truncating MSA to {params['MSA_LIMIT']} sequences. Only to be used for length pre-computation, NOT training.")
msa = msa[:params["MSA_LIMIT"]]
if fixbb:
p_mask = 0
msa = msa[:1]
@@ -258,18 +169,28 @@ def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
# - 10%: aa replaced with an amino acid sampled from the MSA profile
# - 10%: not replaced
# - 70%: replaced with a special token ("mask")
seq = msa_clust[0]
random_aa = torch.tensor([[0.05]*20 + [0.0]*(NAATOKENS-20)], device=msa.device)
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=NAATOKENS)
# explicitly remove probabilities from nucleic acids and atoms
same_aa[..., NPROTAAS:] = 0
raw_profile[...,NPROTAAS:] = 0
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
probs[...,MASKINDEX]=0.7
# explicitly set the probability of masking for nucleic acids and atoms
probs[...,is_protein(seq),MASKINDEX]=0.7
probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
probs[0,is_nucleic(seq), MASKINDEX] = 1.0
probs[0,is_atom(seq), aa2num["ATM"]] = 1.0
sampler = torch.distributions.categorical.Categorical(probs=probs)
mask_sample = sampler.sample()
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
mask_pos[msa_clust>MASKINDEX]=False # no masking on NAs
#mask_pos[msa_clust>MASKINDEX]=False # no masking on NAs
use_seq = msa_clust
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
b_seq.append(msa_masked[0].clone())
@@ -369,7 +290,6 @@ def blank_template(n_tmpl, L, random_noise=5.0):
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5):
seqID_cut = params['SEQID']
if npick_global == None:
@@ -764,6 +684,15 @@ def add_negative_sets(
valid_ID_dict["dude_inactives"] = dude_inactives_df["CLUSTER"].unique()
return train_ID_dict, valid_ID_dict, weights_dict, train_dict, valid_dict
def _load_df(filename, pad_hash=True, eval_cols=[]):
"""load dataframe, zero-pad hash string, parse columns as python objects"""
df = pd.read_csv(filename, na_filter=False) # prevents chain "NA" loading as NaN
if pad_hash: # restore leading zeros, make into string
df['HASH'] = df['HASH'].apply(lambda x: f'{x:06d}')
for col in eval_cols:
df[col] = df[col].apply(lambda x: ast.literal_eval(x)) # interpret as list of strings
return df
def get_train_valid_set(params, NEG_CLUSID_OFFSET=1000000, no_match_okay=False, diffusion_training=False, add_negatives: bool = True):
"""Loads training/validation sets as pandas DataFrames and returns them in
@@ -814,15 +743,6 @@ def get_train_valid_set(params, NEG_CLUSID_OFFSET=1000000, no_match_okay=False,
f"re-parsing train/valid metadata...")
# helper functions
def _load_df(filename, pad_hash=True, eval_cols=[]):
"""load dataframe, zero-pad hash string, parse columns as python objects"""
df = pd.read_csv(filename, na_filter=False) # prevents chain "NA" loading as NaN
if pad_hash: # restore leading zeros, make into string
df['HASH'] = df['HASH'].apply(lambda x: f'{x:06d}')
for col in eval_cols:
df[col] = df[col].apply(lambda x: ast.literal_eval(x)) # interpret as list of strings
return df
def _apply_date_res_cutoffs(df):
"""filter dataframe by date and resolution cutoffs"""
return df[(df.RESOLUTION <= params['RESCUT']) &
@@ -1359,6 +1279,22 @@ def get_na_crop(seq, xyz, mask, sel, len_s, params, negative=False, incl_protein
return sel
def adjust_samples_for_num_replicas(num_per_epoch_dict, num_replicas):
"""
Modifies the number of examples per epoch for each dataset to be divisible by num_replicas.
Args:
num_per_epoch_dict (dict): Mapping from dataset name to number of examples to sample from that dataset per epoch
num_replicas (int): The number of nodes/GPUs in the distributed training setup.
Returns:
dict: The modified num_per_epoch_dict where the number of examples per epoch for each dataset is divisible by num_replicas.
"""
adjusted_dict = {}
for dataset, num_per_epoch in num_per_epoch_dict.items():
# Round down to nearest multiple of num_replicas
adjusted_num = num_per_epoch // num_replicas * num_replicas
adjusted_dict[dataset] = adjusted_num
return adjusted_dict
def find_msa_hashes(protein_chain_info, params):
"""
@@ -1606,52 +1542,6 @@ def remove_all_gap_seqs(a3m):
a3m['ins'] = a3m['ins'][idx_seq_keep]
return a3m
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
ID, only the sequence with the highest sequence identity to the query (1st
sequence in MSA) will be paired.
Sequences that aren't paired will be padded and added to the bottom of the
joined MSA. If a subregion of the input MSAs overlap (represent the same
chain), the subregion residue indices can be given as `idx_overlap`, and
the overlap region of the unpaired sequences will be included in the joined
MSA.
Parameters
----------
a3mA : dict
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
boolean tensor indicating whether each sequence is fully paired. Can be
a multi-MSA (contain >2 sub-MSAs).
a3mB : dict
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
idx_overlap : tuple or list (optional)
Start and end indices of overlap region in 1st MSA, followed by the
same in 2nd MSA.
Returns
-------
a3m : dict
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
"""
# preprocess & sanity check overlap region
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
if idx_overlap is not None:
i1A, i2A, i1B, i2B = idx_overlap
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
else:
i1B_new, i2B_new = (0, L_B)
# paired sequences
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
i_pairedA, i_pairedB = [], []
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
@@ -2718,7 +2608,7 @@ def loader_na_complex(item, params, native_NA_frac=0.05, negative=False, pick_to
torch.load(params['PDB_DIR']+'/torch/pdb/'+pdb_ids[0][1:3]+'/'+pdb_ids[0]+'.pt'),
torch.load(params['PDB_DIR']+'/torch/pdb/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.pt')
]
filenameB1 = params['NA_DIR']+'/torch/'+pdb_ids[2][1:3]+'/'+pdb_ids[2]+'.pt'
filenameB1 = params['NA_DIR'] + '/torch/' + pdb_ids[2][1:3] + '/' + pdb_ids[2] + '.pt'
filenameB2 = params['NA_DIR']+'/torch/'+pdb_ids[3][1:3]+'/'+pdb_ids[3]+'.pt'
if os.path.exists(filenameB1+".v3"):
filenameB1 = filenameB1+".v3"
@@ -3230,7 +3120,7 @@ def loader_distil_tf(item, params, random_noise=5.0, pick_top=True, native_NA_fr
xyz, mask, _, pdbseq = parse_pdb(
params["TF_DIR"]+f'/distill_v2/filtered/{gene_id[:2]}/{gene_id}_{dnaseq}.pdb',
seq=True,
lddtmask=True
lddt_mask=True
)
xyz = torch.from_numpy(xyz)
@@ -3637,9 +3527,10 @@ def featurize_asmb_prot(pdb_id, partners, params, chains, asmb_xfs, modres,
chnum += 1
## protein templates
random_noise = 0.0
ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']+1)
if chid2hash is None or ntempl < 1:
xyz_t_ch, f1d_t_ch, mask_t_ch = \
xyz_t_ch, f1d_t_ch, mask_t_ch, tplt_ids_ch = \
blank_template(n_tmpl=1, L=xyz_ch.shape[1], random_noise=random_noise)
else:
pdb_hash = chid2hash[pdb_id+'_'+list(chlet_set)[0]] # chlet_set all have same hash
@@ -4020,11 +3911,15 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
"""
pdb_chain = item['CHAINID']
pdb_id = pdb_chain.split('_')[0]
# load pre-parsed cif assembly - requires cifutils.py in path for object definitions
chains, asmb, covale, modres = \
out = \
pickle.load(gzip.open(params['MOL_DIR']+f'/{pdb_id[1:3]}/{pdb_id}.pkl.gz'))
if len(out) == 4:
chains, asmb, covale, modres = out
elif len(out) == 5:
chains, asmb, covale, meta, modres = out
else:
raise ValueError(f"cif parser returns {len(out)} values")
# list of proteins and ligands to featurize
prot_partners = [p for p in item['PARTNERS'] if p[-1]=='polypeptide(L)']
prot_partners = prot_partners[:params['MAXPROTCHAINS']]
@@ -4075,6 +3970,7 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
# combine protein & ligand templates
N_tmpl = xyz_t_prot.shape[0]
random_noise = 0.0
if chid2smpartners is not None and params["SHOW_SM_TEMPLATES"]:
assert num_protein_chains == 1, "templating ligands not supported for multiple protein chains (complications in xyz_prev)"
xyz_t_sm, f1d_t_sm, mask_t_sm = generate_sm_template_feats(tplt_ids, resnames, akeys_sm, Ls_sm,chid2smpartners, params)
@@ -4165,7 +4061,8 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
sel = crop_sm_compl(xyz_prot, xyz_sm[0], Ls_prot + Ls_sm, params['CROP'], mask_prot,
seq_prot, select_farthest_residues=select_farthest_residues)
else:
sel = crop_sm_compl_assembly(xyz[0], mask[0], Ls_prot, Ls_sm, params['CROP'])
#sel = crop_sm_compl_assembly(xyz[0], mask[0], Ls_prot, Ls_sm, params['CROP'])
sel = crop_sm_compl_asmb_contig(xyz[0], mask[0], Ls_prot, Ls_sm, bond_feats, params['CROP'], use_partial_ligands=False)
mask = reassign_symmetry_after_cropping(sel, Ls_prot, ch_label, mask, item)
msa = msa[:, sel]
@@ -4202,7 +4099,7 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
if max_msa_seqs is not None:
msa = msa[:max_msa_seqs]
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = \
MSAFeaturize(msa.long(), ins.long(), params, term_info=term_info, fixbb=fixbb, seed_msa_clus=seed_msa_clus)
MSAFeaturize(msa.long(), ins.long(), params, p_mask=params["p_msa_mask"], term_info=term_info, fixbb=fixbb, seed_msa_clus=seed_msa_clus)
return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\
xyz.float(), mask, idx.long(), \
@@ -4693,6 +4590,7 @@ def crop_sm_compl(prot_xyz, lig_xyz, Ls, crop_size, mask_prot, seq_prot,
remaining_residues_to_select = crop_size - len(lig_xyz) - min_resolved_residues
assert remaining_residues_to_select > 0, f"For some reason I encountered a scenario in which we are cropping a protein but I was unable to select enough residues from that protein. This probably means you passed a protein that was too short into this function. {crop_size}, {len(lig_xyz)}, {min_resolved_residues}"
dist[sel_min_resolved] = nan_fill_value
remaining_residues_to_select = min(remaining_residues_to_select, len(dist))
_, sel_remaining = torch.topk(dist, remaining_residues_to_select)
sel_unsorted = torch.concat((sel_min_resolved, sel_remaining))
sel, _ = torch.sort(sel_unsorted)
@@ -4845,6 +4743,121 @@ def crop_sm_compl_assembly(all_xyz, all_mask, Ls_prot, Ls_sm, n_crop, use_partia
sel = prot_sel
return torch.from_numpy(sel).long()
def crop_sm_compl_asmb_contig(all_xyz, all_mask, Ls_prot, Ls_sm, bond_feats, n_crop, use_partial_ligands=False):
"""
instead of conducting a radial crop around a random atom, construct a crop with contiguous protein segments
the way this works is that a graph data structure is constructed where contiguous residues are connected,
close interchain contacts are connected and residues within a ligand are fully connected. each edge is weighted
and then the crop is chosen by selecting a random residue and traversing the graph to find the n_crop closest nodes
"""
def find_edges_based_on_distance(all_xyz, all_mask, chain_i_start_index, chain_i_end_index, chain_j_start_index, chain_j_end_index, dist_cutoff):
xyz_chain_i = all_xyz[chain_i_start_index:chain_i_end_index]
xyz_chain_j = all_xyz[chain_j_start_index:chain_j_end_index]
dist = torch.cdist(xyz_chain_i[:, 1], xyz_chain_j[:, 1]) # calpha distogram
chain_i_ca_mask = all_mask[chain_i_start_index:chain_i_end_index, 1]
chain_j_ca_mask = all_mask[chain_j_start_index:chain_j_end_index, 1]
mask_2d = chain_i_ca_mask[:, None] * chain_j_ca_mask[None, :]
dist[~mask_2d] = 99999
new_edges = (dist<dist_cutoff).nonzero()
return new_edges
L = all_xyz.shape[0]
num_prot_chains = len(Ls_prot)
num_sm_chains = len(Ls_sm)
# construct weighted graph
graph = np.full((L, L), n_crop, dtype=np.float32)
# set neighboring residues to have edge weight = 1
for chain_index, L_prot in enumerate(Ls_prot):
chain_start_index = sum(Ls_prot[:chain_index])
residues = torch.arange(L_prot-1) + chain_start_index
graph[residues, residues+1] = 1
graph[residues+1, residues] = 1
# set all intra ligand chain values to 0 so that if one atom is sampled the whole ligand is sampled (we will still confirm this later)
total_protein_L = sum(Ls_prot)
for chain_index in range(len(Ls_sm)):
chain_start_index = sum(Ls_sm[:chain_index])+ total_protein_L
chain_end_index = sum(Ls_sm[:chain_index+1])+ total_protein_L
graph[chain_start_index: chain_end_index, chain_start_index:chain_end_index] = 0.1
# set interchain edges between protein chains
for chain_i, chain_j in itertools.combinations(range(num_prot_chains), 2):
chain_i_start_index = sum(Ls_prot[:chain_i])
chain_i_end_index = sum(Ls_prot[:chain_i+1])
chain_j_start_index = sum(Ls_prot[:chain_j])
chain_j_end_index = sum(Ls_prot[:chain_j+1])
new_edges = find_edges_based_on_distance(all_xyz, all_mask, chain_i_start_index, chain_i_end_index, chain_j_start_index, chain_j_end_index, dist_cutoff=8)
for edge in new_edges:
start = edge[0] + chain_i_start_index
end = edge[1] +chain_j_start_index
graph[start,end] = 8
graph[end, start]= 8
# set interchain edges between proteins and small molecules (non_covalent)
for protein_chain, sm_chain in itertools.product(range(num_prot_chains), range(num_sm_chains)):
protein_chain_start_index = sum(Ls_prot[:protein_chain])
protein_chain_end_index = sum(Ls_prot[:protein_chain+1])
sm_chain_start_index = sum(Ls_sm[:sm_chain]) + total_protein_L
sm_chain_end_index = sum(Ls_sm[:sm_chain+1]) + total_protein_L
if torch.any(bond_feats[protein_chain_start_index:protein_chain_end_index][:, sm_chain_start_index: sm_chain_end_index] == 8): # skip chains that are covalently connected
continue
new_edges = find_edges_based_on_distance(all_xyz, all_mask, protein_chain_start_index, protein_chain_end_index, sm_chain_start_index, sm_chain_end_index, dist_cutoff=5)
for edge in new_edges:
start = edge[0] + protein_chain_start_index
end = edge[1] +sm_chain_start_index
graph[start,end] = 2
graph[end, start]= 2
# edges to covalent modifications should be similar to residue edges not ligand edges
covalent_bonds = (bond_feats==6).nonzero()
for bond in covalent_bonds:
graph[bond[0], bond[1]] = 1
graph[bond[1], bond[0]] = 1
# find an interface residue to start at by finding random residue near a ligand atom
starting_edges = find_edges_based_on_distance(all_xyz, all_mask, 0,Ls_prot[0],total_protein_L, total_protein_L+Ls_sm[0], dist_cutoff=10)
if starting_edges.numel() == 0:
startres = random.randint(0,Ls_prot[0])
else:
startres = random.choice(starting_edges[:, 0].unique()).item()
d_res = shortest_path(graph, directed = False, indices=startres)
n_crop = min(d_res.shape[0], n_crop)
_, idx = torch.topk(torch.from_numpy(d_res).to(device=all_xyz.device), n_crop, largest=False)
sel, _ = torch.sort(idx)
#make sure that all ligands were fully pulled into the crop
# print(f"total number of chain: {len(Ls_sm)}")
for sm_chain_index, L_sm in enumerate(Ls_sm):
sm_chain_start_index = sum(Ls_sm[:sm_chain_index]) + total_protein_L
chain_indices = torch.arange(L_sm) + sm_chain_start_index
chain_in_crop = torch.isin(chain_indices, sel) # tensor with length chain_indices indicating which elements from chain_indices are in sel
is_subset = torch.all(chain_in_crop)
has_overlap = torch.any(chain_in_crop)
if has_overlap == True:
if is_subset == False:
#if sm_chain_index == 0:
# print("WARNING: PART OF QUERY LIGAND WAS CROPPED; ADDING REST BACK IN")
# sel = torch.cat((sel, chain_indices), dim=0)
# sel = sel.unique()
# continue
print("WARNING: removing partially cropped small molecule")
print(f"chain: {sm_chain_index}")
if use_partial_ligands == False:
crop_in_chain = torch.isin(sel, chain_indices) # tensor with length of sel indicating which indices in sel are also in chain_indices
sel = sel[~crop_in_chain]
else:
sel = torch.cat((sel, chain_indices), dim=0)
sel = sel.unique()
return sel
def crop_chirals(chirals, atom_sel):
"""
this function returns only chiral centers that appear in molecules that are chosen after cropping
@@ -4905,7 +4918,6 @@ def sample_item_sm_compl(df, ID, dedup_ligand=True):
# uniformly sample from unique PDB chains
chid = np.random.choice(tmp_df.CHAINID.drop_duplicates().values)
tmp_df = tmp_df[tmp_df.CHAINID==chid]
if dedup_ligand and "LIGAND" in tmp_df:
# uniform sample from unique ligands
lignames = list(set([x[0][2] for x in tmp_df['LIGAND']]))
@@ -4914,7 +4926,7 @@ def sample_item_sm_compl(df, ID, dedup_ligand=True):
item = tmp_df.sample(1).to_dict(orient='records')[0] # choose 1 random row
return copy.deepcopy(item) # prevents dataframe from being modified by downstream changes
class Dataset(data.Dataset):
def __init__(
@@ -5401,7 +5413,6 @@ class DistilledDataset(data.Dataset):
except Exception as e:
print('error loading',item, '\n',repr(e), task)
raise e
return out
class DistributedWeightedSampler(data.Sampler):
@@ -5436,83 +5447,164 @@ class DistributedWeightedSampler(data.Sampler):
),
num_replicas=None,
rank=None,
replacement=False
datasets_with_replacement=["pdb", "fb", "compl", "neg_compl", "na_compl", "neg_na_compl", "distil_tf", "tf", "neg_tf", "rna", "dna"],
lengths=None,
batch_by_dataset=False,
batch_by_length=False,
):
if num_replicas is None:
if not dist.is_available():
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not dist.is_available():
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
rank = torch.distributed.get_rank()
assert num_example_per_epoch % num_replicas == 0
assert (num_example_per_epoch % num_replicas) == 0, "Please ensure that the number of examples per epoch is evenly divisible by the number of nodes"
assert (np.allclose(sum([v for k,v in fractions.items()]), 1.0)), \
f"Fractions of datasets add up to {sum([v for k,v in fractions.items()])}, should add up to 1.0"
# Load lengths into a tensor, if file exists
if lengths is not None and os.path.isfile(lengths):
lengths = torch.load(lengths)
else:
if lengths is not None:
warnings.warn(f"Lengths file {lengths} does not exist. Ignoring lengths file.")
lengths = None
if batch_by_length:
assert lengths is not None, "If batching by length, must pass a valid lengths tensor."
if not batch_by_dataset:
assert not batch_by_length, "Cannot batch by length without also batching by dataset."
self.dataset = dataset
self.weights_dict = weights_dict
self.num_replicas = num_replicas
self.num_per_epoch_dict = OrderedDict([
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
for dataset_name in self.dataset.dataset_dict.keys()
])
self.lengths = lengths
self.batch_by_length = batch_by_length
self.batch_by_dataset = batch_by_dataset
if batch_by_dataset:
# Ensure that all GPU's can process an example from the same dataset at once
self.num_per_epoch_dict = adjust_samples_for_num_replicas(
OrderedDict([
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
for dataset_name in self.dataset.dataset_dict.keys()
]),
num_replicas
)
else:
self.num_per_epoch_dict = OrderedDict([
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
for dataset_name in self.dataset.dataset_dict.keys()
])
# account for rounding error
# Account for rounding error
dataset_names = list(self.dataset.dataset_dict.keys())
nonzero_dataset_names = [name for name in dataset_names if self.num_per_epoch_dict[name] > 0]
# Calculate the actual number of examples that will be sampled (will be a multiple of num_replicas)
num_per_epoch_actual = sum([self.num_per_epoch_dict[name] for name in nonzero_dataset_names])
self.num_per_epoch_dict[nonzero_dataset_names[0]] += num_example_per_epoch - num_per_epoch_actual
# Handle remainders by rounding down to the nearest multiple of num_replicas and sampling from `pdb`
remainder = num_example_per_epoch - num_per_epoch_actual
remainder = remainder - (remainder % num_replicas)
self.num_per_epoch_dict[nonzero_dataset_names[0]] += remainder # The first dataset is the pdb
self.total_size = num_example_per_epoch
self.total_size = num_per_epoch_actual + remainder
self.num_samples = self.total_size // self.num_replicas
self.rank = rank
self.epoch = 0
self.replacement = replacement
# Sample the protein datasets with replacement to account for length weighting
# Other datasets (e.g., small molecule datasets) will be sampled WITHOUT replacement (since LEN_EXIST is not the appropriate weighting)
self.datasets_with_replacement = datasets_with_replacement
if (rank==0):
print(f"Training examples per epoch ({self.total_size} total):")
for k,v in self.num_per_epoch_dict.items():
print(' '+k, ':', v)
def _sample_from_dataset(self, dataset_name, g):
"""
Samples a specified number of sequences from the given dataset.
Samples with replacement based on the dataset type, forcing replacement if sampling more than dataset length.
Parameters:
dataset_name (str): The name of the dataset to sample from.
g (torch.Generator): A pre-seeded generator to ensure consistency across nodes.
Returns:
Tensor: A tensor of sampled indices from the dataset.
"""
# Throw warning if the number of sequences to be sampled is not more than the number of sequences in the dataset
if self.num_per_epoch_dict[dataset_name] > len(self.dataset.ID_dict[dataset_name]):
warnings.warn(f"Number of sequences to be sampled in one epoch is greater than the number of " \
f"sequences in the dataset. Must sample with replacement. Ensure that this is the desired behavior. Dataset: {dataset_name}, " \
f"Dataset length: {len(self.dataset.ID_dict[dataset_name])}, " \
f"# to be sampled: {self.num_per_epoch_dict[dataset_name]}")
# Determine if sampling with replacement based on the dataset type, forcing replacement if sampling more than dataset length
replacement = dataset_name in self.datasets_with_replacement or self.num_per_epoch_dict[dataset_name] > len(self.dataset.ID_dict[dataset_name])
# Sample indices from the dataset based on the weights (prefer longer sequences)
return torch.multinomial(self.weights_dict[dataset_name],
self.num_per_epoch_dict[dataset_name],
generator=g,
replacement=replacement)
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
# get indices (fb + pdb models)
# get indices (all models)
indices = torch.arange(len(self.dataset))
# weighted subsampling
# order of datasets in this loop should match order in DistilledDataset.__getitem__()
offset = 0
sel_indices = torch.tensor((),dtype=int)
print(self.dataset.dataset_dict.keys())
for dataset_name in self.dataset.dataset_dict.keys():
if self.batch_by_dataset:
for dataset_name in self.dataset.dataset_dict.keys():
if (self.num_per_epoch_dict[dataset_name]> 0):
# Sample and adjust for offset; _sample_from_dataset handles replacement
sampled_idx = self._sample_from_dataset(dataset_name, g) + offset
# Divide sampled_idx into num_replicas chunks and assign each chunk to a node
sampled_idx_split = torch.split(sampled_idx, len(sampled_idx) // self.num_replicas)
assert all([len(x) == len(sampled_idx_split[0]) for x in sampled_idx_split])
# If also batching by sequence length, sort the indices by length
if self.batch_by_length and self.lengths is not None:
sampled_idx_split = [x[torch.argsort(self.lengths[x])] for x in sampled_idx_split]
# Add the sampled indices to the running tensor based on the node rank
sel_indices = torch.cat((sel_indices, indices[sampled_idx_split[self.rank]]))
offset += len(self.dataset.ID_dict[dataset_name])
if (self.num_per_epoch_dict[dataset_name]> 0):
sampled_idx = torch.multinomial(self.weights_dict[dataset_name],
self.num_per_epoch_dict[dataset_name],
self.replacement,
generator=g)
sel_indices = torch.cat((sel_indices, indices[sampled_idx + offset]))
# For each node, the indices are shuffled with the same seed, and so will draw from the same datasets in the same order
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
else:
# Standard implementation of WeightedDistributedSampler without batching by dataset or length
for dataset_name in self.dataset.dataset_dict.keys():
if (self.num_per_epoch_dict[dataset_name]> 0):
sampled_idx = self._sample_from_dataset(dataset_name, g)
sel_indices = torch.cat((sel_indices, indices[sampled_idx + offset]))
offset += len(self.dataset.ID_dict[dataset_name])
# shuffle indices
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
# shuffle indices
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
# per each gpu
indices = indices[self.rank:self.total_size:self.num_replicas]
# per each gpu
indices = indices[self.rank:self.total_size:self.num_replicas]
#print('rank',self.rank,': expecting',self.num_samples,'examples, drew',len(indices),'examples')
assert len(indices) == self.num_samples # more stringent, switch with line above during debugging
assert len(indices) == self.num_samples # more stringent, switch with line above during debugging
return iter(indices.tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
self.epoch = epoch

View File

@@ -0,0 +1,158 @@
import torch
from rf2aa.kinematics import xyz_to_t2d
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs
from rf2aa.util import INIT_CRDS, NTOTAL, NTOTALDOFS, is_atom, \
Ls_from_same_chain_2d, xyz_t_to_frame_xyz, get_prot_sm_mask
def prepare_input(inputs, xyz_converter, gpu):
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
# transfer inputs to device
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3)
mask_crds = mask_crds.to(gpu, non_blocking=True) # (B, L, 27)
same_chain = same_chain.to(gpu, non_blocking=True)
xyz_t = xyz_t.to(gpu, non_blocking=True)
t1d = t1d.to(gpu, non_blocking=True)
mask_t = mask_t.to(gpu, non_blocking=True)
#fd --- use black hole initialization
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
mask_prev = torch.zeros((1,L,NTOTAL), dtype=torch.bool).to(gpu, non_blocking=True)
atom_frames = atom_frames.to(gpu, non_blocking=True)
bond_feats = bond_feats.to(gpu, non_blocking=True)
dist_matrix = dist_matrix.to(gpu, non_blocking=True)
chirals = chirals.to(gpu, non_blocking=True)
assert (len(symmgp)==1)
symmgp = symmgp[0]
# symmetry - reprocess (many) inputs
if (symmgp != 'C1'):
Lasu = L//2 # msa contains intra/inter block
symmids, symmRs, symmmeta, symmoffset = symm_subunit_matrix(symmgp)
symmids = symmids.to(gpu, non_blocking=True)
symmRs = symmRs.to(gpu, non_blocking=True)
symmoffset = symmoffset.to(gpu, non_blocking=True)
symmmeta = (
[x.to(gpu, non_blocking=True) for x in symmmeta[0]],
symmmeta[1])
O = symmids.shape[0]
xyz_prev = xyz_prev + symmoffset*Lasu**(1/3)
# find contacting subunits
xyz_prev, symmsub = find_symm_subs(xyz_prev[:,:Lasu], symmRs, symmmeta)
symmsub = symmsub.to(gpu, non_blocking=True)
Osub = symmsub.shape[0]
mask_prev = mask_prev[:,:L].repeat(1,Osub,1)
# symmetrize msa
seq = torch.cat([seq[:,:,:Lasu],*[seq[:,:,Lasu:]]*(Osub-1)], dim=2)
msa = torch.cat([msa[:,:,:,:Lasu],*[msa[:,:,:,Lasu:]]*(Osub-1)], dim=3)
msa_masked = torch.cat([msa_masked[:,:,:,:Lasu],*[msa_masked[:,:,:,Lasu:]]*(Osub-1)], dim=3)
msa_full = torch.cat([msa_full[:,:,:,:Lasu],*[msa_full[:,:,:,Lasu:]]*(Osub-1)], dim=3)
mask_msa = torch.cat([mask_msa[:,:,:,:Lasu],*[mask_msa[:,:,:,Lasu:]]*(Osub-1)], dim=3)
# symmetrize templates
xyz_t = xyz_t[:,:,:Lasu].repeat(1,1,Osub,1,1)
mask_t = mask_t[:,:,:Lasu].repeat(1,1,Osub,1)
t1d = t1d[:,:,:Lasu].repeat(1,1,Osub,1)
# symmetrize atom_frames
atom_frames = torch.cat([atom_frames[:,:,:Lasu],*[atom_frames[:,:,Lasu:]]*(Osub-1)], dim=2)
# index, same chain, bond feats
idx_pdb = torch.arange(Osub*Lasu, device=gpu)[None,:]
same_chain = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
bond_feats_new = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
dist_matrix_new = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
for o_i in range(Osub):
same_chain[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = 1
idx_pdb[:,o_i*Lasu:(o_i+1)*Lasu] += 100*o_i
bond_feats_new[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = bond_feats
dist_matrix_new[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = dist_matrix
bond_feats = bond_feats_new
dist_matrix = dist_matrix_new
else:
Lasu = L
Osub = 1
symmids = None
symmsub = None
symmRs = None
symmmeta = None
# processing template features
mask_t_2d = get_prot_sm_mask(mask_t, seq[0][0])
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
# we can provide sm_templates so we want to allow interchain templates bw protein chain 1 and sms
# specifically the templates are found for the query protein chain
Ls = Ls_from_same_chain_2d(same_chain)
prot_ch1_to_sm_2d = torch.zeros_like(same_chain)
prot_ch1_to_sm_2d[:, :Ls[0], is_atom(seq)[0][0]] = 1
prot_ch1_to_sm_2d[:, is_atom(seq)[0][0], :Ls[0]] = 1
is_possible_t2d = same_chain.clone()
is_possible_t2d[prot_ch1_to_sm_2d.bool()] = 1
mask_t_2d = mask_t_2d.float() * is_possible_t2d.float()[:,None] # (ignore inter-chain region between proteins)
xyz_t_frame = xyz_t_to_frame_xyz(xyz_t, msa[:, 0,0], atom_frames)
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d)
# get torsion angles from templates
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,Lasu*Osub)
alpha, _, alpha_mask, _ = xyz_converter.get_torsions(xyz_t.reshape(-1,Lasu*Osub,NTOTAL,3), seq_tmp, mask_in=mask_t.reshape(-1,Lasu*Osub,NTOTAL))
alpha = alpha.reshape(B,-1,Lasu*Osub,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,Lasu*Osub,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, Lasu*Osub, 3*NTOTALDOFS)
alpha_prev = torch.zeros((B,Lasu*Osub,NTOTALDOFS,2))
network_input = {}
network_input['msa_latent'] = msa_masked
network_input['msa_full'] = msa_full
network_input['seq'] = seq
network_input['seq_unmasked'] = msa[:,0,0]
network_input['idx'] = idx_pdb
network_input['t1d'] = t1d
network_input['t2d'] = t2d
network_input['xyz_t'] = xyz_t[:,:,:,1]
network_input['alpha_t'] = alpha_t
network_input['mask_t'] = mask_t_2d
network_input['same_chain'] = same_chain
network_input['bond_feats'] = bond_feats
network_input['dist_matrix'] = dist_matrix
network_input['chirals'] = chirals
network_input['atom_frames'] = atom_frames
network_input['symmids'] = symmids
network_input['symmsub'] = symmsub
network_input['symmRs'] = symmRs
network_input['symmmeta'] = symmmeta
network_input["xyz_prev"] = xyz_prev
network_input["alpha_prev"] = alpha_prev
network_input["mask_recycle"] = None
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
def get_loss_calc_items(inputs,device="cpu"):
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
return seq.to(device), same_chain.to(device), idx_pdb.to(device), bond_feats.to(device), dist_matrix.to(device), atom_frames.to(device)

View File

@@ -2,7 +2,7 @@ import numpy as np
import torch
import networkx as nx
from typing import Dict, Optional, Tuple, List, Set, Any
from rf2aa import cifutils
import rf2aa.cifutils as cifutils
from rf2aa.util import get_ligand_atoms_bonds, cif_ligand_to_xyz, cif_ligand_to_obmol, get_automorphs

View File

@@ -479,18 +479,19 @@ def parse_a3m(filename, maxseq=8000, paired=False):
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False):
def parse_pdb(filename, seq=False, lddt_mask=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines)
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines):
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # (chain letter, res num, aa)
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
plddt = [float(r[3]) for r in res]
seq = [aa2num[r[2]] if r[2] in aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
@@ -525,6 +526,12 @@ def parse_pdb_lines_w_seq(lines):
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
if lddt_mask == True:
plddt = np.array(plddt)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > .85, 5:] = True
mask_lddt[plddt > .70, :5] = True
mask = np.logical_and(mask, mask_lddt)
return xyz,mask,np.array(idx_s), np.array(seq)
@@ -532,7 +539,7 @@ def parse_pdb_lines_w_seq(lines):
def parse_pdb_lines(lines):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # (chain letter, res num, aa)
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]

18
rf2aa/debug.py Normal file
View File

@@ -0,0 +1,18 @@
import torch
def debug_nans(latent_feats):
for k, v in latent_feats.items():
if torch.is_tensor(v):
print(k)
print(torch.sum(v.isnan()))
def debug_unused_params(model):
for name, param in model.named_parameters():
if param.grad is None:
print(name)
def debug_used_params(model):
for name, param in model.named_parameters():
if param.grad is not None:
print(name)

125
rf2aa/debug_item.py Normal file
View File

@@ -0,0 +1,125 @@
import unittest
from hydra import compose, initialize
import torch
from rf2aa.chemical import NBTYPES, NTOTAL
from rf2aa.data.compose_dataset import compose_single_item_dataset, set_data_loader_params
from rf2aa.data.data_loader import loader_atomize_pdb
from rf2aa.data.dataloader_adaptor import prepare_input
from rf2aa.util import is_atom, writepdb
from rf2aa.tensor_util import assert_shape
from rf2aa.trainer_new import trainer_factory
from rf2aa.training.recycling import recycle_step_legacy
#### Setup test case hyperparams
ITEM = \
{'Unnamed: 0': 262672, 'CHAINID': '6ywe_UB', 'DEPOSITION': '2020-04-29', 'RESOLUTION': 2.9900, 'HASH': '072380', 'CLUSTER': 9905, 'SEQUENCE': 'MPNKPIRLPPLKQLRVRQANKAEENPCIAVMSSVLACWASAGYNSAGCATVENALRACMDAPKPAPKPNNTINYHLSRFQERLTQGKSKK', 'LEN_EXIST': 88, 'TAXID': '5141'}
CONFIG = "legacy_train"
LOADER_FN = loader_atomize_pdb
LOADER_KWARGS = {
"homo": None,
"n_res_atomize": 5,
"flank": 0
}
class DebugTestCase(unittest.TestCase):
def setUp(self) -> None:
with initialize(version_base=None, config_path="config/train"):
self.cfg = compose(config_name=CONFIG)
loader_params = set_data_loader_params(self.cfg.loader_params)
loader = compose_single_item_dataset(
ITEM,
loader_params,
LOADER_FN,
LOADER_KWARGS
)
self.loader = loader
def test_correct_shapes(self):
""" test shapes are all consistent with each other """
for inputs in self.loader:
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
B, recycles, N, L = msa.shape[:4]
num_atoms = (is_atom(seq[0,0]).sum()).item()
assert_shape(seq, (B, recycles, L))
assert_shape(msa, (B, recycles, N, L))
assert_shape(msa_masked, (B, recycles, N, L, 164)) #Hack: hardcoded for current featurization
N_full = msa_full.shape[2]
assert_shape(msa_full, (B, recycles, N_full, L, 83)) #HACK:: hardcoded for current features
assert_shape(mask_msa, (B, recycles, N, L))
N_symm = true_crds.shape[1]
assert_shape(true_crds, (B, N_symm, L, NTOTAL, 3))
assert_shape(mask_crds, (B, N_symm, L, NTOTAL))
assert_shape(idx_pdb, (B, L))
N_templ = xyz_t.shape[1]
assert_shape(xyz_t, (B, N_templ, L, NTOTAL, 3))
assert_shape(t1d, (B, N_templ, L, 80)) # hack hard coded dimension
assert_shape(mask_t, (B, N_templ, L, NTOTAL))
assert_shape(xyz_prev, (B, L, NTOTAL, 3))
assert_shape(mask_prev, (B, L, NTOTAL))
assert_shape(same_chain, (B, L, L))
assert type(unclamp.item()) == bool
assert type(negative.item()) == bool
assert_shape(atom_frames, (B, num_atoms, 3,2))
assert_shape(bond_feats, (B, L, L))
assert_shape(dist_matrix, (B, L, L))
n_chirals = chirals.shape[1]
assert_shape(chirals, (B, n_chirals, 5))
assert_shape(ch_label, (B, L))
assert symmgp[0] == "C1", f"{symmgp}"
def test_forward_pass(self):
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
trainer.construct_model()
trainer.model.device = "cpu"
trainer.move_constants_to_device(gpu="cpu")
for inputs in self.loader:
loss, loss_dict = trainer.train_step(inputs, 1)
def test_forward_pass_with_checkpoint(self):
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
trainer.construct_model()
trainer.model.device = "cpu"
trainer.move_constants_to_device(gpu="cpu")
checkpoint_path = "/home/rohith/rf2a-fd3/models/rf2a_fd3_20221125_714.pt"
trainer.checkpoint = torch.load(checkpoint_path, map_location="cpu")
trainer.model.model.load_state_dict(trainer.checkpoint["final_state_dict"])
trainer.model.shadow.load_state_dict(trainer.checkpoint["model_state_dict"])
for inputs in self.loader:
loss, loss_dict = trainer.train_step(inputs, 1)
#TODO: check something about the loss
def test_forward_pass_outputs(self):
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
trainer.construct_model()
trainer.model.device = "cpu"
trainer.move_constants_to_device(gpu="cpu")
checkpoint_path = "/home/rohith/rf2a-fd3/models/rf2a_fd3_20221125_714.pt"
trainer.checkpoint = torch.load(checkpoint_path, map_location="cpu")
trainer.model.model.load_state_dict(trainer.checkpoint["final_state_dict"])
trainer.model.shadow.load_state_dict(trainer.checkpoint["model_state_dict"])
for inputs in self.loader:
gpu = trainer.model.device
# HACK: certain features are constructed during the train step
# in the future this should only promote the constructed features onto gpu
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input(inputs, trainer.xyz_converter, gpu)
n_cycle = 1
output_i = recycle_step_legacy(trainer.model, network_input, n_cycle, trainer.config.training_params.use_amp)
c6d, mlm, pae, pde, p_bind, xyz, alphas, _, _, _, _, _ = output_i
seq_unmasked = network_input["seq_unmasked"]
writepdb("test.pdb", xyz[-1], seq_unmasked)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,40 +0,0 @@
from rf2aa.train_multi_EMA import Trainer
from rf2aa.evaluate import Evaluator
def trainer_factory(args, dataset_param, model_param, loader_param, loss_param):
dataloader_kwargs = {
"shuffle": args.shuffle_dataloader,
"num_workers": args.dataloader_num_workers,
"pin_memory": not args.dont_pin_memory,
}
trainer_class = Trainer
if args.mode == "eval":
trainer_class = Evaluator
args.eval = True
trainer_object = trainer_class(
model_name=args.model_name,
n_epoch=args.num_epochs,
step_lr=args.step_lr,
lr=args.lr,
l2_coeff=1.0e-2,
port=args.port,
model_param=model_param,
loader_param=loader_param,
loss_param=loss_param,
batch_size=args.batch_size,
accum_step=args.accum,
maxcycle=args.maxcycle,
eval=args.eval,
interactive=args.interactive,
out_dir=args.out_dir,
wandb_prefix=args.wandb_prefix,
model_dir=args.model_dir,
dataset_param=dataset_param,
dataloader_kwargs=dataloader_kwargs,
debug_mode=args.debug,
skip_valid=args.skip_valid,
start_epoch=args.start_epoch,
)
return trainer_object

View File

@@ -1,248 +0,0 @@
import sys, os, json
import time
import numpy as np
import torch
import torch.nn as nn
from loss_halluc import calc_entropy_loss, calc_pae_loss
from optimization import run_gradient_descent, run_mcmc
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0,script_dir+'/models/fold_and_dock3/')
import parsers
from RoseTTAFoldModel import RoseTTAFoldModule
from data_loader import merge_a3m_hetero
import util
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
from model_params import MODEL_PARAM
def get_args():
import argparse
parser = argparse.ArgumentParser(description="RF2-allatom hallucination: protein scaffold design with explicit modeling of small-molecule ligands")
parser.add_argument("-checkpoint",
default="/databases/TrRosetta/RF2_allatom/checkpoints/rf2a_fd3_20221125_115.pt",
help="Path to model weights")
parser.add_argument("-pdb", help='PDB of motif')
parser.add_argument("-parse_hetatm", action="store_true", default=False, help="parse ligand information from input pdb")
parser.add_argument("-mol2", help='mol2 of small molecule')
parser.add_argument("-num", type=int, default=1, help='number of designs')
parser.add_argument("-start_num", type=int, default=0, help='start number of output designs')
parser.add_argument("-grad_steps", type=int, required=True, help='number of gradient descent steps')
parser.add_argument("-mcmc_steps", type=int, required=True, help='number of mcmc steps')
parser.add_argument("-out", help='prefix of output files')
parser.add_argument("-L", type=int, help='length of hallucinated protein')
parser.add_argument("-T0", type=float, default=0.02, help='initial temperature for simulated annealing')
parser.add_argument("-mcmc_halflife", default=100, help='half-life of simulated annealing')
parser.add_argument("-cycles", type=str, default='10', help='number of recycles')
parser.add_argument("-seq_prob_type", default='hard', help='soft or hard probabilities of sequence')
parser.add_argument("-init_sd", type=float, default=1e-6, help='random initial logit standard deviation')
parser.add_argument("-template_ligand", action='store_true', default=False, help='input template features for ligand structure')
parser.add_argument("-init_ligand_xyz", action='store_true', default=False, help='input initial xyz coords for ligand structure')
parser.add_argument("-learning_rate", type=float, default=0.05, help='gradient descent learning rate')
parser.add_argument("-device", type=str, default='cuda:0', help='gpu to run on')
parser.add_argument("-w_ent", type=float, default=1.0, help='weight of entropy loss')
parser.add_argument("-w_pae", type=float, default=1.0, help='weight of pae loss')
parser.add_argument("-w_ipae", type=float, default=1.0, help='weight of inter-pae loss')
args = parser.parse_args()
return args
# compute expected value from binned lddt
def lddt_unbin(pred_lddt):
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
def load_model(args, MODEL_PARAM):
device = args.device
model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask = util.allatom_mask.to(device),
atom_type_index = util.atom_type_index.to(device),
ljlk_parameters = util.ljlk_parameters.to(device),
lj_correction_parameters = util.lj_correction_parameters.to(device),
num_bonds = util.num_bonds.to(device),
cb_len = util.cb_length_t.to(device),
cb_ang = util.cb_angle_t.to(device),
cb_tor = util.cb_torsion_t.to(device),
).to(device)
checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
return model
def prepare_inputs(args, init_seq=None, random_noise=5):
B = 1 # batch
N = 1 # msa depth
protein_L = args.L
device = args.device
if init_seq is None:
msa_prot = torch.randint(20, (1,protein_L))
else:
msa_prot = init_seq
ins_prot = torch.zeros((1,protein_L)).long()
idx_prot = torch.arange(protein_L)
if args.mol2 is not None:
a3m_prot = {"msa": msa_prot, "ins": ins_prot}
mol, msa_sm, ins_sm, xyz_sm, mask_sm = parsers.parse_mol(args.mol2)
a3m_sm = {"msa": msa_sm.unsqueeze(0), "ins": ins_sm.unsqueeze(0)}
G = util.get_nxgraph(mol)
atom_frames = util.get_atom_frames(msa_sm, G)
N_symmetry, sm_L, _ = xyz_sm.shape
Ls = [protein_L, sm_L]
a3m = merge_a3m_hetero(a3m_prot, a3m_sm, Ls)
msa = a3m['msa'].long()
ins = a3m['ins'].long()
chirals = get_chirals(mol, xyz_sm[0])
xyz = torch.full((N_symmetry, sum(Ls), NTOTAL, 3), np.nan).float()
mask = torch.full(xyz.shape[:-1], False).bool()
if args.pdb is not None:
xyz[:, :Ls[0], :nprotatoms, :] = xyz_prot.expand(N_symmetry, Ls[0], nprotatoms, 3)
mask[:, :protein_L, :nprotatoms] = mask_prot.expand(N_symmetry, Ls[0], nprotatoms)
if args.mol2 is not None:
xyz[:, Ls[0]:, 1, :] = xyz_sm
mask[:, protein_L:, 1] = mask_sm
idx_sm = torch.arange(max(idx_prot),max(idx_prot)+Ls[1])+200
idx_pdb = torch.concat([idx_prot, idx_sm])
chain_idx = torch.zeros((sum(Ls), sum(Ls))).long()
chain_idx[:Ls[0], :Ls[0]] = 1
chain_idx[Ls[0]:, Ls[0]:] = 1
bond_feats = torch.zeros((sum(Ls), sum(Ls))).long()
bond_feats[:Ls[0], :Ls[0]] = util.get_protein_bond_feats(Ls[0])
if args.mol2 is not None:
bond_feats[Ls[0]:, Ls[0]:] = util.get_bond_feats(mol)
# blank template
xyz_t = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,sum(Ls),1,1) \
+ torch.rand(1,sum(Ls),1,3)*random_noise - random_noise/2
f1d_t = torch.nn.functional.one_hot(torch.full((1, sum(Ls)), 20).long(), num_classes=NAATOKENS-1).float() # all gaps
conf = torch.zeros((1, sum(Ls), 1)).float()
f1d_t = torch.cat((f1d_t, conf), -1)
mask_t = torch.full((1,sum(Ls),NTOTAL), False)
if args.template_ligand: # input true s.m. xyz as template
xyz_t[0, Ls[0]:, 1] = xyz_sm[0] - xyz_sm[0].mean(-2) # centroid at origin
f1d_t[0, Ls[0]:] = torch.cat((
torch.nn.functional.one_hot(msa[0, Ls[0]:]-1, num_classes=NAATOKENS-1).float(),
torch.ones((Ls[1], 1)).float()
), -1) # (1, L_sm, NAATOKENS)
mask_t[0, Ls[0]:, 1] = mask_sm[0] # all symmetry variants have same mask
xyz_t = torch.nan_to_num(xyz_t)
# black-hole coordinates
init = INIT_CRDS.reshape(1,NTOTAL,3).repeat(sum(Ls),1,1)
xyz_prev = init + torch.rand(sum(Ls),1,3)*random_noise - random_noise/2
mask_prev = torch.full(xyz_prev.shape[:-1], False).bool()
if args.init_ligand_xyz:
xyz_prev[Ls[0]:, 1] = xyz_sm[0] - xyz_sm[0].mean(-2) # centroid at origin
mask_prev[Ls[0]:, 1] = mask_sm[0]
# transfer inputs to device
atom_frames = atom_frames[None].to(device, non_blocking=True)
atom_mask = mask[None].to(device, non_blocking=True) # (B, L, 27)
idx_pdb = idx_pdb[None].to(device, non_blocking=True) # (B, L)
xyz_t = xyz_t[None].to(device, non_blocking=True)
mask_t = mask_t[None].to(device, non_blocking=True)
t1d = f1d_t[None].to(device, non_blocking=True)
xyz_prev = xyz_prev[None].to(device, non_blocking=True)
mask_prev = mask_prev[None].to(device, non_blocking=True)
same_chain = chain_idx[None].to(device, non_blocking=True)
bond_feats = bond_feats[None].to(device, non_blocking=True)
chirals = chirals[None].to(device, non_blocking=True)
# processing template features
seq_unmasked = msa.clone() # (B, L)
mask_t_2d = util.get_prot_sm_mask(mask_t, seq_unmasked[0]) # (B, T, L)
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
mask_t_2d = mask_t_2d.float() * same_chain.float()[:,None] # (ignore inter-chain region)
mask_recycle = util.get_prot_sm_mask(mask_prev, seq_unmasked[0])
mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
mask_recycle = same_chain.float()*mask_recycle.float()
xyz_t_frames = util.xyz_t_to_frame_xyz(xyz_t, seq_unmasked, atom_frames)
t2d = xyz_to_t2d(xyz_t_frames, mask_t_2d)
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,sum(Ls))
alpha, _, alpha_mask, _ = util.get_torsions(
xyz_t.reshape(-1,sum(Ls),NTOTAL,3),
seq_tmp,
util.torsion_indices.to(device),
util.torsion_can_flip.to(device),
util.reference_angles.to(device)
)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(1,-1,sum(Ls),NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(1,-1,sum(Ls),NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, sum(Ls), 3*NTOTALDOFS).to(device)
return dict(
Ls=Ls,
msa=msa,
ins=ins,
xyz_prev=xyz_prev,
idx_pdb=idx_pdb,
bond_feats=bond_feats,
chirals=chirals,
atom_frames=atom_frames,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
mask_t=mask_t,
mask_t_2d=mask_t_2d,
same_chain=same_chain,
mask_recycle=mask_recycle,
)
def main():
args = get_args()
print('Loading network weights...')
model = load_model(args, MODEL_PARAM)
inputs = prepare_inputs(args)
loss_funcs = [
dict(w=args.w_ent, name='ent', func=calc_entropy_loss),
dict(w=args.w_pae, name='pae', func=calc_pae_loss),
dict(w=args.w_ipae, name='ipae', func=lambda out: calc_pae_loss(out, inter=True))
]
for i_des in range(args.start_num, args.start_num+args.num):
start_time = time.time()
cycles = [int(i) for i in args.cycles.split(',')]
if args.grad_steps > 0:
out = run_gradient_descent(model, args, inputs, cycles[0], loss_funcs)
inputs['msa'] = out['msa'][0]
if args.mcmc_steps > 0:
out = run_mcmc(model, args, inputs, cycles[min(len(cycles)-1,1)], loss_funcs)
inputs['msa'] = out['msa'][0]
out_prefix = args.out+f'_{i_des}'
best_lddt = lddt_unbin(out['pred_lddt_binned'])
util.writepdb(out_prefix+".pdb", out['pred_allatom'], out['msa'][0], bfacts=100*best_lddt[0].float(),
bond_feats=inputs['bond_feats'])
if __name__ == "__main__":
main()

View File

@@ -1,63 +0,0 @@
import sys, os, json
import time
import numpy as np
import torch
import torch.nn as nn
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0,script_dir+'/models/fold_and_dock3/')
import util
def get_c6d_dict(logits, grad=True, args=None):
if grad:
logits = [logit.float() for logit in logits]
else:
logits = [logit.float().detach() for logit in logits]
probs = [nn.functional.softmax(l, dim=1) for l in logits]
dict_pred = {}
dict_pred['p_dist'] = probs[0].permute([0,2,3,1])
dict_pred['p_omega'] = probs[1].permute([0,2,3,1])
dict_pred['p_theta'] = probs[2].permute([0,2,3,1])
dict_pred['p_phi'] = probs[3].permute([0,2,3,1])
return dict_pred, probs
def calc_entropy_loss(out):
dict_pred, probs = get_c6d_dict(out['logit_s'], grad=True)
probs = [dict_pred[key] for key in ['p_dist','p_omega','p_theta','p_phi']]
# exclude last bin, then renormalize
probs = [prob[...,:-1]/prob[...,:-1].sum(-1,keepdim=True) for prob in probs]
L_mask = probs[0].shape[1]
loss_mask = 1-torch.eye(L_mask)[None].to(probs[0].device).float() # (B, L, L)
def calc_entropy(p, mask, eps=1e-6):
S_ij = -(p * torch.log(p + eps)).sum(axis=-1)
S_ave = torch.sum(mask * S_ij, axis=(1,2)) / (torch.sum(mask, axis=(1,2)) + eps)
return S_ave
entropy_s = [calc_entropy(prob, loss_mask) for prob in probs]
loss = torch.stack(entropy_s, dim=0).mean(dim=0)
return loss
def pae_unbin(logits_pae, bin_step=0.5):
nbin = logits_pae.shape[1]
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin, dtype=logits_pae.dtype, device=logits_pae.device)
logits_pae = torch.nn.Softmax(dim=1)(logits_pae)
return torch.sum(bins[None,:,None,None]*logits_pae, dim=1)
def calc_pae_loss(out, inter=False):
pae = pae_unbin(out['logit_pae'])
if inter:
sm_mask = util.is_atom(out['msa'])[0,0]
inter_mask = sm_mask[None]*(~sm_mask[:,None]) + (~sm_mask[None])*sm_mask[:,None]
pae = pae_unbin(out['logit_pae'])
return (pae*inter_mask).sum(dim=[-1,-2]) / inter_mask.sum(dim=[-1,-2])
else:
return pae.mean(dim=[-1,-2])
return pae

View File

@@ -1,476 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from opt_einsum import contract as einsum
from rf2aa.util_module import init_lecun_normal
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, r_ff, p_drop=0.1):
super(FeedForwardLayer, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_model*r_ff)
self.dropout = nn.Dropout(p_drop)
self.linear2 = nn.Linear(d_model*r_ff, d_model)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer right before ReLu: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear1.bias)
# initialize linear layer right before residual connection: zero initialize
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, src):
src = self.norm(src)
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
return src
class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden
#
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
#
self.to_out = nn.Linear(n_head*d_hidden, d_out)
self.scaling = 1/math.sqrt(d_hidden)
#
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, query, key, value):
B, Q = query.shape[:2]
B, K = key.shape[:2]
#
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
key = self.to_k(key).reshape(B, K, self.h, self.dim)
value = self.to_v(value).reshape(B, K, self.h, self.dim)
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkhd->bqhd', attn, value)
out = out.reshape(B, Q, self.h*self.dim)
#
out = self.to_out(out)
return out
# MSA Attention (row/column) from AlphaFold architecture
class SequenceWeight(nn.Module):
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
super(SequenceWeight, self).__init__()
self.h = n_head
self.dim = d_hidden
self.scale = 1.0 / math.sqrt(self.dim)
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
self.dropout = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_query.weight)
nn.init.xavier_uniform_(self.to_key.weight)
def forward(self, msa):
B, N, L = msa.shape[:3]
tar_seq = msa[:,0]
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
q = q * self.scale
attn = einsum('bqihd,bkihd->bkihq', q, k)
attn = F.softmax(attn, dim=1)
return self.dropout(attn)
class MSARowAttentionWithBias(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
super(MSARowAttentionWithBias, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
#
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa, pair): # TODO: make this as tied-attention
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
#
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
key = key * self.scaling
attn = einsum('bsqhd,bskhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
#
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColAttention(nn.Module):
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
super(MSAColAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * self.scaling
attn = einsum('bqihd,bkihd->bihqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColGlobalAttention(nn.Module):
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
super(MSAColGlobalAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
query = query.mean(dim=1) # (B, L, h, dim)
key = self.to_k(msa) # (B, N, L, dim)
value = self.to_v(msa) # (B, N, L, dim)
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
#
query = query * self.scaling
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
out = gate * out # (B, N, L, h*dim)
#
out = self.to_out(out)
return out
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
class TriangleAttention(nn.Module):
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
super(TriangleAttention, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.start_node=start_node
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
# input projection
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
# attention
query = query * self.scaling
if self.start_node:
attn = einsum('bijhd,bikhd->bijkh', query, key)
else:
attn = einsum('bijhd,bkjhd->bijkh', query, key)
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
attn = F.softmax(attn, dim=-2)
if self.start_node:
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
else:
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
out = gate * out # gated attention
# output projection
out = self.to_out(out)
return out
class TriangleMultiplication(nn.Module):
def __init__(self, d_pair, d_hidden=128, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.left_proj = nn.Linear(d_pair, d_hidden)
self.right_proj = nn.Linear(d_pair, d_hidden)
self.left_gate = nn.Linear(d_pair, d_hidden)
self.right_gate = nn.Linear(d_pair, d_hidden)
#
self.gate = nn.Linear(d_pair, d_pair)
self.norm_out = nn.LayerNorm(d_hidden)
self.out_proj = nn.Linear(d_hidden, d_pair)
self.outgoing = outgoing
self.reset_parameter()
def reset_parameter(self):
# normal distribution for regular linear weights
self.left_proj = init_lecun_normal(self.left_proj)
self.right_proj = init_lecun_normal(self.right_proj)
# Set Bias of Linear layers to zeros
nn.init.zeros_(self.left_proj.bias)
nn.init.zeros_(self.right_proj.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.left_gate.weight)
nn.init.ones_(self.left_gate.bias)
nn.init.zeros_(self.right_gate.weight)
nn.init.ones_(self.right_gate.bias)
nn.init.zeros_(self.gate.weight)
nn.init.ones_(self.gate.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
left = self.left_proj(pair) # (B, L, L, d_h)
left_gate = torch.sigmoid(self.left_gate(pair))
left = left_gate * left
right = self.right_proj(pair) # (B, L, L, d_h)
right_gate = torch.sigmoid(self.right_gate(pair))
right = right_gate * right
if self.outgoing:
out = einsum('bikd,bjkd->bijd', left, right/float(L))
else:
out = einsum('bkid,bkjd->bijd', left, right/float(L))
out = self.norm_out(out)
out = self.out_proj(out)
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
out = gate * out
return out
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
class BiasedAxialAttention(nn.Module):
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
super(BiasedAxialAttention, self).__init__()
#
self.is_row = is_row
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_bias = nn.LayerNorm(d_bias)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_bias, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair, bias):
# pair: (B, L, L, d_pair)
B, L = pair.shape[:2]
if self.is_row:
pair = pair.permute(0,2,1,3)
bias = bias.permute(0,2,1,3)
pair = self.norm_pair(pair)
bias = self.norm_bias(bias)
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
query = query * self.scaling
key = key / L # normalize for tied attention
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
attn = attn + bias # apply bias
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(B, L, L, -1)
out = gate * out
out = self.to_out(out)
if self.is_row:
out = out.permute(0,2,1,3)
return out

View File

@@ -1,87 +0,0 @@
import torch
import torch.nn as nn
from rf2aa.chemical import NAATOKENS
class DistanceNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
super(DistanceNetwork, self).__init__()
#
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
self.proj_asymm = nn.Linear(n_feat, 37+19)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer for final logit prediction
nn.init.zeros_(self.proj_symm.weight)
nn.init.zeros_(self.proj_asymm.weight)
nn.init.zeros_(self.proj_symm.bias)
nn.init.zeros_(self.proj_asymm.bias)
def forward(self, x):
# input: pair info (B, L, L, C)
# predict theta, phi (non-symmetric)
logits_asymm = self.proj_asymm(x)
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
# predict dist, omega
logits_symm = self.proj_symm(x)
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
logits_dist = logits_symm[:,:,:,:61].permute(0,3,1,2)
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
return logits_dist, logits_omega, logits_theta, logits_phi
class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
super(MaskedTokenNetwork, self).__init__()
#fd note this predicts probability for the mask token (which is never in ground truth)
# it should be ok though(?)
self.proj = nn.Linear(n_feat, NAATOKENS)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
B, N, L = x.shape[:3]
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
return logits
class LDDTNetwork(nn.Module):
def __init__(self, n_feat, n_bin_lddt=50):
super(LDDTNetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_lddt)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
logits = self.proj(x) # (B, L, 50)
return logits.permute(0,2,1)
class PAENetwork(nn.Module):
def __init__(self, n_feat, n_bin_pae=64):
super(PAENetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_pae)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
logits = self.proj(x) # (B, L, L, 64)
return logits.permute(0,3,1,2)

View File

@@ -1,283 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from util import *
from util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
from Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
from Track_module import PairStr2Pair, PositionalEncoding2D
from chemical import NAATOKENS,NTOTALDOFS, NBTYPES
# Module contains classes and functions to generate initial embeddings
class MSA_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2,
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
super(MSA_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
self.emb_left = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_state = nn.Embedding(NAATOKENS, d_state)
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
maxpos_atom=maxpos_atom, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
self.emb_q = init_lecun_normal(self.emb_q)
self.emb_left = init_lecun_normal(self.emb_left)
self.emb_right = init_lecun_normal(self.emb_right)
self.emb_state = init_lecun_normal(self.emb_state)
nn.init.zeros_(self.emb.bias)
def forward(self, msa, seq1hot, idx, bond_feats, same_chain):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# - bond_feats: Bond features (B, L, L)
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
# - pair: Initial Pair embedding (B, L, L, d_pair)
N = msa.shape[1] # number of sequenes in MSA
# msa embedding
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
tmp = (seq1hot @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
#msa = self.drop(msa)
# pair embedding
left = (seq1hot @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
right = (seq1hot @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
pair = pair + self.pos(seq1hot.argmax(-1), idx, bond_feats, same_chain) # add relative position
# state embedding
state = (seq1hot @ self.emb_state.weight)
return msa, pair, state
class Extra_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_init=NAATOKENS+1+2, p_drop=0.1):
super(Extra_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence
#self.drop = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, msa, seq1hot, idx):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
N = msa.shape[1] # number of sequenes in MSA
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
seq_emb = (seq1hot @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
msa = msa + seq_emb.expand(-1, N, -1, -1) # adding query embedding to MSA
#return self.drop(msa)
return (msa)
class Bond_emb(nn.Module):
def __init__(self, d_pair=128, d_init=NBTYPES):
super(Bond_emb, self).__init__()
self.emb = nn.Linear(d_init, d_pair)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, bond_feats):
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
return self.emb(bond_feats.float())
class TemplatePairStack(nn.Module):
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, d_t1d=22, d_state=32, p_drop=0.25):
super(TemplatePairStack, self).__init__()
self.n_block = n_block
self.proj_t1d = nn.Linear(d_t1d, d_state)
proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, d_state=d_state, p_drop=p_drop) for i in range(n_block)]
self.block = nn.ModuleList(proc_s)
self.norm = nn.LayerNorm(d_templ)
self.reset_parameter()
def reset_parameter(self):
self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)
def forward(self, templ, rbf_feat, t1d, use_checkpoint=False):
B, T, L = templ.shape[:3]
templ = templ.reshape(B*T, L, L, -1)
t1d = t1d.reshape(B*T, L, -1)
state = self.proj_t1d(t1d)
for i_block in range(self.n_block):
if use_checkpoint:
templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ,
rbf_feat, state)
else:
templ = self.block[i_block](templ, rbf_feat, state)
return self.norm(templ).reshape(B, T, L, L, -1)
class Templ_emb(nn.Module):
# Get template embedding
# Features are
# t2d:
# - 61 distogram bins + 6 orientations (67)
# - Mask (missing/unaligned) (1)
# t1d:
# - tiled AA sequence (20 standard aa + gap)
# - confidence (1)
#
def __init__(self, d_t1d=(NAATOKENS-1)+1, d_t2d=67+1, d_tor=3*NTOTALDOFS, d_pair=128, d_state=32,
n_block=2, d_templ=64,
n_head=4, d_hidden=16, p_drop=0.25):
super(Templ_emb, self).__init__()
# process 2D features
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, d_t1d=d_t1d, d_state=d_state, p_drop=p_drop)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
# d_hidden=d_hidden, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
nn.init.zeros_(self.emb_t1d.bias)
self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)
def _get_templ_emb(self, t1d, t2d):
B, T, L, _ = t1d.shape
# Prepare 2D template features
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
#
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
return self.emb(templ) # Template templures (B, T, L, L, d_templ)
def _get_templ_rbf(self, xyz_t, mask_t):
B, T, L = xyz_t.shape[:3]
# process each template features
xyz_t = xyz_t.reshape(B*T, L, 3).contiguous()
mask_t = mask_t.reshape(B*T, L, L)
assert(xyz_t.is_contiguous())
rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf)
return rbf_feat
def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False):
# Input
# - t1d: 1D template info (B, T, L, 30)
# - t2d: 2D template info (B, T, L, L, 44)
# - alpha_t: torsion angle info (B, T, L, 30) - DOUBLE-CHECK
# - xyz_t: template CA coordinates (B, T, L, 3)
# - mask_t: is valid residue pair? (B, T, L, L)
# - pair: query pair features (B, L, L, d_pair)
# - state: query state features (B, L, d_state)
B, T, L, _ = t1d.shape
templ = self._get_templ_emb(t1d, t2d)
rbf_feat = self._get_templ_rbf(xyz_t, mask_t)
# process each template pair feature
templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
# Prepare 1D template torsion angle features
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
# process each template features
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
# mixing query state features to template state features
state = state.reshape(B*L, 1, -1)
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
out = out.reshape(B, L, -1)
else:
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
state = state.reshape(B, L, -1)
state = state + out
# mixing query pair features to template information (Template pointwise attention)
pair = pair.reshape(B*L*L, 1, -1)
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
out = out.reshape(B, L, L, -1)
else:
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
#
pair = pair.reshape(B, L, L, -1)
pair = pair + out
return pair, state
class Recycling(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
super(Recycling, self).__init__()
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa)
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_state = nn.LayerNorm(d_state)
self.reset_parameter()
def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
self.proj_sctors = init_lecun_normal(self.proj_sctors)
nn.init.zeros_(self.proj_sctors.bias)
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = pair.shape[:2]
state = self.norm_state(state)
left = state.unsqueeze(2).expand(-1,-1,L,-1)
right = state.unsqueeze(1).expand(-1,L,-1,-1)
Ca_or_P = xyz[:,:,1].contiguous()
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
if mask_recycle != None:
dist = mask_recycle[...,None].float()*dist
dist = torch.cat((dist, left, right), dim=-1)
dist = self.proj_dist(dist)
pair = dist + self.norm_pair(pair)
sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS))
msa = sctors + self.norm_msa(msa)
return msa, pair, state

View File

@@ -1,139 +0,0 @@
import torch
import torch.nn as nn
from Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling
from Track_module import IterativeSimulator
from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, LDDTNetwork, PAENetwork
from chemical import INIT_CRDS,NAATOKENS, NBTYPES
from util import Ls_from_same_chain_2d
from data_loader import get_term_feats
class RoseTTAFoldModule(nn.Module):
def __init__(
self, n_extra_block=4, n_main_block=8, n_ref_block=4, n_finetune_block=0,\
d_msa=256, d_msa_full=64, d_pair=128, d_templ=64,
n_head_msa=8, n_head_pair=4, n_head_templ=4,
d_hidden=32, d_hidden_templ=64, p_drop=0.15,
SE3_param={}, SE3_ref_param={},
atom_type_index=None, aamask=None, ljlk_parameters=None, lj_correction_parameters=None,
cb_len=None, cb_ang=None, cb_tor=None,
num_bonds=None, lj_lin=0.6, use_extra_l1=True, use_atom_frames=True
):
super(RoseTTAFoldModule, self).__init__()
#
# Input Embeddings
d_state = SE3_param['l0_out_features']
self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=NAATOKENS-1+4, p_drop=p_drop)
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES)
self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, n_head=n_head_templ,
d_hidden=d_hidden_templ, p_drop=0.25)
# Update inputs with outputs from previous round
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
#
self.simulator = IterativeSimulator(
n_extra_block=n_extra_block,
n_main_block=n_main_block,
n_ref_block=n_ref_block,
n_finetune_block=n_finetune_block,
d_msa=d_msa,
d_msa_full=d_msa_full,
d_pair=d_pair,
d_hidden=d_hidden,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
SE3_param=SE3_param,
SE3_ref_param=SE3_ref_param,
p_drop=p_drop,
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
aamask=aamask,
ljlk_parameters=ljlk_parameters,
lj_correction_parameters=lj_correction_parameters,
num_bonds=num_bonds,
cb_len=cb_len,
cb_ang=cb_ang,
cb_tor=cb_tor,
lj_lin=lj_lin,
use_extra_l1=use_extra_l1,
)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
if use_extra_l1: # extra l1 features introduced at the same time as pAE/pDE heads
self.pae_pred = PAENetwork(d_pair)
self.pde_pred = PAENetwork(d_pair) # distance error, but use same architecture as aligned error
self.use_extra_l1 = use_extra_l1
self.use_atom_frames = use_atom_frames
def forward(
self, msa_one_hot, seq_unmasked, xyz, sctors, idx, bond_feats, chirals,
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None,
return_raw=False, return_full=False,
use_checkpoint=False
):
B, N, L = msa_one_hot.shape[:3]
seq1hot = msa_one_hot[:,0]
# generate input msa features
Ls = Ls_from_same_chain_2d(same_chain)
term_feats = get_term_feats(L,Ls).to(msa_one_hot.device)
msa_feat = torch.cat([msa_one_hot,
msa_one_hot,
torch.zeros(B,N,L,2).to(msa_one_hot.device),
term_feats[None,None].expand(B,N,-1,-1)], dim=3)
extra_feat = torch.cat([msa_one_hot,
torch.zeros(B,N,L,1).to(msa_one_hot.device),
term_feats[None,None].expand(B,N,-1,-1)], dim=3)
# Get embeddings
msa_latent, pair, state = self.latent_emb(msa_feat, seq1hot, idx, bond_feats, same_chain)
msa_full = self.full_emb(extra_feat, seq1hot, idx)
pair = pair + self.bond_emb(bond_feats)
#
# Do recycling
if msa_prev == None:
msa_prev = torch.zeros_like(msa_latent[:,0])
pair_prev = torch.zeros_like(pair)
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
state = state + state_recycle
# add template embedding
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint)
# Predict coordinates from given inputs
msa, pair, xyz, alpha_s, xyz_allatom, state = self.simulator(
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx, bond_feats, same_chain, chirals, atom_frames, use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames)
if return_raw:
# get last structure
xyz_last = xyz_allatom[-1].unsqueeze(0)
return msa[:,0], pair, xyz_last, state, alpha_s[-1], None
# predict masked amino acids
logits_aa = self.aa_pred(msa)
# predict distogram & orientograms
logits = self.c6d_pred(pair)
# Predict LDDT
lddt = self.lddt_pred(state)
# predict aligned error and distance error
if self.use_extra_l1:
logits_pae = self.pae_pred(pair)
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
else:
logits_pae = None
logits_pde = None
return logits, logits_aa, logits_pae, logits_pde, xyz, alpha_s, xyz_allatom, \
lddt, msa[:,0], pair, state

View File

@@ -1,88 +0,0 @@
import torch
import torch.nn as nn
import sys, os
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
sys.path.insert(0,script_dir+'../../../SE3Transformer/')
from util_module import init_lecun_normal_param
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
class SE3TransformerWrapper(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32):
super().__init__()
# Build the network
self.l1_in = l1_in_features
self.l1_out = l1_out_features
#
fiber_edge = Fiber({0: num_edge_features})
if l1_out_features > 0:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
self.se3 = SE3Transformer(num_layers=num_layers,
fiber_in=fiber_in,
fiber_hidden=fiber_hidden,
fiber_out = fiber_out,
num_heads=n_heads,
channels_div=div,
fiber_edge=fiber_edge,
populate_edge="arcsin",
final_layer="lin",
use_layer_norm=True)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
if self.l1_out > 0:
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
node_features = {'0': type_0_features, '1': type_1_features}
else:
node_features = {'0': type_0_features}
edge_features = {'0': edge_features}
return self.se3(G, node_features, edge_features)

View File

@@ -1,788 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rf2aa.util_module import *
from rf2aa.Attention_module import *
from rf2aa.SE3_network import SE3TransformerWrapper
from rf2aa.resnet import ResidualNetwork
from rf2aa.util import INIT_CRDS, is_atom, xyz_frame_from_rotation_mask
from rf2aa.loss import (
calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads,
calc_lj, calc_cart_bonded, calc_chiral_grads
)
from rf2aa.chemical import NTOTALDOFS
# Components for three-track blocks
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
# 2. Pair -> Pair update (biased attention. bias from structure)
# 3. MSA -> Pair update (extract coevolution signal)
# 4. Str -> Str update (node from MSA, edge from Pair)
class PositionalEncoding2D(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
super(PositionalEncoding2D, self).__init__()
self.minpos = minpos
self.maxpos = maxpos
self.maxpos_atom = maxpos_atom
self.nbin_res = abs(minpos)+maxpos+2 # include 0 and "unknown" value (maxpos+1)
self.nbin_atom = maxpos_atom+2 # include 0 and "unknown" token (maxpos_sm + 1)
self.emb_res = nn.Embedding(self.nbin_res, d_pair)
self.emb_atom = nn.Embedding(self.nbin_atom, d_pair)
self.emb_chain = nn.Embedding(2, d_pair)
def forward(self, seq, idx, bond_feats, same_chain=None):
sm_mask = is_atom(seq[0])
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, sm_mask,
minpos_res=self.minpos, maxpos_res=self.maxpos, maxpos_atom=self.maxpos_atom)
bins = torch.arange(self.minpos, self.maxpos+1, device=seq.device)
ib_res = torch.bucketize(res_dist, bins).long() # (B, L, L)
emb_res = self.emb_res(ib_res) #(B, L, L, d_pair)
bins = torch.arange(0, self.maxpos_atom+1, device=seq.device)
ib_atom = torch.bucketize(atom_dist, bins).long() # (B, L, L)
emb_atom = self.emb_atom(ib_atom) #(B, L, L, d_pair)
out = emb_res + emb_atom
if same_chain is not None:
emb_c = self.emb_chain(same_chain.long()) # this is used for MSA_emb but not in IterBlock
out += emb_c
return out
# Update MSA with biased self-attention. bias from Pair & Str
class MSAPairStr2MSA(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, d_rbf=64,
d_hidden=32, p_drop=0.15, use_global_attn=False):
super(MSAPairStr2MSA, self).__init__()
self.norm_pair = nn.LayerNorm(d_pair)
self.emb_rbf = nn.Linear(d_rbf, d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.proj_state = nn.Linear(d_state, d_msa)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
n_head=n_head, d_hidden=d_hidden)
if use_global_attn:
self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
else:
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
# Do proper initialization
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distrib
self.emb_rbf= init_lecun_normal(self.emb_rbf)
self.proj_state = init_lecun_normal(self.proj_state)
# initialize bias to zeros
nn.init.zeros_(self.emb_rbf.bias)
nn.init.zeros_(self.proj_state.bias)
def forward(self, msa, pair, rbf_feat, state):
'''
Inputs:
- msa: MSA feature (B, N, L, d_msa)
- pair: Pair feature (B, L, L, d_pair)
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
- xyz: xyz coordinates (B, L, n_atom, 3)
- state: updated node features after SE(3)-Transformer layer (B, L, d_state)
Output:
- msa: Updated MSA feature (B, N, L, d_msa)
'''
B, N, L = msa.shape[:3]
# prepare input bias feature by combining pair & coordinate info
pair = self.norm_pair(pair)
pair = pair + self.emb_rbf(rbf_feat)
#
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
state = self.norm_state(state)
state = self.proj_state(state).reshape(B, 1, L, -1)
msa = msa.type_as(state)
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state)
#
# Apply row/column attention to msa & transform
msa = msa + self.drop_row(self.row_attn(msa, pair))
msa = msa + self.col_attn(msa)
msa = msa + self.ff(msa)
return msa
class PairStr2Pair(nn.Module):
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_hidden_state=16, d_rbf=64, d_state=32, p_drop=0.15):
super(PairStr2Pair, self).__init__()
self.norm_state = nn.LayerNorm(d_state)
self.proj_left = nn.Linear(d_state, d_hidden_state)
self.proj_right = nn.Linear(d_state, d_hidden_state)
self.to_gate = nn.Linear(d_hidden_state*d_hidden_state, d_pair)
self.emb_rbf = nn.Linear(d_rbf, d_pair)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
self.tri_mul_out = TriangleMultiplication(d_pair, d_hidden=d_hidden)
self.tri_mul_in = TriangleMultiplication(d_pair, d_hidden, outgoing=False)
self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
self.ff = FeedForwardLayer(d_pair, 2)
self.reset_parameter()
def reset_parameter(self):
self.emb_rbf = init_lecun_normal(self.emb_rbf)
nn.init.zeros_(self.emb_rbf.bias)
self.proj_left = init_lecun_normal(self.proj_left)
nn.init.zeros_(self.proj_left.bias)
self.proj_right = init_lecun_normal(self.proj_right)
nn.init.zeros_(self.proj_right.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_gate.weight)
nn.init.ones_(self.to_gate.bias)
def forward(self, pair, rbf_feat, state):
B, L = pair.shape[:2]
rbf_feat = self.emb_rbf(rbf_feat)
state = self.norm_state(state)
left = self.proj_left(state)
right = self.proj_right(state)
gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1)
gate = torch.sigmoid(self.to_gate(gate))
rbf_feat = gate*rbf_feat
pair = pair + self.drop_row(self.tri_mul_out(pair))
pair = pair + self.drop_row(self.tri_mul_in(pair))
pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
pair = pair + self.ff(pair)
return pair
class MSA2Pair(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
super(MSA2Pair, self).__init__()
self.norm = nn.LayerNorm(d_msa)
self.proj_left = nn.Linear(d_msa, d_hidden)
self.proj_right = nn.Linear(d_msa, d_hidden)
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.proj_left = init_lecun_normal(self.proj_left)
self.proj_right = init_lecun_normal(self.proj_right)
nn.init.zeros_(self.proj_left.bias)
nn.init.zeros_(self.proj_right.bias)
# zero initialize output
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
def forward(self, msa, pair):
B, N, L = msa.shape[:3]
msa = self.norm(msa)
left = self.proj_left(msa)
right = self.proj_right(msa)
right = right / float(N)
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
out = self.proj_out(out)
pair = pair + out
return pair
class Str2Str(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=16, d_rbf=64,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
nextra_l0=0, nextra_l1=0, p_drop=0.1
):
super(Str2Str, self).__init__()
# initial node & pair feature process
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.embed_node = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
self.ff_node = FeedForwardLayer(SE3_param['l0_in_features'], 2, p_drop=p_drop)
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
self.embed_edge = nn.Linear(d_pair+d_rbf+1, SE3_param['num_edge_features'])
self.ff_edge = FeedForwardLayer(SE3_param['num_edge_features'], 2, p_drop=p_drop)
self.norm_edge = nn.LayerNorm(SE3_param['num_edge_features'])
SE3_param_temp = SE3_param.copy()
SE3_param_temp['l0_in_features'] += nextra_l0
SE3_param_temp['l1_in_features'] += nextra_l1
self.se3 = SE3TransformerWrapper(**SE3_param_temp)
self.sc_predictor = SCPred(
d_msa=d_msa,
d_state=SE3_param['l0_out_features'],
p_drop=p_drop)
self.nextra_l0 = nextra_l0
self.nextra_l1 = nextra_l1
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.embed_node = init_lecun_normal(self.embed_node)
self.embed_edge = init_lecun_normal(self.embed_edge)
# initialize bias to zeros
nn.init.zeros_(self.embed_node.bias)
nn.init.zeros_(self.embed_edge.bias)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, state, idx, rotation_mask, bond_feats, atom_frames, extra_l0=None, extra_l1=None, use_atom_frames=True, top_k=128, eps=1e-5):
# process msa & pair features
B, N, L = msa.shape[:3]
seq = self.norm_msa(msa[:,0])
pair = self.norm_pair(pair)
state = self.norm_state(state)
node = torch.cat((seq, state), dim=-1)
node = self.embed_node(node)
node = node + self.ff_node(node)
node = self.norm_node(node)
neighbor = get_seqsep_protein_sm(idx, bond_feats, rotation_mask)
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas))
edge = torch.cat((pair, rbf_feat, neighbor), dim=-1)
edge = self.embed_edge(edge)
edge = edge + self.ff_edge(edge)
edge = self.norm_edge(edge)
# define graph
if top_k != 0:
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=top_k)
else:
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx)
if use_atom_frames: # ligand l1 features are vectors to neighboring atoms
xyz_frame = xyz_frame_from_rotation_mask(xyz, rotation_mask, atom_frames)
l1_feats = xyz_frame - xyz_frame[:,:,1,:].unsqueeze(2)
else: # old (incorrect) behavior: vectors to random initial coords of virtual N and C
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
l1_feats = l1_feats.reshape(B*L, -1, 3)
if extra_l1 is not None:
l1_feats = torch.cat( (l1_feats,extra_l1), dim=1 )
if extra_l0 is not None:
node = torch.cat( (node,extra_l0), dim=2 )
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, 2, 3)
T = offset[:,:,0,:] / 10.0
R = offset[:,:,1,:] / 100.0
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
v = xyz - xyz[:,:,1:2,:]
Rout = torch.zeros((B,L,3,3), device=xyz.device)
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
Rout = torch.where(rotation_mask.reshape(B, L, 1,1), I, Rout)
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
alpha = self.sc_predictor(msa[:,0], state)
return xyz, state, alpha
class Allatom2Allatom(nn.Module):
def __init__(
self,
SE3_param
):
super(Allatom2Allatom, self).__init__()
self.se3 = SE3TransformerWrapper(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, seq, xyz, aamask, num_bonds, state, grads, top_k=24, eps=1e-5):
# seq (B,L)
# xyz (B,L,27,3)
# aamask (22,27) [per-amino-acid]
# num_bonds (22,27,27) [per-amino-acid]
# state (N,B,L,K) [K channels]
# grads (N,B,L,27,3) [N terms]
B, L = xyz.shape[:2]
mask = aamask[seq]
G, edge = make_atom_graph( xyz, mask, num_bonds[seq], top_k, maxbonds=4 )
node = state[mask]
node_l1 = grads[:,mask].permute(1,0,2)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node[...,None], node_l1, edge)
state[mask] = shift['0'][...,0]
xyz[mask] = xyz[mask] + shift['1'].squeeze(1) / 100.0
return xyz, state
class AllatomEmbed(nn.Module):
def __init__(
self,
d_state_in=64,
d_state_out=32,
p_mask=0.15
):
super(AllatomEmbed, self).__init__()
self.p_mask = p_mask
# initial node & pair feature process
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state, seq, eltmap):
B,L = state.shape[:2]
mask = torch.rand(B,L) < self.p_mask
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
state[mask] = 0.0
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
state = self.compress_embed(
torch.cat( (state,elements), dim=-1 )
)
state = self.norm_state( state )
return state
# embed residue state + atomtype -> per-atom state
#
class AllatomEmbed(nn.Module):
def __init__(
self,
d_state_in=64,
d_state_out=32,
p_mask=0.15
):
super(AllatomEmbed, self).__init__()
self.p_mask = p_mask
# initial node & pair feature process
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state, seq, eltmap):
B,L = state.shape[:2]
mask = torch.rand(B,L) < self.p_mask
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
state[mask] = 0.0
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
state = self.compress_embed(
torch.cat( (state,elements), dim=-1 )
)
state = self.norm_state( state )
return state
# embed per-atom state -> residue state
class ResidueEmbed(nn.Module):
def __init__(
self,
d_state_in=16,
d_state_out=64
):
super(ResidueEmbed, self).__init__()
self.compress_embed = nn.Linear(27*d_state_in, d_state_out)
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state):
B,L = state.shape[:2]
state = self.compress_embed( state.reshape(B,L,-1) )
state = self.norm_state( state )
return state
class SCPred(nn.Module):
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
super(SCPred, self).__init__()
self.norm_s0 = nn.LayerNorm(d_msa)
self.norm_si = nn.LayerNorm(d_state)
self.linear_s0 = nn.Linear(d_msa, d_hidden)
self.linear_si = nn.Linear(d_state, d_hidden)
# ResNet layers
self.linear_1 = nn.Linear(d_hidden, d_hidden)
self.linear_2 = nn.Linear(d_hidden, d_hidden)
self.linear_3 = nn.Linear(d_hidden, d_hidden)
self.linear_4 = nn.Linear(d_hidden, d_hidden)
# Final outputs
self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.linear_s0 = init_lecun_normal(self.linear_s0)
self.linear_si = init_lecun_normal(self.linear_si)
self.linear_out = init_lecun_normal(self.linear_out)
nn.init.zeros_(self.linear_s0.bias)
nn.init.zeros_(self.linear_si.bias)
nn.init.zeros_(self.linear_out.bias)
# right before relu activation: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_1.bias)
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_3.bias)
# right before residual connection: zero initialize
nn.init.zeros_(self.linear_2.weight)
nn.init.zeros_(self.linear_2.bias)
nn.init.zeros_(self.linear_4.weight)
nn.init.zeros_(self.linear_4.bias)
def forward(self, seq, state):
'''
Predict side-chain torsion angles along with backbone torsions
Inputs:
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
Outputs:
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
'''
B, L = seq.shape[:2]
seq = self.norm_s0(seq)
state = self.norm_si(state)
si = self.linear_s0(seq) + self.linear_si(state)
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
si = self.linear_out(F.relu_(si))
return si.view(B, L, NTOTALDOFS, 2)
class IterBlock(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_rbf=64,
n_head_msa=8, n_head_pair=4,
use_global_attn=False,
d_hidden=32, d_hidden_msa=None,
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
nextra_l0=0, nextra_l1=0):
super(IterBlock, self).__init__()
if d_hidden_msa == None:
d_hidden_msa = d_hidden
self.pos = PositionalEncoding2D(d_rbf, minpos=minpos, maxpos=maxpos,
maxpos_atom=maxpos_atom, p_drop=p_drop)
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
n_head=n_head_msa,
d_state=SE3_param['l0_out_features'],
use_global_attn=use_global_attn,
d_hidden=d_hidden_msa, p_drop=p_drop)
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
d_hidden=d_hidden//2, p_drop=p_drop)
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, d_rbf=d_rbf,
d_state=SE3_param['l0_out_features'],
d_hidden=d_hidden, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_param,
p_drop=p_drop,
nextra_l0=nextra_l0,
nextra_l1=nextra_l1)
def forward(self, msa, pair, xyz, state, seq_unmasked, idx, bond_feats, same_chain, use_checkpoint=False, top_k=128, rotation_mask=None, atom_frames=None, extra_l0=None, extra_l1=None, use_atom_frames=True):
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, same_chain)
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat, state)
xyz, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=top_k),
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, bond_feats, atom_frames, extra_l0, extra_l1, use_atom_frames)
else:
msa = self.msa2msa(msa, pair, rbf_feat, state)
pair = self.msa2pair(msa, pair)
pair = self.pair2pair(pair, rbf_feat, state)
xyz, state, alpha = self.str2str(msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, bond_feats, atom_frames, extra_l0, extra_l1, use_atom_frames, top_k=top_k)
return msa, pair, xyz, state, alpha
class IterativeSimulator(nn.Module):
def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, n_finetune_block=0,
d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
n_head_msa=8, n_head_pair=4,
SE3_param={}, SE3_ref_param={}, p_drop=0.15,
atom_type_index=None, aamask=None,
ljlk_parameters=None, lj_correction_parameters=None,
cb_len=None, cb_ang=None, cb_tor=None,
num_bonds=None, lj_lin=0.6, use_extra_l1=True
):
super(IterativeSimulator, self).__init__()
self.n_extra_block = n_extra_block
self.n_main_block = n_main_block
self.n_ref_block = n_ref_block
self.n_finetune_block = n_finetune_block
self.atom_type_index = atom_type_index
self.aamask = aamask
self.ljlk_parameters = ljlk_parameters
self.lj_correction_parameters = lj_correction_parameters
self.num_bonds = num_bonds
self.lj_lin = lj_lin
self.cb_len = cb_len
self.cb_ang = cb_ang
self.cb_tor = cb_tor
self.use_extra_l1 = use_extra_l1 # set to False to not use chiral & LJ grads
# Update with extra sequences
if n_extra_block > 0:
self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
d_hidden_msa=8,
d_hidden=d_hidden,
p_drop=p_drop,
use_global_attn=True,
SE3_param=SE3_param,
nextra_l1=3 if self.use_extra_l1 else 0)
for i in range(n_extra_block)])
# Update with seed sequences
if n_main_block > 0:
self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
d_hidden=d_hidden,
p_drop=p_drop,
use_global_attn=False,
SE3_param=SE3_param,
nextra_l1=3 if self.use_extra_l1 else 0)
for i in range(n_main_block)])
# Final SE(3) refinement
if n_ref_block > 0:
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_ref_param,
p_drop=p_drop,
nextra_l0=2*NTOTALDOFS if self.use_extra_l1 else 0,
nextra_l1=6 if self.use_extra_l1 else 0
)
# Fine-tuning all-atom SE(3) refinement
if n_finetune_block > 0:
d_state=16
self.allatom_embed = AllatomEmbed(
d_state_in = SE3_param['l0_out_features'],
d_state_out = d_state,
p_mask = 0.15
)
self.finetune_refiner = Allatom2Allatom(
SE3_param = {
'num_layers':1,
'num_channels':16,
'num_degrees':2,
'l0_in_features':d_state,
'l0_out_features':d_state,
'l1_in_features':2,
'l1_out_features':1,
'num_edge_features':4,
'n_heads':4,
'div':2,
}
)
self.residue_embed = ResidueEmbed(
d_state_in = d_state,
d_state_out = SE3_param['l0_out_features']
)
# To get all-atom coordinates
self.compute_allatom_coords = ComputeAllAtomCoords()
def forward(self, seq_unmasked, msa, msa_full, pair, xyz, state, idx, bond_feats, same_chain, chirals, atom_frames=None, use_checkpoint=False, use_atom_frames=True):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
rotation_mask = is_atom(seq_unmasked)
xyz_s = list()
alpha_s = list()
for i_m in range(self.n_extra_block):
extra_l0 = None
extra_l1 = None
if self.use_extra_l1:
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l1 = dchiraldxyz[0].detach()
msa_full, pair, xyz, state, alpha = self.extra_block[i_m](msa_full, pair,
xyz, state, seq_unmasked, idx, bond_feats,
same_chain,
use_checkpoint=use_checkpoint,
top_k=0, rotation_mask=rotation_mask,
atom_frames=atom_frames,
extra_l0=extra_l0,
extra_l1=extra_l1,
use_atom_frames=use_atom_frames)
xyz_s.append(xyz)
alpha_s.append(alpha)
for i_m in range(self.n_main_block):
extra_l0 = None
extra_l1 = None
if self.use_extra_l1:
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l1 = dchiraldxyz[0].detach()
msa, pair, xyz, state, alpha = self.main_block[i_m](msa, pair,
xyz, state, seq_unmasked, idx, bond_feats,
same_chain,
use_checkpoint=use_checkpoint,
top_k=0, rotation_mask=rotation_mask,
atom_frames=atom_frames,
extra_l0=extra_l0,
extra_l1=extra_l1,
use_atom_frames=use_atom_frames)
xyz_s.append(xyz)
alpha_s.append(alpha)
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
# now use unmasked seq (no cross-talk for msa prediction)
for i_m in range(self.n_ref_block):
extra_l0 = None
extra_l1 = None
if self.use_extra_l1:
# dbonddxyz, = calc_BB_bond_geom_grads(seq_unmasked[0], xyz.detach(), idx)
dljdxyz, dljdalpha = calc_lj_grads(
seq_unmasked, xyz.detach(), alpha.detach(),
self.compute_allatom_coords, bond_feats,
self.aamask,
self.ljlk_parameters,
self.lj_correction_parameters,
self.num_bonds,
lj_lin=self.lj_lin)
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach()
extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
xyz, state, alpha = self.str_refiner(
msa, pair, xyz.detach(), state, idx, rotation_mask, bond_feats, atom_frames,
extra_l0, extra_l1, top_k=128, use_atom_frames=use_atom_frames)
xyz_s.append(xyz)
alpha_s.append(alpha)
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
xyzallatom_s = list()
xyzallatom_s.append(xyzallatom.clone())
if (self.n_finetune_block>0):
state = self.allatom_embed(state, seq_unmasked, self.atom_type_index)
for i_m in range(self.n_finetune_block):
# dbonddxyz, = calc_cart_bonded_grads(
# seq_unmasked, xyzallatom.detach(), idx,
# self.cb_len, self.cb_ang, self.cb_tor
# )
# dljdxyz, = calc_ljallatom_grads(
# seq_unmasked,
# xyzallatom.detach(),
# self.aamask,
# self.ljlk_parameters,
# self.lj_correction_parameters,
# self.num_bonds,
# lj_lin=self.lj_lin
# )
# extra_l1 = torch.stack((dbonddxyz.detach(), dljdxyz.detach()))
extra_l1 = None
xyzallatom, state = self.finetune_refiner(
seq_unmasked,
xyzallatom.detach().float(),
self.aamask,
self.num_bonds,
state,
extra_l1.float()
)
# cb_loss = calc_cart_bonded(
# seq_unmasked, xyzallatom.detach(), idx,
# self.cb_len, self.cb_ang, self.cb_tor
# )
# lj_loss = calc_lj(
# seq_unmasked[0],
# xyzallatom.detach(),
# self.aamask,
# self.ljlk_parameters,
# self.lj_correction_parameters,
# self.num_bonds,
# lj_lin=self.lj_lin
# )
xyzallatom_s.append(xyzallatom.clone())
state = self.residue_embed(state)
xyz = torch.stack(xyz_s, dim=0)
alpha_s = torch.stack(alpha_s, dim=0)
xyzallatom_s = torch.cat(xyzallatom_s, dim=0)
return msa, pair, xyz, alpha_s, xyzallatom_s, state

View File

@@ -1,44 +0,0 @@
import re
from collections import OrderedDict
import pandas as pd
def parse_training_log(filename):
headers = ['Local','Train','Monomer','Homo','Hetero','NA','NAfs','RNA','SM Compl']
records = []
with open(filename) as f:
for line in f:
if line.startswith("Header"):
# renaming a few things from an earlier version of training script
for src,tgt in [('# epochs','num_epochs'), ('processed','examples_seen_in_epoch'),
('examples in epoch','examples_per_epoch'), ('Max mem','max_mem'),
('seconds','time'), ('total_loss: loss','Total_loss: total_loss')]:
line = line.replace(src,tgt)
columns = re.findall('(\w+)',line)
for val in ['Batch','Time','Total_loss']:
if val in columns: columns.remove(val)
if any([line.startswith(h) for h in headers]):
values = [line.split(':')[0]]+[float(x) for x in re.findall('(\d+\.*\d*)',line)]
records.append(OrderedDict(zip(columns, values)))
df = pd.DataFrame.from_records(records)
df = df.drop_duplicates(['Header','epoch','examples_seen_in_epoch'])
df_s = []
offset = 0
for ep in df['epoch'].drop_duplicates():
tmp = df[df['epoch']==ep]
n_per_epoch = tmp['examples_per_epoch'].values[0]
tmp['example'] = tmp['examples_seen_in_epoch']+offset
offset += n_per_epoch
mask = tmp['Header']!='Local'
tmp.loc[mask,'example'] = offset
df_s.append(tmp)
df = pd.concat(df_s)
return df

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,78 +0,0 @@
import numpy as np
import scipy
import scipy.spatial
from rf2aa.chemical import generate_Cbeta
# calculate dihedral angles defined by 4 sets of points
def get_dihedrals(a, b, c, d):
b0 = -1.0*(b - a)
b1 = c - b
b2 = d - c
b1 /= np.linalg.norm(b1, axis=-1)[:,None]
v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
x = np.sum(v*w, axis=-1)
y = np.sum(np.cross(b1, v)*w, axis=-1)
return np.arctan2(y, x)
# calculate planar angles defined by 3 sets of points
def get_angles(a, b, c):
v = a - b
v /= np.linalg.norm(v, axis=-1)[:,None]
w = c - b
w /= np.linalg.norm(w, axis=-1)[:,None]
x = np.sum(v*w, axis=1)
#return np.arccos(x)
return np.arccos(np.clip(x, -1.0, 1.0))
# get 6d coordinates from x,y,z coords of N,Ca,C atoms
def get_coords6d(xyz, dmax):
nres = xyz.shape[1]
# three anchor atoms
N = xyz[0]
Ca = xyz[1]
C = xyz[2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
# fast neighbors search to collect all
# Cb-Cb pairs within dmax
kdCb = scipy.spatial.cKDTree(Cb)
indices = kdCb.query_ball_tree(kdCb, dmax)
# indices of contacting residues
idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T
idx0 = idx[0]
idx1 = idx[1]
# Cb-Cb distance matrix
dist6d = np.full((nres, nres),999.9, dtype=np.float32)
dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1)
# matrix of Ca-Cb-Cb-Ca dihedrals
omega6d = np.zeros((nres, nres), dtype=np.float32)
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
# matrix of polar coord theta
theta6d = np.zeros((nres, nres), dtype=np.float32)
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
# matrix of polar coord phi
phi6d = np.zeros((nres, nres), dtype=np.float32)
phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1])
mask = np.zeros((nres, nres), dtype=np.float32)
mask[idx0, idx1] = 1.0
return dist6d, omega6d, theta6d, phi6d, mask

File diff suppressed because it is too large Load Diff

View File

@@ -1,91 +0,0 @@
#!/usr/bin/env python
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
'''
Created on Apr 30, 2014
@author: meiermark
'''
import sys
import mmap
from collections import namedtuple
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
def read_index(ffindex_filename):
entries = []
fh = open(ffindex_filename)
for line in fh:
tokens = line.split("\t")
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
fh.close()
return entries
def read_data(ffdata_filename):
fh = open(ffdata_filename, "r+b")
data = mmap.mmap(fh.fileno(), 0)
fh.close()
return data
def get_entry_by_name(name, index):
#TODO: bsearch
for entry in index:
if(name == entry.name):
return entry
return None
def read_entry_lines(entry, data):
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
return lines
def read_entry_data(entry, data):
return data[entry.offset:entry.offset + entry.length - 1]
def write_entry(entries, data_fh, entry_name, offset, data):
data_fh.write(data[:-1])
data_fh.write(bytearray(1))
entry = FFindexEntry(entry_name, offset, len(data))
entries.append(entry)
return offset + len(data)
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
with open(file_name, "rb") as fh:
data = bytearray(fh.read())
return write_entry(entries, data_fh, entry_name, offset, data)
def finish_db(entries, ffindex_filename, data_fh):
data_fh.close()
write_entries_to_db(entries, ffindex_filename)
def write_entries_to_db(entries, ffindex_filename):
sorted(entries, key=lambda x: x.name)
index_fh = open(ffindex_filename, "w")
for entry in entries:
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
index_fh.close()
def write_entry_to_file(entry, data, file):
lines = read_lines(entry, data)
fh = open(file, "w")
for line in lines:
fh.write(line+"\n")
fh.close()

View File

@@ -1,292 +0,0 @@
import numpy as np
import torch
from openbabel import openbabel
from chemical import aachirals, NTOTAL, generate_Cbeta
PARAMS = {
'DMIN':1,
'DMID':4,
'DMAX':20.0,
'DBINS1':30,
'DBINS2':30,
'ABINS':36
}
# ============================================================
def get_pair_dist(a, b):
"""calculate pair distances between two sets of points
Parameters
----------
a,b : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of two sets of atoms
Returns
-------
dist : pytorch tensor of shape [batch,nres,nres]
stores paitwise distances between atoms in a and b
"""
dist = torch.cdist(a, b, p=2)
return dist
# ============================================================
def get_ang(a, b, c, eps=1e-6):
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
from Cartesian coordinates of three sets of atoms a,b,c
Parameters
----------
a,b,c : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of three sets of atoms
Returns
-------
ang : pytorch tensor of shape [batch,nres]
stores resulting planar angles
"""
v = a - b
w = c - b
vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps)
wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps)
vw = torch.sum(vn*wn, dim=-1)
return torch.acos(torch.clamp(vw,-0.999,0.999))
# ============================================================
def get_dih(a, b, c, d, eps=1e-6):
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
given Cartesian coordinates of four sets of atoms a,b,c,d
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
stores resulting dihedrals
"""
b0 = a - b
b1 = c - b
b2 = d - c
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n
w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1)
return torch.atan2(y+eps, x+eps)
# ============================================================
def xyz_to_c6d(xyz, params=PARAMS):
"""convert cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz : pytorch tensor of shape [batch,nres,3,3]
stores Cartesian coordinates of backbone N,Ca,C atoms
Returns
-------
c6d : pytorch tensor of shape [batch,nres,nres,4]
stores stacked dist,omega,theta,phi 2D maps
"""
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
# 6d coordinates order: (dist,omega,theta,phi)
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
dist = get_pair_dist(Cb,Cb)
dist[torch.isnan(dist)] = 999.9
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
# fix long-range distances
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
c6d = torch.nan_to_num(c6d)
return c6d
def xyz_to_t2d(xyz_t, mask, params=PARAMS):
"""convert template cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
stores Cartesian coordinates of template backbone N,Ca,C atoms
mask : pytorch tensor [batch,templ,nres,nres]
indicates whether valid residue pairs or not
Returns
-------
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
stores stacked dist,omega,theta,phi 2D maps
"""
B, T, L = xyz_t.shape[:3]
c6d = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params)
c6d = c6d.view(B, T, L, L, 4)
# dist to one-hot encoded
mask = mask[...,None]
dist = dist_to_onehot(c6d[...,0], params)*mask
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
#
t2d = torch.cat((dist, orien, mask), dim=-1)
return t2d
def xyz_to_bbtor(xyz, params=PARAMS):
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
next_N = torch.roll(N, -1, dims=1)
prev_C = torch.roll(C, 1, dims=1)
phi = get_dih(prev_C, N, Ca, C)
psi = get_dih(N, Ca, C, next_N)
#
phi[:,0] = 0.0
psi[:,-1] = 0.0
#
astep = 2.0*np.pi / params['ABINS']
phi_bin = torch.round((phi+np.pi-astep/2)/astep)
psi_bin = torch.round((psi+np.pi-astep/2)/astep)
return torch.stack([phi_bin, psi_bin], axis=-1).long()
# ============================================================
def dist_to_onehot(dist, params=PARAMS):
db = dist_to_bins(dist, params)
dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS1'] + params['DBINS2']+1).float()
return dist
# ============================================================
def dist_to_bins(dist,params=PARAMS):
"""bin 2d distance maps
"""
dist[torch.isnan(dist)] = 999.9
dstep1 = (params['DMID'] - params['DMIN']) / params['DBINS1']
dstep2 = (params['DMAX'] - params['DMID']) / params['DBINS2']
dbins = torch.cat([
torch.linspace(params['DMIN']+dstep1, params['DMID'], params['DBINS1'],
dtype=dist.dtype,device=dist.device),
torch.linspace(params['DMID']+dstep2, params['DMAX'], params['DBINS2'],
dtype=dist.dtype,device=dist.device),
])
db = torch.bucketize(dist.contiguous(),dbins).long()
return db
# ============================================================
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
"""bin 2d distance and orientation maps
"""
db = dist_to_bins(c6d[...,0], params) # all dist < DMIN are in bin 0
astep = 2.0*np.pi / params['ABINS']
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
pb = torch.round((c6d[...,3]-astep/2)/astep)
# synchronize no-contact bins
params['DBINS'] = params['DBINS1'] + params['DBINS2']
ob[db==params['DBINS']] = params['ABINS']
tb[db==params['DBINS']] = params['ABINS']
pb[db==params['DBINS']] = params['ABINS']//2
if negative:
db = torch.where(same_chain.bool(), db.long(), params['DBINS'])
ob = torch.where(same_chain.bool(), ob.long(), params['ABINS'])
tb = torch.where(same_chain.bool(), tb.long(), params['ABINS'])
pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2)
return torch.stack([db,ob,tb,pb],axis=-1).long()
def get_chirals(obmol, xyz):
'''get all quadruples of atoms forming chiral centers
'''
# detect stereo centers
# stereo = [a for a in openbabel.OBMolAtomIter(obmol)
# if (a.GetHvyDegree()==3 or a.GetHvyDegree()==4) and a.GetHyb()!=2]
stereo = openbabel.OBStereoFacade(obmol)
stereo_atoms = [obmol.GetAtom(i+1) for i in range(obmol.NumAtoms()) if stereo.HasTetrahedralStereo(i)]
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
chirals = []
for o in stereo_atoms:
neigh = [b.GetIdx() for b in openbabel.OBAtomAtomIter(o)]
if len(neigh)==3:
chirals.append([o.GetIdx(),*neigh,angle])
elif len(neigh)==4:
a,b,c,d = neigh
chirals.extend([[o.GetIdx(),a,b,c,angle],
[o.GetIdx(),b,a,d,angle],
[o.GetIdx(),a,c,d,angle],
[o.GetIdx(),c,b,d,angle]])
n = len(chirals)
if n>0:
chirals = torch.tensor(chirals*3).float()
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
chirals[:,:-1] -= 1
dih = get_dih(*xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
chirals[dih<0.0,-1] = -angle
else:
chirals = torch.Tensor()
return chirals
def get_atomize_protein_chirals(residues_atomize, lig_xyz, residue_atomize_mask, bond_feats):
"""
Enumerate chiral centers in residues and provide features for chiral centers
"""
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
chiral_atoms = aachirals[residues_atomize]
ra = residue_atomize_mask.nonzero()
r,a = ra.T
chiral_atoms = chiral_atoms[r,a].nonzero().squeeze(1) #num_chiral_centers
num_chiral_centers = chiral_atoms.shape[0]
chiral_bonds = bond_feats[chiral_atoms] # find bonds to each chiral atom
chiral_bonds_idx = chiral_bonds.nonzero() # find indices of each bonded neighbor to chiral atom
# in practice all chiral atoms in proteins have 3 heavy atom neighbors, so reshape to 3
chiral_bonds_idx = chiral_bonds_idx.reshape(num_chiral_centers, 3, 2)
chirals = torch.zeros((num_chiral_centers, 5))
chirals[:,0] = chiral_atoms.long()
chirals[:, 1:-1] = chiral_bonds_idx[...,-1].long()
chirals[:, -1] = angle
n = chirals.shape[0]
if n>0:
chirals = chirals.repeat(3,1).float()
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
dih = get_dih(*lig_xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
chirals[dih<0.0,-1] = -angle
else:
chirals = torch.Tensor()
return chirals

View File

@@ -1,985 +0,0 @@
import torch
import numpy as np
import scipy
import networkx as nx
from util import (
rigid_from_3_points,
cb_lengths_CN,
cb_angles_CACN,
cb_angles_CNCA,
cb_torsions_CACNH,
cb_torsions_CANCO,
is_nucleic,
find_all_paths_of_length_n,
find_all_rigid_groups
)
from chemical import NFRAMES, NTOTAL
from kinematics import get_dih, get_ang
from scoring import HbHybType
# Loss functions for the training
# 1. BB rmsd loss
# 2. distance loss (or 6D loss?)
# 3. bond geometry loss
# 4. predicted lddt loss
#fd use improved coordinate frame generation
def get_t(N, Ca, C, eps=1e-5):
I,B,L=N.shape[:3]
Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), eps=eps)
Rs = Rs.view(I,B,L,3,3)
Ts = Ts.view(I,B,L,3)
t = Ts.unsqueeze(-2) - Ts.unsqueeze(-3)
return torch.einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3) **fixed
def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra=10.0, d_clamp_inter=30.0, A=10.0, gamma=0.99, eps=1e-6):
'''
Calculate Backbone FAPE loss
Input:
- pred: predicted coordinates (I, B, L, n_atom, 3)
- true: true coordinates (B, L, n_atom, 3)
Output: str loss
'''
I = pred.shape[0]
true = true.unsqueeze(0)
t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])
difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
clamp = torch.zeros_like(difference)
clamp[:,same_chain==1] = d_clamp_intra
clamp[:,same_chain==0] = d_clamp_inter
difference = torch.clamp(difference, max=clamp)
loss = difference / A # (I, B, L, L)
# Get a mask information (ignore missing residue + inter-chain residues)
# for positive cases, mask = mask_2d
# for negative cases (non-interacting pairs) mask = mask_2d*same_chain
if negative:
mask = mask_2d * same_chain
else:
mask = mask_2d
# calculate masked loss (ignore missing regions when calculate loss)
loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
# weighting loss
w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
w_loss = torch.flip(w_loss, (0,))
w_loss = w_loss / w_loss.sum()
tot_loss = (w_loss * loss).sum()
return tot_loss, loss.detach()
#resolve rotationally equivalent sidechains
def resolve_symmetry(xs, Rsnat_all, xsnat, Rsnat_all_alt, xsnat_alt, atm_mask):
dists = torch.linalg.norm( xs[:,:,None,:] - xs[atm_mask,:][None,None,:,:], dim=-1)
dists_nat = torch.linalg.norm( xsnat[:,:,None,:] - xsnat[atm_mask,:][None,None,:,:], dim=-1)
dists_natalt = torch.linalg.norm( xsnat_alt[:,:,None,:] - xsnat_alt[atm_mask,:][None,None,:,:], dim=-1)
drms_nat = torch.sum(torch.abs(dists_nat-dists),dim=(-1,-2))
drms_natalt = torch.sum(torch.abs(dists_nat-dists_natalt), dim=(-1,-2))
Rsnat_symm = Rsnat_all
xs_symm = xsnat
toflip = drms_natalt<drms_nat
Rsnat_symm[toflip,...] = Rsnat_all_alt[toflip,...]
xs_symm[toflip,...] = xsnat_alt[toflip,...]
return Rsnat_symm, xs_symm
# resolve "equivalent" natives
def resolve_equiv_natives(xs, natstack, maskstack):
if (len(natstack.shape)==4):
return natstack, maskstack
if (natstack.shape[1]==1):
return natstack[:,0,...], maskstack[:,0,...]
dx = torch.norm( xs[:,None,:,None,1,:]-xs[:,None,None,:,1,:], dim=-1)
dnat = torch.norm( natstack[:,:,:,None,1,:]-natstack[:,:,None,:,1,:], dim=-1)
delta = torch.sum( torch.abs(dnat-dx), dim=(-2,-1))
return natstack[:,torch.argmin(delta),...], maskstack[:,torch.argmin(delta),...]
#torsion angle predictor loss
def torsionAngleLoss( alpha, alphanat, alphanat_alt, tors_mask, tors_planar, eps=1e-8 ):
I = alpha.shape[0]
lnat = torch.sqrt( torch.sum( torch.square(alpha), dim=-1 ) + eps )
anorm = alpha / (lnat[...,None])
l_tors_ij = torch.min(
torch.sum(torch.square( anorm - alphanat[None] ),dim=-1),
torch.sum(torch.square( anorm - alphanat_alt[None] ),dim=-1)
)
l_tors = torch.sum( l_tors_ij*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
l_norm = torch.sum( torch.abs(lnat-1.0)*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
l_planar = torch.sum( torch.abs( alpha[...,0] )*tors_planar[None] ) / (torch.sum( tors_planar )*I + eps)
return l_tors+0.02*l_norm+0.02*l_planar
def compute_FAPE(Rs, Ts, xs, Rsnat, Tsnat, xsnat, Z=10.0, dclamp=10.0, eps=1e-4):
xij = torch.einsum('rji,rsj->rsi', Rs, xs[None,...] - Ts[:,None,...])
xij_t = torch.einsum('rji,rsj->rsi', Rsnat, xsnat[None,...] - Tsnat[:,None,...])
#torch.norm(xij-xij_t,dim=-1)
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t), dim=-1 ) + eps )
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean()
return loss
def compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae, pae_bin_step=0.5, eps=1e-4):
# predicted aligned error: C-alpha (or sm. mol atom) distances in backbone frames
xij_ca = torch.einsum('rji,rsj->rsi', uX[-1,:,0], X[-1,:,None,1] - X_y[-1,None,:,0,:]) # last bb prediction
xij_ca_t = torch.einsum('rji,rsj->rsi', uY[0,:,0], Y[0,:,None,1] - Y_y[0,None,:,0,:]) # assumes B=1
eij_label = torch.sqrt(torch.square(xij_ca - xij_ca_t).sum(dim=-1)+eps).clone().detach()
nbin = logit_pae.shape[1]
pae_bins = torch.linspace(pae_bin_step, pae_bin_step*(nbin-1), nbin-1, dtype=logit_pae.dtype, device=logit_pae.device)
true_pae_label = torch.bucketize(eij_label, pae_bins, right=True).long()
return torch.nn.CrossEntropyLoss(reduction='mean')(logit_pae, true_pae_label[None]) # assumes B=1
def compute_pde_loss(X, Y, logit_pde, pde_bin_step=0.3):
# predicted distance error: C-alpha (or sm. mol atom) pairwise distances
dX = torch.cdist(X[-1,:,1], X[-1,:,1], compute_mode='donot_use_mm_for_euclid_dist')
dY = torch.cdist(Y[0,:,1], Y[0,:,1], compute_mode='donot_use_mm_for_euclid_dist')
dist_err = torch.abs(dX-dY).clone().detach()
nbin = logit_pde.shape[1]
pde_bins = torch.linspace(pde_bin_step, pde_bin_step*(nbin-1), nbin-1, dtype=logit_pde.dtype, device=logit_pde.device)
true_pde_label = torch.bucketize(dist_err, pde_bins, right=True).long()
return torch.nn.CrossEntropyLoss(reduction='mean')(logit_pde, true_pde_label[None]) # assumes B=1
# from Ivan: FAPE generalized over atom sets & frames
def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=None,
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, gamma=0.99, eps=1e-4):
# X (predicted) N x L x natoms x 3
# Y (native) 1 x L x natoms x 3
# atom_mask 1 x L x natoms
# frames 1 x L x nframes x 3 x 2
# frame_mask 1 x L x nframes
# frame_atom_mask 1 x L x natoms
if frame_atom_mask is None:
frame_atom_mask = atom_mask
N, L, natoms, _ = X.shape
# flatten middle dims so can gather across residues
X_prime = X.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
Y_prime = Y.reshape(1, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
# reindex frames for flat X
frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device)
for i in range(L):
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
frames_reindex = frames_reindex.long()
frame_mask *= torch.all(
torch.gather(frame_atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
axis=-1)
X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3))
X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3))
X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3))
uX,tX = rigid_from_3_points(X_x, X_y, X_z)
Y_x = torch.gather(Y_prime, 1, frames_reindex[...,0:1].repeat(1,1,1,3))
Y_y = torch.gather(Y_prime, 1, frames_reindex[...,1:2].repeat(1,1,1,3))
Y_z = torch.gather(Y_prime, 1, frames_reindex[...,2:3].repeat(1,1,1,3))
uY,tY = rigid_from_3_points(Y_x, Y_y, Y_z)
xij = torch.einsum(
'brji,brsj->brsi',
uX[:,frame_mask[0]], X[:,atom_mask[0]][:,None,...] - X_y[:,frame_mask[0]][:,:,None,...]
)
xij_t = torch.einsum('rji,rsj->rsi', uY[frame_mask], Y[atom_mask][None,...] - Y_y[frame_mask][:,None,...])
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t[None,...]), dim=-1 ) + eps )
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2))
pae_loss = compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae) if logit_pae is not None \
else torch.tensor(0).to(frames.device)
pde_loss = compute_pde_loss(X, Y, logit_pde) if logit_pde is not None \
else torch.tensor(0).to(frames.device)
return loss, pae_loss, pde_loss
def calc_crd_rmsd(pred, true, atom_mask, rmsd_mask=None):
'''
Calculate coordinate RMSD
Input:
- pred: predicted coordinates (B, L, natoms, 3)
- true: true coordinates (B, L, natoms, 3)
- atom_mask: mask for seen coordinates (B, L, natoms)
Output: RMSD after superposition
'''
def rmsd(V, W, eps=1e-6):
L = V.shape[1]
return torch.sqrt(torch.sum((V-W)*(V-W), dim=(1,2)) / L + eps)
def centroid(X):
return X.mean(dim=-2, keepdim=True)
if rmsd_mask == None:
rmsd_mask = atom_mask.clone()
B, L, natoms = pred.shape[:3]
# center to centroid
pred_allatom = pred[atom_mask][None]
true_allatom = true[atom_mask][None]
pred_allatom_origin = pred_allatom - centroid(pred_allatom)
true_allatom_origin = true_allatom - centroid(true_allatom)
# reshape true crds to match the shape to pred crds
# true = true.unsqueeze(0).expand(I,-1,-1,-1,-1)
# pred = pred.view(B, L*natoms, 3)
# true = true.view(I*B, L*natoms, 3)
# Computation of the covariance matrix
C = torch.matmul(pred_allatom_origin.permute(0,2,1), true_allatom_origin)
# Compute optimal rotation matrix using SVD
V, S, W = torch.svd(C)
# get sign to ensure right-handedness
d = torch.ones([B,3,3], device=pred.device)
d[:,:,-1] = torch.sign(torch.det(V)*torch.det(W)).unsqueeze(1)
# Rotation matrix U
U = torch.matmul(d*V, W.permute(0,2,1)) # (IB, 3, 3)
pred_rms = pred[rmsd_mask][None] - centroid(pred_allatom)
true_rms = true[rmsd_mask][None] - centroid(true_allatom)
# Rotate pred
rP = torch.matmul(pred_rms, U) # (IB, L*3, 3)
# get RMS
rms = rmsd(rP, true_rms).reshape(B)
return rms
def angle(a, b, c, eps=1e-6):
'''
Calculate cos/sin angle between ab and cb
a,b,c have shape of (B, L, 3)
'''
B,L = a.shape[:2]
u1 = a-b
u2 = c-b
u1_norm = torch.norm(u1, dim=-1, keepdim=True) + eps
u2_norm = torch.norm(u2, dim=-1, keepdim=True) + eps
# normalize u1 & u2 --> make unit vector
u1 = u1 / u1_norm
u2 = u2 / u2_norm
u1 = u1.reshape(B*L, 3)
u2 = u2.reshape(B*L, 3)
# sin_theta = norm(a cross b)/(norm(a)*norm(b))
# cos_theta = norm(a dot b) / (norm(a)*norm(b))
sin_theta = torch.norm(torch.cross(u1, u2, dim=1), dim=1, keepdim=True).reshape(B, L, 1) # (B,L,1)
cos_theta = torch.matmul(u1[:,None,:], u2[:,:,None]).reshape(B, L, 1)
return torch.cat([cos_theta, sin_theta], axis=-1) # (B, L, 2)
def length(a, b):
return torch.norm(a-b, dim=-1)
def torsion(a,b,c,d, eps=1e-6):
#A function that takes in 4 atom coordinates:
# a - [B,L,3]
# b - [B,L,3]
# c - [B,L,3]
# d - [B,L,3]
# and returns cos and sin of the dihedral angle between those 4 points in order a, b, c, d
# output - [B,L,2]
u1 = b-a
u1 = u1 / (torch.norm(u1, dim=-1, keepdim=True) + eps)
u2 = c-b
u2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)
u3 = d-c
u3 = u3 / (torch.norm(u3, dim=-1, keepdim=True) + eps)
#
t1 = torch.cross(u1, u2, dim=-1) #[B, L, 3]
t2 = torch.cross(u2, u3, dim=-1)
t1_norm = torch.norm(t1, dim=-1, keepdim=True)
t2_norm = torch.norm(t2, dim=-1, keepdim=True)
cos_angle = torch.matmul(t1[:,:,None,:], t2[:,:,:,None])[:,:,0]
sin_angle = torch.norm(u2, dim=-1,keepdim=True)*(torch.matmul(u1[:,:,None,:], t2[:,:,:,None])[:,:,0])
cos_sin = torch.cat([cos_angle, sin_angle], axis=-1)/(t1_norm*t2_norm+eps) #[B,L,2]
return cos_sin
# ideal N-C distance, ideal cos(CA-C-N angle), ideal cos(C-N-CA angle)
# for NA, we do not compute this as it is not computable from the stubs alone
def calc_BB_bond_geom(
seq, pred, idx, eps=1e-6,
ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255,
sig_len=0.02, sig_ang=0.05):
'''
Calculate backbone bond geometry (bond length and angle) and put loss on them
Input:
- pred: predicted coords (B, L, :, 3), 0; N / 1; CA / 2; C
- true: True coords (B, L, :, 3)
Output:
- bond length loss, bond angle loss
'''
def cosangle( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
B, L = pred.shape[:2]
bonded = (idx[:,1:] - idx[:,:-1])==1
is_prot = ~is_nucleic(seq)[:-1]
# bond length: C-N
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(B,L-1) # (B, L-1)
CN_loss = torch.clamp( torch.abs(blen_CN_pred - ideal_NC) - sig_len, min=0.0 )
CN_loss = (bonded*is_prot*CN_loss).sum() / ((bonded*is_prot).sum() + eps)
blen_loss = CN_loss #fd squared loss
# bond angle: CA-C-N, C-N-CA
bang_CACN_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
bang_CNCA_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
CACN_loss = torch.clamp( torch.abs(bang_CACN_pred - ideal_CACN) - sig_ang, min=0.0 )
CACN_loss = (bonded*is_prot*CACN_loss).sum() / ((bonded*is_prot).sum() + eps)
CNCA_loss = torch.clamp( torch.abs(bang_CNCA_pred - ideal_CNCA) - sig_ang, min=0.0 )
CNCA_loss = (bonded*is_prot*CNCA_loss).sum() / ((bonded*is_prot).sum() + eps)
bang_loss = CACN_loss + CNCA_loss
return blen_loss+bang_loss
def calc_atom_bond_loss(pred, true, bond_feats, seq, beta=0.2, eps=1e-6):
"""
loss on distances between bonded atoms
"""
loss_func_sum = torch.nn.SmoothL1Loss(reduction='sum', beta=beta)
loss_func_mean = torch.nn.SmoothL1Loss(reduction='mean', beta=beta)
# intra-ligand bonds
atom_bonds = (bond_feats>0)*(bond_feats < 5)
b, i, j = torch.where(atom_bonds>0)
nat_dist = torch.sum(torch.square(true[:,i,1]-true[:,j,1]),dim=-1)
pred_dist = torch.sum(torch.square(pred[:,i,1]-pred[:,j,1]),dim=-1)
#lig_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist), max=clamp)) # from EquiBind
lig_dist_loss = loss_func_sum(nat_dist, pred_dist)
# bonds between protein residues and ligand atoms (i.e. atomized protein)
inter_bonds = bond_feats==6
_, i, j = torch.where(inter_bonds)
a = (seq[:,i]<22) & (seq[:,j]==39) # res N - atom C: binary indicator
b = (seq[:,i]<22) & (seq[:,j]==55) # res C - atom N
c = (seq[:,i]==39) & (seq[:,j]<22) # atom C - res N
d = (seq[:,i]==55) & (seq[:,j]<22) # atom N - res C
i_atom = 0*a + 2*b + 1*c + 1*d # (B, N_bonds) : indexes of atom that is bonded (N:0, C:2, 1:ligand atom)
j_atom = 1*a + 1*b + 0*c + 2*d # (B, N_bonds)
nat_dist = torch.sum(torch.square(true[0,i,i_atom[0],:]-true[0,j,j_atom[0],:]), dim=-1) # assumes B=1
pred_dist = torch.sum(torch.square(pred[0,i,i_atom[0],:]-pred[0,j,j_atom[0],:]), dim=-1)
#inter_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist), max=clamp))
inter_dist_loss = loss_func_sum(nat_dist, pred_dist)
bond_dist_loss = (lig_dist_loss + inter_dist_loss)/(atom_bonds.sum() + inter_bonds.sum() + eps)
# enforce LAS constraints between atoms 2 bonds away and aromatic groups
atom_bonds_np = atom_bonds[0].cpu().numpy()
G = nx.from_numpy_matrix(atom_bonds_np)
paths = find_all_paths_of_length_n(G,2)
if paths:
paths = torch.tensor(paths, device=pred.device)
nat_dist = torch.sum(torch.square(true[:,paths[:,0],1]-true[:,paths[:,2],1]),dim=-1)
pred_dist = torch.sum(torch.square(pred[:,paths[:,0],1]-pred[:,paths[:,2],1]),dim=-1)
#skip_bond_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist),max=clamp))/(paths.shape[0]+eps)
skip_bond_dist_loss = loss_func_mean(nat_dist, pred_dist)
else:
skip_bond_dist_loss = torch.tensor(0, device=pred.device)
rigid_groups = find_all_rigid_groups(bond_feats)
if rigid_groups != None:
nat_dist = torch.sum(torch.square(true[:,rigid_groups[:,0],1]-true[:,rigid_groups[:,1],1]),dim=-1)
pred_dist = torch.sum(torch.square(pred[:,rigid_groups[:,0],1]-pred[:,rigid_groups[:,1],1]),dim=-1)
#rigid_group_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist),max=clamp))/(rigid_groups.shape[0]+eps)
rigid_group_dist_loss = loss_func_mean(nat_dist, pred_dist)
else:
rigid_group_dist_loss = torch.tensor(0, device=pred.device)
return bond_dist_loss, skip_bond_dist_loss, rigid_group_dist_loss
def calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
# pred: N x L x 27 x 3
# idx: 1 x L
# seq: 1 x L
def gen_ang( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.acos( torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) )
# quadratic from [-1,1], linear elsewhere
def boundfunc(X):
Y = torch.abs(X)
Y[Y<1.0] = torch.square(Y[Y<1.0])
#Y = torch.square(X)
return Y
N,L = pred.shape[:2]
cb_loss = torch.zeros(N, device=pred.device)
## intra-res
cblens = len_param[seq]
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
len_all = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
len_mask = cblens[...,0]!=cblens[...,1]
E_cb_len = (
len_mask[None,...] *
cblens[None,...,3] *
boundfunc( length(len_all[...,0,:],len_all[...,1,:]) - cblens[...,2] )
).sum(dim=(0,3)) / len_mask.sum()
# figure out which his are his_d
cblens[seq==8] = len_param[-1]
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
len_all_a = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
len_mask_a = cblens[...,0]!=cblens[...,1]
E_cb_len_a = (
len_mask_a[None,...] *
cblens[None,...,3] *
boundfunc( length(len_all_a[...,0,:],len_all_a[...,1,:]) - cblens[...,2] )
).sum(dim=(0,3)) / len_mask.sum() # N,L
is_his_d = (seq==8)*(E_cb_len_a<E_cb_len)
cb_loss += torch.min(E_cb_len_a,E_cb_len).sum(dim=1)
cbangs = ang_param[seq].repeat(N,1,1,1)
cbangs[is_his_d] = ang_param[-1]
ang_idx = cbangs[...,:3].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
ang_all = torch.gather(pred, 2, ang_idx).reshape(N,L,-1,3,3)
ang_mask = cbangs[...,0]!=cbangs[...,1]
E_cb_ang = (
ang_mask[None,...] *
cbangs[None,...,4] *
boundfunc( get_ang(ang_all[...,0,:],ang_all[...,1,:],ang_all[...,2,:]) - cbangs[None,...,3] )
).sum(dim=(0,2,3)) / ang_mask.sum()
cb_loss += E_cb_ang
cbtors = tor_param[seq].repeat(N,1,1,1)
cbtors[is_his_d] = tor_param[-1]
tor_idx = cbtors[...,:4].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
tor_all = torch.gather(pred, 2, tor_idx).reshape(N,L,-1,4,3)
tor_mask = cbtors[...,0]!=cbtors[...,1]
offset = 2*np.pi/cbtors[None,...,6]
tor_deltas = (
get_dih(
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
) - cbtors[None,...,4] + 0.5*offset
) % offset - 0.5*offset
dihs = get_dih(
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
)
E_cb_tor = (
tor_mask[None,...] *
cbtors[None,...,5] *
boundfunc( tor_deltas )
).sum(dim=(0,2,3)) / tor_mask.sum()
cb_loss += E_cb_tor
# inter-res
# bond length: C-N
bonded = (idx[:,1:] - idx[:,:-1])==1
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(N,L-1) # (B, L-1)
CN_loss = cb_lengths_CN[1] * boundfunc(blen_CN_pred - cb_lengths_CN[0])
cb_loss += (bonded*CN_loss).sum(dim=1) / (bonded.sum())
# bond angle: CA-C-N, C-N-CA
bang_CACN_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
CACN_loss = cb_angles_CACN[1] * boundfunc(bang_CACN_pred - cb_angles_CACN[0])
cb_loss += (bonded*CACN_loss).sum(dim=1) / (bonded.sum())
bang_CNCA_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
CNCA_loss = cb_angles_CNCA[1] * boundfunc(bang_CNCA_pred - cb_angles_CNCA[0])
cb_loss += (bonded*CNCA_loss).sum(dim=1) / (bonded.sum())
# improper torsions CA-C-N-H (CD-C-N-CA), CA-N-C-O
# planarity around N (H for non-pro, CD for pro)
atom4idx = torch.full_like(seq, 14)
atom4idx[seq==14] = 6 # set to CD for proline
atom4 = torch.gather( pred, 2, atom4idx[:,:,None,None].repeat(1,1,1,3) )
btor_CACNH_delta = (
get_dih(
pred[:,:-1,1], pred[:,:-1,2], pred[:,1:,0], atom4[:,1:,0]
) - cb_torsions_CACNH[0] + np.pi/2
) % np.pi - np.pi/2
CACNH_loss = cb_torsions_CACNH[1] * boundfunc( btor_CACNH_delta )
cb_loss += (bonded*CACNH_loss).sum(dim=1) / (bonded.sum())
# planarity around C
btor_CANCO_delta = (
get_dih(
pred[:,:-1,1], pred[:,1:,0], pred[:,:-1,2], pred[:,:-1,3]
) - cb_torsions_CANCO[0] + np.pi/2
) % np.pi - np.pi/2
CANCO_loss = cb_torsions_CANCO[1] * boundfunc( btor_CANCO_delta )
cb_loss += (bonded*CANCO_loss).sum(dim=1) / (bonded.sum())
return cb_loss
# AF2-like version of clash score
def calc_clash(xs, mask):
DISTCUT=2.0 # (d_lit - tau) from AF2 MS
L = xs.shape[0]
dij = torch.sqrt(
torch.sum( torch.square( xs[:,:,None,None,:]-xs[None,None,:,:,:] ), dim=-1 ) + 1e-8
)
allmask = mask[:,:,None,None]*mask[None,None,:,:]
allmask[torch.arange(L),:,torch.arange(L),:] = False # ignore res-self
allmask[torch.arange(1,L),0,torch.arange(L-1),2] = False # ignore N->C
allmask[torch.arange(L-1),2,torch.arange(1,L),0] = False # ignore N->C
clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask)
return clash
# Rosetta-like version of LJ (fa_atr+fa_rep)
# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes
def calc_lj(
seq, xs, aamask, bond_feats, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
def ljV(dist, sigma, epsilon, lj_lin, lj_maxrad):
N = dist.shape[0]
linpart = dist<lj_lin*sigma[None]
deff = dist.clone()
deff[linpart] = lj_lin*sigma.repeat(N,1)[linpart]
sd = sigma[None] / deff
sd2 = sd*sd
sd6 = sd2 * sd2 * sd2
sd12 = sd6 * sd6
ljE = epsilon * (sd12 - 2 * sd6)
ljE[linpart] += epsilon.repeat(N,1)[linpart] * (
-12 * sd12[linpart]/deff[linpart] + 12 * sd6[linpart]/deff[linpart]
) * (dist[linpart]-deff[linpart])
if (lj_maxrad>0):
sdmax = sigma / lj_maxrad
sd2 = sd*sd
sd6 = sd2 * sd2 * sd2
sd12 = sd6 * sd6
ljE = ljE - epsilon * (sd12 - 2 * sd6)
return ljE
N, L = xs.shape[:2]
# mask keeps running total of what to compute
mask = aamask[seq][...,None,None]*aamask[seq][None,None,...]
idxes1r = torch.tril_indices(L,L,-1)
mask[idxes1r[0],:,idxes1r[1],:] = False
idxes2r = torch.arange(L)
idxes2a = torch.tril_indices(NTOTAL,NTOTAL,0)
mask[idxes2r[:,None],idxes2a[0:1],idxes2r[:,None],idxes2a[1:2]] = False
# "countpair" can be enforced by making this a weight
mask[idxes2r,:,idxes2r,:] *= num_bonds[seq,:,:] >= 4 #intra-res
mask[idxes2r[:-1],:,idxes2r[1:],:] *= (
num_bonds[seq[:-1],:,2:3] + num_bonds[seq[1:],0:1,:] + 1 >= 4 #inter-res
)
atom_bonds = (bond_feats > 0)*(bond_feats<5)
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds[0].long().detach().cpu().numpy(), directed=False)
dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0), device=mask.device) # protein portion is inf and you don't want to mask it out
mask[:,1,:,1] *= dist_matrix >=4
si,ai,sj,aj = mask.nonzero(as_tuple=True)
ds = torch.sqrt( torch.sum ( torch.square( xs[:,si,ai]-xs[:,sj,aj] ), dim=-1 ) + eps )
# hbond correction
use_hb_dis = (
ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,1]
+ ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0] )
use_ohdon_dis = ( # OH are both donors & acceptors
ljcorr[seq[si],ai,0]*ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0]
+ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,0]*ljcorr[seq[sj],aj,1]
)
use_hb_hdis = (
ljcorr[seq[si],ai,2]*ljcorr[seq[sj],aj,1]
+ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,2]
)
# disulfide correction
potential_disulf = ljcorr[seq[si],ai,3]*ljcorr[seq[sj],aj,3]
ljrs = ljparams[seq[si],ai,0] + ljparams[seq[sj],aj,0]
ljrs[use_hb_dis] = lj_hb_dis
ljrs[use_ohdon_dis] = lj_OHdon_dis
ljrs[use_hb_hdis] = lj_hbond_hdis
ljss = torch.sqrt( ljparams[seq[si],ai,1] * ljparams[seq[sj],aj,1] + eps )
ljss [potential_disulf] = 0.0
ljval = ljV(ds,ljrs,ljss,lj_lin,lj_maxrad)
return (torch.sum( ljval, dim=-1 )/torch.sum(aamask[seq]))
def calc_hb(
seq, xs, aamask, hbtypes, hbbaseatoms, hbpolys,
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
):
def evalpoly( ds, xrange, yrange, coeffs ):
v = coeffs[...,0]
for i in range(1,10):
v = v * ds + coeffs[...,i]
minmask = ds<xrange[...,0]
v[minmask] = yrange[minmask][...,0]
maxmask = ds>xrange[...,1]
v[maxmask] = yrange[maxmask][...,1]
return v
def cosangle( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
hbts = hbtypes[seq]
hbba = hbbaseatoms[seq]
rh,ah = (hbts[...,0]>=0).nonzero(as_tuple=True)
ra,aa = (hbts[...,1]>=0).nonzero(as_tuple=True)
D_xs = xs[rh,hbba[rh,ah,0]][:,None,:]
H_xs = xs[rh,ah][:,None,:]
A_xs = xs[ra,aa][None,:,:]
B_xs = xs[ra,hbba[ra,aa,0]][None,:,:]
B0_xs = xs[ra,hbba[ra,aa,1]][None,:,:]
hyb = hbts[ra,aa,2]
polys = hbpolys[hbts[rh,ah,0][:,None],hbts[ra,aa,1][None,:]]
AH = torch.sqrt( torch.sum( torch.square( H_xs-A_xs), axis=-1) + eps )
AHD = torch.acos( cosangle( B_xs, A_xs, H_xs) )
Es = polys[...,0,0]*evalpoly(
AH,polys[...,0,1:3],polys[...,0,3:5],polys[...,0,5:])
Es += polys[...,1,0] * evalpoly(
AHD,polys[...,1,1:3],polys[...,1,3:5],polys[...,1,5:])
Bm = 0.5*(B0_xs[:,hyb==HbHybType.RING]+B_xs[:,hyb==HbHybType.RING])
cosBAH = cosangle( Bm, A_xs[:,hyb==HbHybType.RING], H_xs )
Es[:,hyb==HbHybType.RING] += polys[:,hyb==HbHybType.RING,2,0] * evalpoly(
cosBAH,
polys[:,hyb==HbHybType.RING,2,1:3],
polys[:,hyb==HbHybType.RING,2,3:5],
polys[:,hyb==HbHybType.RING,2,5:])
cosBAH1 = cosangle( B_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
cosBAH2 = cosangle( B0_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
Esp3_1 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
cosBAH1,
polys[:,hyb==HbHybType.SP3,2,1:3],
polys[:,hyb==HbHybType.SP3,2,3:5],
polys[:,hyb==HbHybType.SP3,2,5:])
Esp3_2 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
cosBAH2,
polys[:,hyb==HbHybType.SP3,2,1:3],
polys[:,hyb==HbHybType.SP3,2,3:5],
polys[:,hyb==HbHybType.SP3,2,5:])
Es[:,hyb==HbHybType.SP3] += torch.log(
torch.exp(Esp3_1 * hb_sp3_softmax_fade)
+ torch.exp(Esp3_2 * hb_sp3_softmax_fade)
) / hb_sp3_softmax_fade
cosBAH = cosangle( B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs )
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * evalpoly(
cosBAH,
polys[:,hyb==HbHybType.SP2,2,1:3],
polys[:,hyb==HbHybType.SP2,2,3:5],
polys[:,hyb==HbHybType.SP2,2,5:])
BAH = torch.acos( cosBAH )
B0BAH = get_dih(B0_xs[:,hyb==HbHybType.SP2], B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs)
d,m,l = hb_sp2_BAH180_rise, hb_sp2_range_span, hb_sp2_outer_width
Echi = torch.full_like( B0BAH, m-0.5 )
mask1 = BAH>np.pi * 2.0 / 3.0
H = 0.5 * (torch.cos(2 * B0BAH) + 1)
F = d / 2 * torch.cos(3 * (np.pi - BAH[mask1])) + d / 2 - 0.5
Echi[mask1] = H[mask1] * F + (1 - H[mask1]) * d - 0.5
mask2 = BAH>np.pi * (2.0 / 3.0 - l)
mask2 *= ~mask1
outer_rise = torch.cos(np.pi - (np.pi * 2 / 3 - BAH[mask2]) / l)
F = m / 2 * outer_rise + m / 2 - 0.5
G = (m - d) / 2 * outer_rise + (m - d) / 2 + d - 0.5
Echi[mask2] = H[mask2] * F + (1 - H[mask2]) * d - 0.5
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * Echi
tosquish = torch.logical_and(Es > -0.1,Es < 0.1)
Es[tosquish] = -0.025 + 0.5 * Es[tosquish] - 2.5 * torch.square(Es[tosquish])
Es[Es > 0.1] = 0.
if (normalize):
return (torch.sum( Es ) / torch.sum(aamask[seq]))
else:
return torch.sum( Es )
def calc_chiral_loss(pred, chirals):
"""
calculate error in dihedral angles for chiral atoms
Input:
- pred: predicted coords (B, L, :, 3)
- chirals: True coords (B, nchiral, 5), skip if 0 chiral sites, 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
Output:
- mean squared error of chiral angles
"""
if chirals.shape[1] == 0:
return torch.tensor(0.0, device=pred.device)
chiral_dih = pred[:, chirals[..., :-1].long(), 1]
pred_dih = get_dih(chiral_dih[...,0, :], chiral_dih[...,1, :], chiral_dih[...,2, :], chiral_dih[...,3, :]) # n_symm, b, n, 36, 3
l = torch.square(pred_dih-chirals[...,-1]).mean()
return l
@torch.enable_grad()
def calc_BB_bond_geom_grads(seq, pred, idx, eps=1e-6, ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255, sig_len=0.02, sig_ang=0.05):
pred.requires_grad_(True)
Ebond = calc_BB_bond_geom(seq, pred, idx, eps, ideal_NC, ideal_CACN, ideal_CNCA, sig_len, sig_ang)
return torch.autograd.grad(Ebond, pred)
@torch.enable_grad()
def calc_cart_bonded_grads(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
pred.requires_grad_(True)
Ecb = calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps)
return torch.autograd.grad(Ecb, pred)
@torch.enable_grad()
def calc_ljallatom_grads(
seq, xyzaa,
aamask, bond_feats, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
xyzaa.requires_grad_(True)
Elj = calc_lj(
seq[0],
xyzaa[...,:3],
aamask,
bond_feats,
ljparams,
ljcorr,
num_bonds,
lj_lin,
lj_hb_dis,
lj_OHdon_dis,
lj_hbond_hdis,
lj_maxrad,
eps
)
return torch.autograd.grad(Elj, (xyzaa,))
@torch.enable_grad()
def calc_lj_grads(
seq, xyz, alpha, toaa, bond_feats,
aamask, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
xyz.requires_grad_(True)
alpha.requires_grad_(True)
_, xyzaa = toaa(seq, xyz, alpha)
Elj = calc_lj(
seq[0],
xyzaa[...,:3],
aamask,
bond_feats,
ljparams,
ljcorr,
num_bonds,
lj_lin,
lj_hb_dis,
lj_OHdon_dis,
lj_hbond_hdis,
lj_maxrad,
eps
)
return torch.autograd.grad(Elj, (xyz,alpha))
@torch.enable_grad()
def calc_hb_grads(
seq, xyz, alpha, toaa,
aamask, hbtypes, hbbaseatoms, hbpolys,
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
):
xyz.requires_grad_(True)
alpha.requires_grad_(True)
_, xyzaa = toaa(seq, xyz, alpha)
Ehb = calc_hb(
seq,
xyzaa[0,...,:3],
aamask,
hbtypes,
hbbaseatoms,
hbpolys,
hb_sp2_range_span,
hb_sp2_BAH180_rise,
hb_sp2_outer_width,
hb_sp3_softmax_fade,
threshold_distance,
eps,
normalize)
return torch.autograd.grad(Ehb, xs)
@torch.enable_grad()
def calc_chiral_grads(xyz, chirals):
xyz.requires_grad_(True)
l = calc_chiral_loss(xyz, chirals)
if l.item() == 0.0:
return (torch.zeros(xyz.shape, device=xyz.device),) # autograd returns a tuple..
return torch.autograd.grad(l, xyz)
def calc_pseudo_dih(pred, true, eps=1e-6):
'''
calculate pseudo CA dihedral angle and put loss on them
Input:
- predicted & true CA coordinates (I,B,L,3) / (B, L, 3)
Output:
- dihedral angle loss
'''
I, B, L = pred.shape[:3]
pred = pred.reshape(I*B, L, -1)
true_dih = torsion(true[:,:-3,:],true[:,1:-2,:],true[:,2:-1,:],true[:,3:,:]) # (B, L', 2)
pred_dih = torsion(pred[:,:-3,:],pred[:,1:-2,:],pred[:,2:-1,:],pred[:,3:,:]) # (I*B, L', 2)
pred_dih = pred_dih.reshape(I, B, -1, 2)
dih_loss = torch.square(pred_dih - true_dih).sum(dim=-1).mean()
dih_loss = torch.sqrt(dih_loss + eps)
return dih_loss
def calc_lddt(pred_ca, true_ca, mask_crds, mask_2d, same_chain, negative=False, interface=False, eps=1e-6):
# Input
# pred_ca: predicted CA coordinates (I, B, L, 3)
# true_ca: true CA coordinates (B, L, 3)
# pred_lddt: predicted lddt values (I-1, B, L)
I, B, L = pred_ca.shape[:3]
pred_dist = torch.cdist(pred_ca, pred_ca) # (I, B, L, L)
true_dist = torch.cdist(true_ca, true_ca).unsqueeze(0) # (1, B, L, L)
mask = torch.logical_and(true_dist > 0.0, true_dist < 15.0) # (1, B, L, L)
# update mask information
mask *= mask_2d[None]
if negative:
mask *= same_chain.bool()[None]
elif interface:
# ignore atoms between the same chain
mask *= ~same_chain.bool()[None]
mask_crds = mask_crds * (mask[0].sum(dim=-1) != 0)
delta = torch.abs(pred_dist-true_dist) # (I, B, L, L)
true_lddt = torch.zeros((I,B,L), device=pred_ca.device)
for distbin in [0.5, 1.0, 2.0, 4.0]:
true_lddt += 0.25*torch.sum((delta<=distbin)*mask, dim=-1) / (torch.sum(mask, dim=-1) + eps)
true_lddt = mask_crds*true_lddt
true_lddt = true_lddt.sum(dim=(1,2)) / (mask_crds.sum() + eps)
return true_lddt
#fd allatom lddt
def calc_allatom_lddt(P, Q, idx, atm_mask, eps=1e-6):
# P - N x L x 27 x 3
# Q - L x 27 x 3
N, L = P.shape[:2]
# distance matrix
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
# get valid pairs
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
# ignore missing atoms
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
# ignore atoms within same residue
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
lddt = torch.zeros( (N,L,27), device=P.device ) # (N, L, 27)
for distbin in (0.5,1.0,2.0,4.0):
lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4)
) / ( torch.sum( pair_mask, dim=(2,4) ) + 1e-8)
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
return lddt
def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain, negative=False, interface=False, bin_scaling=1, eps=1e-6):
# P - N x L x 27 x 3
# Q - L x 27 x 3
# pred_lddt - 1 x nbucket x L
N, L, Natm = P.shape[:3]
# distance matrix
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
# get valid pairs
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
# ignore missing atoms
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
# ignore atoms within same residue
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
if negative:
# ignore atoms between different chains
pair_mask *= same_chain.bool()[:,:,:,None,None]
elif interface:
# ignore atoms between the same chain
pair_mask *= ~same_chain.bool()[:,:,:,None,None]
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27)
for distbin in (0.5,1.0,2.0,4.0):
lddt += 0.25 * torch.sum( (delta_PQ<=distbin*bin_scaling)*pair_mask, dim=(2,4)
) / ( torch.sum( pair_mask, dim=(2,4) ) + eps)
final_lddt_by_res = torch.clamp(
(lddt[-1]*atm_mask[0]).sum(-1)
/ (atm_mask.sum(-1) + eps), min=0.0, max=1.0)
# calculate lddt prediction loss
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
pred_lddt, true_lddt_label[-1])
res_mask = atm_mask.any(dim=-1)
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
# method 1: average per-residue
#lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-8) # L
#lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-8)
# method 2: average per-atom
atm_mask = atm_mask * (pair_mask.sum(dim=(1,3)) != 0)
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
return lddt_loss, lddt

View File

@@ -1,57 +0,0 @@
import gc
import torch
## MEM utils ##
def mem_report():
'''Report the memory usage of the tensor.storage in pytorch
Both on CPUs and GPUs are reported'''
def _mem_report(tensors, mem_type):
'''Print the selected tensors of type
There are two major storage types in our major concern:
- GPU: tensors transferred to CUDA devices
- CPU: tensors remaining on the system memory (usually unimportant)
Args:
- tensors: the tensors of specified type
- mem_type: 'CPU' or 'GPU' in current implementation '''
print('Storage on %s' %(mem_type))
print('-'*LEN)
total_numel = 0
total_mem = 0
visited_data = []
for tensor in tensors:
if tensor.is_sparse:
continue
# a data_ptr indicates a memory block allocated
data_ptr = tensor.storage().data_ptr()
if data_ptr in visited_data:
continue
visited_data.append(data_ptr)
numel = tensor.storage().size()
total_numel += numel
element_size = tensor.storage().element_size()
mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte
total_mem += mem
element_type = type(tensor).__name__
size = tuple(tensor.size())
print('%s\t\t%s\t\t%.2f' % (
element_type,
size,
mem) )
print('-'*LEN)
print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) )
print('-'*LEN)
LEN = 65
print('='*LEN)
objects = gc.get_objects()
print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
tensors = [obj for obj in objects if torch.is_tensor(obj)]
cuda_tensors = [t for t in tensors if t.is_cuda]
host_tensors = [t for t in tensors if not t.is_cuda]
_mem_report(cuda_tensors, 'GPU')
_mem_report(host_tensors, 'CPU')
print('='*LEN)

View File

@@ -1,46 +0,0 @@
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 4,
"d_msa" : 256,
"d_pair" : 192,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 6,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.0,
"lj_lin" : 0.7
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
SE3_ref_param = {
"num_layers" : 2,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
MODEL_PARAM['use_extra_l1'] = True
MODEL_PARAM['use_atom_frames'] = True

View File

@@ -1,490 +0,0 @@
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import util
import gzip
from ffindex import *
import torch
from chemical import NAATOKENS, aa2num, aa2long, atomnum2atomtype, NTOTAL, CHAIN_GAP
from openbabel import openbabel
to1letter = {
"ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
"GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
"LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
"SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V',
"DA":'a', "DC":'c', "DG":'g', "DT":'t',
"A":'b', "C":'d', "G":'h', "U":'u',
}
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(aa2num[aa] if aa in aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = aa2num[aa] if aa in aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, unzip=True, maxseq=10000):
msa = []
ins = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if (unzip):
fstream = gzip.open(filename,"rt")
else:
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# remove insertion at the end
if (not unzip):
n_remove = 0
for c in reversed(line):
if c.islower():
n_remove += 1
else:
break
line = line[:-n_remove]
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines, parse_hetatom=parse_hetatom)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines):
# indices of residues observed in the structure
#idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
idx_s = [int(r[0]) for r in res]
seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
idx = idx_s.index(resNo)
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# parse ligand atoms
offset = max(idx_s)
res_lig = [l[12:16].strip() for l in lines if l[:6]=="HETATM"]
res_lig = [(i+offset+CHAIN_GAP,l) for i,l in enumerate(res_lig)]
idx_s_lig = [int(r[0]) for r in res_lig]
seq_lig = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res_lig]
xyz_s_lig = [[float(l[30:38]), float(l[38:46]), float(l[46:54])] for l in lines if l[:6]=='HETATM']
if len(xyz_s_lig)>0:
xyz_lig = np.full((len(idx_s_lig), NTOTAL, 3), np.nan, dtype=np.float32)
xyz_lig[:,1,:] = np.array(xyz_s_lig)
xyz = np.concatenate([xyz, xyz_lig],axis=0)
idx_s = idx_s + idx_s_lig
seq = seq + seq_lig
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s), np.array(seq)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
idx = idx_s.index(resNo)
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(np.bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
qmap = np.vstack(qmap).astype(np.long)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.long)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(qmap), \
torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn)
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,27,3),np.nan).float()
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, 27, 3), np.nan).float()
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
f1d = torch.nn.functional.one_hot(f1d, num_classes=21).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, f1d
def parse_mol(filename, filetype="mol2", string=False):
obConversion = openbabel.OBConversion()
obConversion.SetInFormat(filetype)
obmol = openbabel.OBMol()
if string:
obConversion.ReadString(obmol,filename)
else:
obConversion.ReadFile(obmol,filename)
obmol.DeleteHydrogens()
# the above sometimes fails to get all the hydrogens
i = 1
while i < obmol.NumAtoms()+1:
if obmol.GetAtom(i).GetAtomicNum()==1:
obmol.DeleteAtom(obmol.GetAtom(i))
else:
i += 1
atomtypes = [atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM') for i in range(1, obmol.NumAtoms()+1)]
msa = torch.tensor([aa2num[x] for x in atomtypes])
ins = torch.zeros_like(msa)
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms)
try:
automorphs = openbabel.vvpairUIntUInt()
openbabel.FindAutomorphisms(obmol,automorphs)
automorphs = torch.tensor(automorphs)
n_symmetry = automorphs.shape[0]
atom_coords = atom_coords.repeat(n_symmetry,1,1)
mask = mask.repeat(n_symmetry,1)
atom_coords = torch.scatter(atom_coords, 1, automorphs[:,:,0:1].repeat(1,1,3),
torch.gather(atom_coords,1,automorphs[:,:,1:2].repeat(1,1,3)))
mask = torch.scatter(mask, 1, automorphs[:,:,0],
torch.gather(mask, 1, automorphs[:,:,1]))
except Exception as e:
print(f"ERROR: automorphs for {filename} yielded invalid tensor")
return obmol, msa, ins, atom_coords, mask

View File

@@ -1,72 +0,0 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
# pre-activation bottleneck resblock
class ResBlock2D_bottleneck(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D_bottleneck, self).__init__()
padding = self._get_same_padding(kernel, dilation)
n_b = n_c // 2 # bottleneck channel
layer_s = list()
# pre-activation
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# project down to n_b
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# convolution
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# project up
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
# make final layer initialize with zeros
#nn.init.zeros_(layer_s[-1].weight)
self.layer = nn.Sequential(*layer_s)
self.reset_parameter()
def reset_parameter(self):
# zero-initialize final layer right before residual connection
nn.init.zeros_(self.layer[-1].weight)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return x + out
class ResidualNetwork(nn.Module):
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
dilation=[1,2,4,8], p_drop=0.15):
super(ResidualNetwork, self).__init__()
layer_s = list()
# project to n_feat_block
if n_feat_in != n_feat_block:
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
# add resblocks
for i_block in range(n_block):
d = dilation[i_block%len(dilation)]
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
layer_s.append(res_block)
if n_feat_out != n_feat_block:
# project to n_feat_out
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
self.layer = nn.Sequential(*layer_s)
def forward(self, x):
return self.layer(x)

View File

@@ -1,371 +0,0 @@
import json, os
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
##
## lk and lk term
#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME)
type2ljlk = {
"CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000),
"COO":(1.916661,0.141799,-3.332648,3.5000,14.653000),
"CH0":(2.011760,0.062642,1.409284,3.5000,8.998000),
"CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000),
"CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000),
"CH3":(2.011760,0.062642,7.292929,3.5000,25.855000),
"aroC":(2.016441,0.068775,1.797950,3.5000,16.704000),
"Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100),
"Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700),
"NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200),
"NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000),
"Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000),
"Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000),
"Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100),
"OH":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000),
"OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600),
"S":(1.975967,0.455970,-1.707229,3.5000,17.640000),
"SH1":(1.975967,0.455970,3.291643,3.5000,23.240000),
"Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000),
"CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000),
"CObb":(1.916661,0.141799,3.104248,3.5000,13.221000),
"OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000),
"Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil
"Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000),
"Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000),
"HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000),
"Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000),
"Haro":(1.374914,0.015909,0.0000,3.5000,0.0000),
"Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000),
"HS":(0.363887,0.050836,0.0000,3.5000,0.0000),
"genAl":(1,0.1, 0.0000, 0.0000, 0.0000),
"genAs":(1, 0.1, 0.0000, 0.0000, 0.0000),
"genAu":(1, 0.1, 0.0000, 0.0000, 0.0000),
"genB": (1,0.1, 0.0000, 0.0000, 0.0000),
"genBe": (1,0.1, 0.0000, 0.0000, 0.0000),
"genBr": (2.1971, 0.1090, 2.7951, 3.5000, 19.6876),
"genC": (2.0067, 0.0689, 2.2256, 3.5000, 10.6860), # params from CT
"genCa": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCl": (2.0496, 0.1070, 2.3668, 3.5000, 17.5849),
"genCo": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCr": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCu": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genF": (1.6941, 0.0750, 1.6442, 3.5000, 12.2163),
"genFe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genHg": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genI": (2.3600, 0.1110, 3.1361, 3.5000, 22.0891),
"genIr": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genK": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genLi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMg": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMo": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genN": (1.7854, 0.1497, -6.3760, 3.5000, 9.5221), # params from NG2
"genNi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genO": (1.5492, 0.1576, -3.5363, 3.5000, 10.7220), # params for OG3
"genOs": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genP": (2.1290, 0.5838, -9.6272, 3.5000, 34.8000), # params for PG5
"genPb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genPd": (1, 0.1, 0.0000, 0.0000,0.0000),
"genPr": (1, 0.1, 0.0000, 0.0000,0.0000),
"genPt": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRe": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRh": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRu": (1, 0.1, 0.0000, 0.0000,0.0000),
"genS": (1.9893, 0.3634, -2.3560, 3.5000, 17.6400), # params for SG3
"genSb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genTb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genTe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genU": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genW": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genV": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genY": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genZn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genATM": (1, 0.0, 0.0000, 0.0000, 0.0000), # masked
}
# cartbonded
with open(script_dir+'cartbonded.json', 'r') as j:
cartbonded_data_raw = json.loads(j.read())
# hbond donor/acceptors
class HbAtom:
NO = 0
DO = 1 # donor
AC = 2 # acceptor
DA = 3 # donor & acceptor
HP = 4 # polar H
type2hb = {
"CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO,
"CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO,
"Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO,
"Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA,
"ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO,
"Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC,
"HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP,
"HS":HbAtom.HP, # HP in rosetta(?)
"Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC,
"genAl":HbAtom.NO, "genAs":HbAtom.NO, "genAu":HbAtom.NO, "genB": HbAtom.NO,
"genBe": HbAtom.NO, "genBr": HbAtom.NO, "genC": HbAtom.NO, "genCa": HbAtom.NO,
"genCl": HbAtom.NO, "genCo": HbAtom.NO, "genCr": HbAtom.NO, "genCu": HbAtom.NO,
"genF": HbAtom.DA, "genFe": HbAtom.NO, "genHg": HbAtom.NO, "genI": HbAtom.NO,
"genIr": HbAtom.NO, "genK": HbAtom.NO, "genLi": HbAtom.NO, "genMg": HbAtom.NO,
"genMn": HbAtom.NO, "genMo": HbAtom.NO, "genN": HbAtom.DA, "genNi": HbAtom.NO,
"genO": HbAtom.DA, "genOs": HbAtom.NO, "genP": HbAtom.NO, "genPb": HbAtom.NO,
"genPd": HbAtom.NO, "genPr": HbAtom.NO, "genPt": HbAtom.NO, "genRe": HbAtom.NO,
"genRh": HbAtom.NO, "genRu": HbAtom.NO, "genS": HbAtom.DA, "genSb": HbAtom.NO,
"genSe": HbAtom.NO, "genSi": HbAtom.NO,"genSn": HbAtom.NO,"genTb": HbAtom.NO,
"genTe": HbAtom.NO, "genU": HbAtom.NO, "genW": HbAtom.NO, "genV": HbAtom.NO,
"genY": HbAtom.NO, "genZn": HbAtom.NO, "genATM": HbAtom.NO, # masked
}
##
## hbond term
## TO DO: ADD DNA
class HbDonType:
PBA = 0
IND = 1
IME = 2
GDE = 3
CXA = 4
AMO = 5
HXL = 6
AHX = 7
NTYPES = 8
class HbAccType:
PBA = 0
CXA = 1
CXL = 2
HXL = 3
AHX = 4
IME = 5
NTYPES = 6
class HbHybType:
SP2 = 0
SP3 = 1
RING = 2
NTYPES = 3
type2dontype = {
"Nbb": HbDonType.PBA,
"Ntrp": HbDonType.IND,
"NtrR": HbDonType.GDE,
"Narg": HbDonType.GDE,
"NH2O": HbDonType.CXA,
"Nlys": HbDonType.AMO,
"OH": HbDonType.HXL,
"OHY": HbDonType.AHX,
}
type2acctype = {
"OCbb": HbAccType.PBA,
"ONH2": HbAccType.CXA,
"OOC": HbAccType.CXL,
"OH": HbAccType.HXL,
"OHY": HbAccType.AHX,
"Nhis": HbAccType.IME,
}
type2hybtype = {
"OCbb": HbHybType.SP2,
"ONH2": HbHybType.SP2,
"OOC": HbHybType.SP2,
"OHY": HbHybType.SP3,
"OH": HbHybType.SP3,
"Nhis": HbHybType.RING,
}
dontype2wt = {
HbDonType.PBA: 1.45,
HbDonType.IND: 1.15,
HbDonType.IME: 1.42,
HbDonType.GDE: 1.11,
HbDonType.CXA: 1.29,
HbDonType.AMO: 1.17,
HbDonType.HXL: 0.99,
HbDonType.AHX: 1.00,
}
acctype2wt = {
HbAccType.PBA: 1.19,
HbAccType.CXA: 1.21,
HbAccType.CXL: 1.10,
HbAccType.HXL: 1.15,
HbAccType.AHX: 1.15,
HbAccType.IME: 1.17,
}
class HbPolyType:
ahdist_aASN_dARG = 0
ahdist_aASN_dASN = 1
ahdist_aASN_dGLY = 2
ahdist_aASN_dHIS = 3
ahdist_aASN_dLYS = 4
ahdist_aASN_dSER = 5
ahdist_aASN_dTRP = 6
ahdist_aASN_dTYR = 7
ahdist_aASP_dARG = 8
ahdist_aASP_dASN = 9
ahdist_aASP_dGLY = 10
ahdist_aASP_dHIS = 11
ahdist_aASP_dLYS = 12
ahdist_aASP_dSER = 13
ahdist_aASP_dTRP = 14
ahdist_aASP_dTYR = 15
ahdist_aGLY_dARG = 16
ahdist_aGLY_dASN = 17
ahdist_aGLY_dGLY = 18
ahdist_aGLY_dHIS = 19
ahdist_aGLY_dLYS = 20
ahdist_aGLY_dSER = 21
ahdist_aGLY_dTRP = 22
ahdist_aGLY_dTYR = 23
ahdist_aHIS_dARG = 24
ahdist_aHIS_dASN = 25
ahdist_aHIS_dGLY = 26
ahdist_aHIS_dHIS = 27
ahdist_aHIS_dLYS = 28
ahdist_aHIS_dSER = 29
ahdist_aHIS_dTRP = 30
ahdist_aHIS_dTYR = 31
ahdist_aSER_dARG = 32
ahdist_aSER_dASN = 33
ahdist_aSER_dGLY = 34
ahdist_aSER_dHIS = 35
ahdist_aSER_dLYS = 36
ahdist_aSER_dSER = 37
ahdist_aSER_dTRP = 38
ahdist_aSER_dTYR = 39
ahdist_aTYR_dARG = 40
ahdist_aTYR_dASN = 41
ahdist_aTYR_dGLY = 42
ahdist_aTYR_dHIS = 43
ahdist_aTYR_dLYS = 44
ahdist_aTYR_dSER = 45
ahdist_aTYR_dTRP = 46
ahdist_aTYR_dTYR = 47
cosBAH_off = 48
cosBAH_7 = 49
cosBAH_6i = 50
AHD_1h = 51
AHD_1i = 52
AHD_1j = 53
AHD_1k = 54
# map donor:acceptor pairs to polynomials
hbtypepair2poly = {
(HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
}
# polynomials are triplets, (x_min, x_max), (y[x<x_min],y[x>x_max]), (c_9,...,c_0)
hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886
HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)),
HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)),
HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)),
HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)),
HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)),
HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)),
HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)),
HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)),
HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)),
HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)),
HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)),
HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)),
HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)),
HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)),
HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)),
HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)),
HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)),
HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)),
HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)),
HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)),
HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)),
HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)),
HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)),
HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)),
HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)),
HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)),
HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)),
HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)),
HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)),
HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)),
HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)),
HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)),
HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)),
HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)),
HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)),
HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)),
HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)),
HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)),
HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)),
HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)),
HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)),
HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)),
HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)),
HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)),
HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)),
HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)),
HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)),
HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)),
HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)),
HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)),
HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)),
HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)),
HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)),
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,589 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import copy
import dgl
import networkx as nx
from util import base_indices, RTs_by_torsion, xyzs_in_base_frame, \
rigid_from_3_points, is_nucleic, is_atom
def init_lecun_normal(module, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
return module
def init_lecun_normal_param(weight, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
return weight
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
self.broadcast_dim=broadcast_dim
self.p_drop=p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if not self.broadcast_dim == None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
# Distance radial basis function
D_max = D_min + (D_count-1) * D_sigma
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
def get_seqsep(idx):
'''
Sequence separation feature for structure module. Protein-only.
Input:
- idx: residue indices of given sequence (B,L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
Sergey found that having sign in seqsep features helps a little
'''
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
neigh = torch.abs(seqsep)
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
neigh = sign * neigh
return neigh.unsqueeze(-1)
def get_seqsep_protein_sm(idx, bond_feats, sm_mask):
'''
Sequence separation features for protein-SM complex
Input:
- idx: residue indices of given sequence (B,L)
- bond_feats: bond features (B, L, L)
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
-1 or 1 for bonded protein residues
1 for bonded SM atoms or residue-atom bonds
0 elsewhere
'''
sm_mask = sm_mask[0] # assume batch = 1
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, sm_mask)
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
res_dist[(res_dist > 1) | (res_dist < -1)] = 0.0
atom_dist[(atom_dist > 1)] = 0.0
seqsep = sm_mask_2d*atom_dist + prot_mask_2d*res_dist + inter_mask_2d*(bond_feats==6)
return seqsep.unsqueeze(-1)
def get_res_atom_dist(idx, bond_feats, sm_mask, minpos_res=-32, maxpos_res=32, maxpos_atom=8):
'''
Calculates residue and atom bond distances of protein/SM complex. Used for positional
embedding and structure module. 2nd version (2022-9-19); handles atomized proteins.
Input:
- idx: residue index (B, L)
- bond_feats: bond features (B, L, L)
- sm_mask: boolean feature (L). True if a position represents atom, False otherwise
- minpos_res: minimum value of residue distances
- maxpos_res: maximum value of residue distances
- maxpos_atom: maximum value of atom bond distances
Output:
- res_dist: residue distance (B, L, L)
- atom_dist: atom bond distance (B, L, L)
'''
bond_feats = bond_feats[0] # assume batch = 1
L = bond_feats.shape[0]
gpu = bond_feats.device
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
# protein residue distances
res_dist_prot = torch.clamp(idx[0,None,:] - idx[0,:,None],
min=minpos_res, max=maxpos_res).to(gpu) # (L, L) intra-protein
res_dist_sm = torch.full((L,L), maxpos_res+1).to(gpu) # (L, L) with "unknown" res. dist. token
# small molecule atom bond graph
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask_2d*bond_feats
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpos_atom))
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
i,j,v = torch.tensor(paths).T
# small molecule atom bond distances
atom_dist_sm = torch.full((L,L), maxpos_atom).to(gpu) - maxpos_atom*torch.eye(L).to(gpu).long()
atom_dist_sm[i,j] = v.to(gpu)
atom_dist_prot = torch.full((L,L), maxpos_atom+1).to(gpu)
# s.m.-protein bonds
sm_idx = torch.where(sm_mask)[0]
prot_idx = torch.where(~sm_mask)[0]
i_s, j_s = torch.where(bond_feats==6)
i_prot = [j for i,j in zip(i_s,j_s) if i in sm_idx] # protein residues bonded to s.m. atoms
i_sm = [i for i in i_s if i in sm_idx] # s.m. atoms bonded to protein residues
# inter-protein-s.m. residue & atom distances
# atoms inherit residue distances from their nearest bonded residue
res_dist_inter = torch.full((L,L), maxpos_res).to(gpu)
if len(i_prot) > 0: # prot & s.m. are connected
for i in sm_idx:
i_closest_res = i_prot[torch.argmin(atom_dist_sm[i,i_sm])]
res_dist_inter[i,:] = res_dist_prot[i_closest_res,:]
res_dist_inter[:,i] = res_dist_prot[:,i_closest_res]
# residues inherit atom distances from their nearest bonded atom (+ 1 to count "boundary" res-atom bond)
atom_dist_inter = torch.full((L, L), maxpos_atom).to(gpu)
if len(i_prot) > 0: # prot & s.m. are connected
for i in prot_idx:
i_closest_atom = i_sm[torch.argmin(torch.abs(res_dist_prot[i,i_prot]))]
atom_dist_inter[i,:] = atom_dist_sm[i_closest_atom,:] + 1
atom_dist_inter[:,i] = atom_dist_sm[:,i_closest_atom] + 1
atom_dist_inter = torch.minimum(atom_dist_inter, torch.tensor(maxpos_atom))
res_dist = res_dist_prot * prot_mask_2d + res_dist_inter * inter_mask_2d + res_dist_sm * sm_mask_2d
atom_dist = atom_dist_prot * prot_mask_2d + atom_dist_inter * inter_mask_2d + atom_dist_sm * sm_mask_2d
return res_dist[None].to(gpu), atom_dist[None].to(gpu) # add batch dim.
def get_relpos(idx, bond_feats, sm_mask, inter_pos=32, maxpath=32):
'''
Relative position matrix of protein/SM complex. Used for positional
embedding and structure module. Simple version from 9/2/2022 that doesn't
handle atomized proteins.
Input:
- idx: residue index (B, L)
- bond_feats: bond features (B, L, L)
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
- inter_pos: value to assign as the protein-SM residue index differences
- maxpath: bond distances greater than this are clipped to this value
Output:
- relpos: relative position feature (B, L, L)
for intra-protein this is the residue index difference
for intra-SM this is the bond distance
for protein-SM this is user-defined value inter_pos
'''
bond_feats = bond_feats[0]
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
# intra-protein: residue # differences
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
# intra-small molecule: bond distances
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask*bond_feats
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpath))
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
i,j,v = torch.tensor(paths).T
bond_separation = torch.full_like(bond_feats, maxpath) \
- maxpath*torch.eye(bond_feats.shape[0]).to(bond_feats.device).long()
bond_separation[i,j] = v.to(bond_feats.device)
# combine: protein-s.m. are always positive maximum distance apart
# assumes one small molecule per example
relpos = prot_mask_2d * seqsep + sm_mask_2d * bond_separation + inter_mask_2d * inter_pos # (B, L, L)
relpos = relpos.to(bond_feats.device)
return relpos
def make_full_graph(xyz, pair, idx):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
b,i,j = torch.where(sep.abs() > 0)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# distance map from current CA coordinates
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L)
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9
if (topk_incl_local):
D = D + sep*eps
D[sep<nlocal] = 0.0
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
cond = topk_matrix > 0.0
else:
D = D + sep*eps
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal)
b,i,j = torch.where(cond)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
B,L,A = xyz.shape[:3]
device = xyz.device
D = torch.norm(
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
)
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
D[~mask2d] = 9999.
D[D==0] = 9999.
# select top K neighbors for each atom
# keep indices as batch/res/atm indices
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
bi,ri,ai = mask.nonzero(as_tuple=True)
bi = bi[:,None].repeat(1,top_k).reshape(-1)
ri = ri[:,None].repeat(1,top_k).reshape(-1)
ai = ai[:,None].repeat(1,top_k).reshape(-1)
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
edge = torch.full(ri.shape, maxbonds, device=device)
resmask = ri==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
resmask = ri+1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
resmask = ri-1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
edge = edge.clamp(0,maxbonds-1)
edge = F.one_hot(edge)[...,None]
natm = torch.sum(mask)
index = torch.zeros_like(mask, dtype=torch.long, device=device)
index[mask] = torch.arange(natm, device=device)
src=index[bi,ri,ai]
tgt=index[bi,rj,aj]
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
return G, edge
# rotate about the x axis
def make_rotX(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,1,1] = angs[:,:,0]/NORM
RTs[:,:,1,2] = -angs[:,:,1]/NORM
RTs[:,:,2,1] = angs[:,:,1]/NORM
RTs[:,:,2,2] = angs[:,:,0]/NORM
return RTs
# rotate about the x axis
def make_rotZ(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,0,0] = angs[:,:,0]/NORM
RTs[:,:,0,1] = -angs[:,:,1]/NORM
RTs[:,:,1,0] = angs[:,:,1]/NORM
RTs[:,:,1,1] = angs[:,:,0]/NORM
return RTs
# rotate about an arbitrary axis
def make_rot_axis(angs, u, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
ct = angs[:,:,0]/NORM
st = angs[:,:,1]/NORM
u0 = u[:,:,0]
u1 = u[:,:,1]
u2 = u[:,:,2]
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
return RTs
# compute allatom structure from backbone frames and torsions
#
# alphas:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# cb/cg bend: 7-9
# eps(p)/zeta(p): 10-11
# alpha/beta/gamma/delta: 12-15
# nu2/nu1/nu0: 16-18
# chi_1(na): 19
#
# RTs_in_base_frame:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# eps(p)/zeta(p): 7-8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
# RT frames (output):
# origin: 0
# omega/phi/psi: 1-3
# chi_1-4(prot): 4-7
# cb bend: 8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
class ComputeAllAtomCoords(nn.Module):
def __init__(self):
super(ComputeAllAtomCoords, self).__init__()
self.base_indices = nn.Parameter(base_indices, requires_grad=False)
self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False)
self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False)
def forward(self, seq, xyz, alphas):
B,L = xyz.shape[:2]
is_NA = is_nucleic(seq)
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA)
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
# bb
RTF0[:,:,:3,:3] = Rs
RTF0[:,:,:3,3] = Ts
# omega
RTF1 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
# phi
RTF2 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
# psi
RTF3 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
# CB bend
basexyzs = self.xyzs_in_base_frame[seq]
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
CAr = (basexyzs[:,:,1,:3])
CBr = (basexyzs[:,:,4,:3])
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
# CB twist
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
CBrotaxis2 = (CBr-CAr).cross(NCpp)
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
RTF8 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, CBrot1,CBrot2)
# chi1 + CG bend
RTF4 = torch.einsum(
'brij,brjk,brkl,brlm->brim',
RTF8,
self.RTs_in_base_frame[seq,3,:],
make_rotX(alphas[:,:,3,:]),
make_rotZ(alphas[:,:,9,:]))
# chi2
RTF5 = torch.einsum(
'brij,brjk,brkl->bril',
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
# chi3
RTF6 = torch.einsum(
'brij,brjk,brkl->bril',
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
# chi4
RTF7 = torch.einsum(
'brij,brjk,brkl->bril',
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
# ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:]
# NA alpha
RTF9 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
# NA beta
RTF10 = torch.einsum(
'brij,brjk,brkl->bril',
RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
# NA gamma
RTF11 = torch.einsum(
'brij,brjk,brkl->bril',
RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
# NA delta
RTF12 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
# NA nu2 - from gamma frame
RTF13 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
# NA nu1
RTF14 = torch.einsum(
'brij,brjk,brkl->bril',
RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
# NA nu0
RTF15 = torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
# NA chi - from nu1 frame
RTF16= torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
RTframes = torch.stack((
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8,
RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16
),dim=2)
xyzs = torch.einsum(
'brtij,brtj->brti',
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
)
return RTframes, xyzs[...,:3]

View File

@@ -1,268 +0,0 @@
import sys, os, json
import time
import numpy as np
import torch
import torch.nn as nn
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0,script_dir+'/models/fold_and_dock2/')
import parsers
from RoseTTAFoldModel import RoseTTAFoldModule
from data_loader import merge_a3m_hetero
import util
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
from model_params import MODEL_PARAM
alphabet = list("ARNDCQEGHILKMFPSTWYV-")
aa_N_1 = dict(zip(range(len(alphabet)),alphabet))
def model_wrapper(model, inputs):
logit_s, logit_aa_s, logit_pae, logit_pde, pred_crds, alpha_s, pred_allatom, pred_lddt_binned, \
msa_prev, pair_prev, state_prev = model(**inputs)
return dict(
logit_s=logit_s,
logit_aa_s=logit_aa_s,
logit_pae=logit_pae,
logit_pde=logit_pde,
pred_crds=pred_crds,
alpha_s=alpha_s,
pred_allatom=pred_allatom,
pred_lddt_binned=pred_lddt_binned,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev
)
def run_gradient_descent(model, args, inputs, cycles, loss_funcs):
B, N = 1,1
device = args.device
torch.set_grad_enabled(True)
print()
print('Starting gradient descent...')
loss_header = ''.join([f'{loss["name"]:>12}' for loss in loss_funcs])
print(f' step{"best loss":>12}{"curr loss":>12}{loss_header} curr seq')
Ls = inputs['Ls']
msa = inputs['msa'][None].to(device) # (B, N, L)
msa_one_hot_sm = nn.functional.one_hot(msa[:,:,Ls[0]:], num_classes=NAATOKENS).float()
input_logits = args.init_sd*torch.randn([B,N,Ls[0],20]).to(device).float()
input_logits = input_logits.requires_grad_(True)
optimizer = NSGD([input_logits], lr=args.learning_rate*np.sqrt(Ls[0]), dim=[-1,-2])
best_loss = torch.full((B,),1e4).to(device)
for i_step in range(args.grad_steps):
optimizer.zero_grad()
# discretize protein tokens on protein residues
msa_one_hot_prot = logits_to_probs(input_logits, output_type=args.seq_prob_type)
# pad with non-protein tokens on protein residues
msa_one_hot = torch.concat([msa_one_hot_prot, torch.zeros((B,N,Ls[0],NAATOKENS-20)).to(device)], dim=-1)
# pad with ligand residues
msa_one_hot = torch.concat([msa_one_hot, msa_one_hot_sm], dim=-2)
# predict structure
xyz_prev = inputs['xyz_prev'].clone()
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,sum(Ls),NTOTALDOFS,2), device=device)
state_prev = None
mask_recycle = inputs['mask_recycle'].clone()
for i_cycle in range(cycles):
out = model_wrapper(model, dict(
msa_one_hot=msa_one_hot,
seq_unmasked=msa_one_hot.argmax(-1)[:,0].detach(), # (B,L)
xyz=inputs['xyz_prev'],
sctors=alpha_prev,
idx=inputs['idx_pdb'],
bond_feats=inputs['bond_feats'],
chirals=inputs['chirals'],
atom_frames=inputs['atom_frames'],
t1d=inputs['t1d'],
t2d=inputs['t2d'],
xyz_t=inputs['xyz_t'][...,1,:],
alpha_t=inputs['alpha_t'],
mask_t=inputs['mask_t_2d'],
same_chain=inputs['same_chain'],
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
mask_recycle=mask_recycle,
use_checkpoint=True
))
xyz_prev = out['pred_allatom'][-1].unsqueeze(0)
msa_prev = out['msa_prev']
pair_prev = out['pair_prev']
alpha_prev = out['alpha_s'][-1]
state_prev = out['state_prev']
mask_recycle = None
# calculate loss
curr_msa = msa_one_hot.argmax(-1).detach().clone() #(B=1,N=1,L)
out['msa'] = curr_msa
loss_s = [loss['func'](out) for loss in loss_funcs]
curr_loss = torch.sum(torch.stack(loss_s), dim=0)
# best design so far
if curr_loss < best_loss:
best_loss = curr_loss
best_loss_s = loss_s
best_out = out
msa = curr_msa
best_step = i_step
# update sequence
curr_loss.backward()
optimizer.step()
# print step info
seq_str = ''.join([aa_N_1[int(a)] for a in msa_one_hot_prot.argmax(-1)[0,0]])
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in loss_s])
print(f'{i_step:>6}{float(best_loss):>12.3f}{float(curr_loss):>12.3f}{loss_str} {seq_str}')
seq_str = ''.join([aa_N_1[int(a)] for a in msa[0,0,:Ls[0]]])
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in best_loss_s])
print(f' final{float(best_loss):>12.2f}{" "*(12)}{loss_str} {seq_str}')
best_out.update(dict(loss = best_loss, msa = msa, best_step=best_step))
torch.cuda.reset_peak_memory_stats()
return best_out
def run_mcmc(model, args, inputs, cycles, loss_funcs):
B, N = 1,1
model.eval()
torch.set_grad_enabled(False)
print()
print('Starting MCMC...')
loss_header = ''.join([f'{loss["name"]:>12}' for loss in loss_funcs])
print(f' step{"best loss":>12}{"curr loss":>12}{"accept?":>8}{loss_header} curr seq')
Ls = inputs['Ls']
device = args.device
msa = inputs['msa'][None].to(device) # (B, N, L)
best_loss = torch.full((B,),1e4).to(device)
for i_step in range(args.mcmc_steps):
# make mutation
i_pos = np.random.randint(Ls[0])
aa_new = np.random.randint(20)
msa_new = msa.clone()
msa_new[0,0,i_pos] = aa_new # assumes B=1
# predict structure
xyz_prev = inputs['xyz_prev'].clone()
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,sum(Ls),NTOTALDOFS,2), device=msa_new.device)
state_prev = None
mask_recycle = inputs['mask_recycle'].clone()
for i_cycle in range(cycles):
out = model_wrapper(model, dict(
msa_one_hot=nn.functional.one_hot(msa_new, num_classes=NAATOKENS).float(),
seq_unmasked=msa_new[:,0],
xyz=inputs['xyz_prev'],
sctors=alpha_prev,
idx=inputs['idx_pdb'],
bond_feats=inputs['bond_feats'],
chirals=inputs['chirals'],
atom_frames=inputs['atom_frames'],
t1d=inputs['t1d'],
t2d=inputs['t2d'],
xyz_t=inputs['xyz_t'][...,1,:],
alpha_t=inputs['alpha_t'],
mask_t=inputs['mask_t_2d'],
same_chain=inputs['same_chain'],
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
mask_recycle=mask_recycle
))
xyz_prev = out['pred_allatom'][-1].unsqueeze(0)
msa_prev = out['msa_prev']
pair_prev = out['pair_prev']
alpha_prev = out['alpha_s'][-1]
state_prev = out['state_prev']
mask_recycle = None
# calculate loss
curr_msa = msa_new.detach().clone()
out['msa'] = curr_msa
loss_s = [loss['func'](out) for loss in loss_funcs]
curr_loss = torch.sum(torch.stack(loss_s), dim=0)
# batch-wise Metropolis update
T = args.T0*(0.5)**(i_step/args.mcmc_halflife)
p_accept = torch.clamp(torch.exp(-(curr_loss-best_loss)/T), min=0.0, max=1.0)
accept = torch.rand(1).to(device) < p_accept
if accept:
best_loss = curr_loss
best_loss_s = loss_s
msa = curr_msa
best_out = out
seq_str = ''.join([aa_N_1[int(a)] for a in msa_new[0,~util.is_atom(msa_new)[0]]])
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in loss_s])
print(f'{i_step:>6}{float(best_loss):>12.2f}{float(curr_loss):>12.2f}{int(accept):>8}{loss_str} {seq_str}')
seq_str = ''.join([aa_N_1[int(a)] for a in msa[0,0,:Ls[0]]])
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in best_loss_s])
print(f' final{float(best_loss):>12.2f}{" "*(12+8)}{loss_str} {seq_str}')
best_out.update(dict(loss = best_loss, msa = msa, loss_terms = best_loss_s))
torch.cuda.reset_peak_memory_stats()
return best_out
def logits_to_probs(logits, output_type='hard', temp=1, add_gumbel_noise=False, eps=1e-8):
device = logits.device
B, N, L, A = logits.shape
if add_gumbel_noise:
U = torch.rand(logits.shape)
noise = -torch.log(-torch.log(U + eps) + eps)
noise = noise.to(device)
logits = logits + noise
y_soft = torch.nn.functional.softmax(logits/temp, -1)
if output_type == 'soft':
return y_soft
elif output_type == 'hard':
n_cat = y_soft.shape[-1]
y_oh = torch.nn.functional.one_hot(y_soft.argmax(-1), n_cat)
y_hard = (y_oh - y_soft).detach() + y_soft
return y_hard
else:
raise NotImplementedError('Output type must be "soft" or "hard"')
class NSGD(torch.optim.Optimizer):
def __init__(self, params, lr, dim):
defaults = dict(lr=lr)
super(NSGD, self).__init__(params, defaults)
self.dim=dim
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None: continue
d_p = p.grad / (torch.norm(p.grad, dim=self.dim, keepdim=True) + 1e-8)
p.add_(d_p, alpha=-group['lr'])
return loss

View File

@@ -1,345 +0,0 @@
import sys
import numpy as np
import pandas as pd
num2aa=[
'ALA','ARG','ASN','ASP','CYS',
'GLN','GLU','GLY','HIS','ILE',
'LEU','LYS','MET','PHE','PRO',
'SER','THR','TRP','TYR','VAL',
'UNK','MAS',
]
aa2num= {x:i for i,x in enumerate(num2aa)}
# full sc atom representation (Nx14)
aa2long=[
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
(" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
(" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
]
def parse_pdb(filename, **kwargs):
'''extract xyz coords for all heavy atoms'''
lines = open(filename,'r').readlines()
return parse_pdb_lines(lines, **kwargs)
def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True):
# indices of residues observed in the structure
res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]
pdb_idx = [( l[21:22].strip(), int(l[22:26].strip()) ) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # chain letter, res num
# 4 BB + up to 10 SC atoms
xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22], int(l[22:26]), ' '+l[12:16].strip().ljust(3), l[17:20]
idx = pdb_idx.index((chain,resNo))
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
if tgtatm is not None and tgtatm.strip() == atom.strip(): # ignore whitespace
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
# remove duplicated (chain, resi)
new_idx = []
i_unique = []
for i,idx in enumerate(pdb_idx):
if idx not in new_idx:
new_idx.append(idx)
i_unique.append(i)
pdb_idx = new_idx
xyz = xyz[i_unique]
mask = mask[i_unique]
seq = np.array(seq)[i_unique]
out = {'xyz':xyz, # cartesian coordinates, [Lx14]
'mask':mask, # mask showing which atoms are present in the PDB file, [Lx14]
'idx':np.array([i[1] for i in pdb_idx]), # residue numbers in the PDB file, [L]
'seq':np.array(seq), # amino acid sequence, [L]
'pdb_idx': pdb_idx, # list of (chain letter, residue number) in the pdb file, [L]
}
# heteroatoms (ligands, etc)
if parse_hetatom:
xyz_het, info_het = [], []
for l in lines:
if l[:6]=='HETATM' and not (ignore_het_h and l[77]=='H'):
info_het.append(dict(
idx=int(l[7:11]),
atom_id=l[12:16],
atom_type=l[77],
name=l[16:20]
))
xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])])
out['xyz_het'] = np.array(xyz_het)
out['info_het'] = info_het
return out
class ContigMap():
'''
New class for doing mapping.
Supports multichain or multiple crops from a single receptor chain.
Also supports indexing jump (+200) or not, based on contig input.
Default chain outputs are inpainted chains as A (and B, C etc if multiple chains), and all fragments of receptor chain on the next one (generally B)
Output chains can be specified. Sequence must be the same number of elements as in contig string
'''
def __init__(self, pdb_idx, contigs=None, inpaint_seq=None, inpaint_str=None, length=None, ref_idx=None, hal_idx=None, idx_rf=None, inpaint_seq_tensor=None, inpaint_str_tensor=None, topo=False):
#sanity checks
if contigs is None and ref_idx is None:
sys.exit("Must either specify a contig string or precise mapping")
if idx_rf is not None or hal_idx is not None or ref_idx is not None:
if idx_rf is None or hal_idx is None or ref_idx is None:
sys.exit("If you're specifying specific contig mappings, the reference and output positions must be specified, AND the indexing for RoseTTAFold (idx_rf)")
self.chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ'
if length is not None:
if '-' not in length:
self.length = [int(length),int(length)+1]
else:
self.length = [int(length.split("-")[0]),int(length.split("-")[1])+1]
else:
self.length = None
self.ref_idx = ref_idx
self.hal_idx=hal_idx
self.idx_rf=idx_rf
self.inpaint_seq = ','.join(inpaint_seq).split(",") if inpaint_seq is not None else None
self.inpaint_str = ','.join(inpaint_str).split(",") if inpaint_str is not None else None
self.inpaint_seq_tensor=inpaint_seq_tensor
self.inpaint_str_tensor=inpaint_str_tensor
self.pdb_idx = pdb_idx
self.topo=topo
if ref_idx is None:
#using default contig generation, which outputs in rosetta-like format
self.contigs=contigs
self.sampled_mask,self.contig_length,self.n_inpaint_chains = self.get_sampled_mask()
self.receptor_chain = self.chain_order[self.n_inpaint_chains]
self.receptor, self.receptor_hal, self.receptor_rf, self.inpaint, self.inpaint_hal, self.inpaint_rf= self.expand_sampled_mask()
self.ref = self.inpaint + self.receptor
self.hal = self.inpaint_hal + self.receptor_hal
self.rf = self.inpaint_rf + self.receptor_rf
else:
#specifying precise mappings
self.ref=ref_idx
self.hal=hal_idx
self.rf = rf_idx
self.mask_1d = [False if i == ('_','_') else True for i in self.ref]
#take care of sequence and structure masking
if self.inpaint_seq_tensor is None:
if self.inpaint_seq is not None:
self.inpaint_seq = self.get_inpaint_seq_str(self.inpaint_seq)
else:
self.inpaint_seq = np.array([True if i != ('_','_') else False for i in self.ref])
else:
self.inpaint_seq = self.inpaint_seq_tensor
if self.inpaint_str_tensor is None:
if self.inpaint_str is not None:
self.inpaint_str = self.get_inpaint_seq_str(self.inpaint_str)
else:
self.inpaint_str = np.array([True if i != ('_','_') else False for i in self.ref])
else:
self.inpaint_str = self.inpaint_str_tensor
#get 0-indexed input/output (for trb file)
self.ref_idx0,self.hal_idx0, self.ref_idx0_inpaint, self.hal_idx0_inpaint, self.ref_idx0_receptor, self.hal_idx0_receptor=self.get_idx0()
def get_sampled_mask(self):
'''
Function to get a sampled mask from a contig.
'''
length_compatible=False
count = 0
while length_compatible is False:
inpaint_chains=0
contig_list = self.contigs[0].strip().split()
sampled_mask = []
sampled_mask_length = 0
#allow receptor chain to be last in contig string
if all([i[0].isalpha() for i in contig_list[-1].split(",")]):
contig_list[-1] = f'{contig_list[-1]},0'
for con in contig_list:
if ((all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0')) or self.topo is True:
#receptor chain
sampled_mask.append(con)
else:
inpaint_chains += 1
#chain to be inpainted. These are the only chains that count towards the length of the contig
subcons = con.split(",")
subcon_out = []
for subcon in subcons:
if subcon[0].isalpha():
subcon_out.append(subcon)
if '-' in subcon:
sampled_mask_length += (int(subcon.split("-")[1])-int(subcon.split("-")[0][1:])+1)
else:
sampled_mask_length += 1
else:
if '-' in subcon:
length_inpaint=np.random.randint(int(subcon.split("-")[0]),int(subcon.split("-")[1]))
subcon_out.append(f'{length_inpaint}-{length_inpaint}')
sampled_mask_length += length_inpaint
elif subcon == '0':
subcon_out.append('0')
else:
length_inpaint=int(subcon)
subcon_out.append(f'{length_inpaint}-{length_inpaint}')
sampled_mask_length += int(subcon)
sampled_mask.append(','.join(subcon_out))
#check length is compatible
if self.length is not None:
if sampled_mask_length >= self.length[0] and sampled_mask_length < self.length[1]:
length_compatible = True
else:
length_compatible = True
count+=1
if count == 100000: #contig string incompatible with this length
sys.exit("Contig string incompatible with --length range")
return sampled_mask, sampled_mask_length, inpaint_chains
def expand_sampled_mask(self):
chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ'
receptor = []
inpaint = []
receptor_hal = []
inpaint_hal = []
receptor_idx = 1
inpaint_idx = 1
inpaint_chain_idx=-1
receptor_chain_break=[]
inpaint_chain_break = []
for con in self.sampled_mask:
if (all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0') or self.topo is True:
#receptor chain
subcons = con.split(",")[:-1]
assert all([i[0] == subcons[0][0] for i in subcons]), "If specifying fragmented receptor in a single block of the contig string, they MUST derive from the same chain"
assert all(int(subcons[i].split("-")[0][1:]) < int(subcons[i+1].split("-")[0][1:]) for i in range(len(subcons)-1)), "If specifying multiple fragments from the same chain, pdb indices must be in ascending order!"
for idx, subcon in enumerate(subcons):
ref_to_add = [(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)]
receptor.extend(ref_to_add)
receptor_hal.extend([(self.receptor_chain,i) for i in np.arange(receptor_idx, receptor_idx+len(ref_to_add))])
receptor_idx += len(ref_to_add)
if idx != len(subcons)-1:
idx_jump = int(subcons[idx+1].split("-")[0][1:]) - int(subcon.split("-")[1]) -1
receptor_chain_break.append((receptor_idx-1,idx_jump)) #actual chain break in pdb chain
else:
receptor_chain_break.append((receptor_idx-1,200)) #200 aa chain break
else:
inpaint_chain_idx += 1
for subcon in con.split(","):
if subcon[0].isalpha():
ref_to_add=[(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)]
inpaint.extend(ref_to_add)
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+len(ref_to_add))])
inpaint_idx += len(ref_to_add)
else:
inpaint.extend([('_','_')] * int(subcon.split("-")[0]))
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+int(subcon.split("-")[0]))])
inpaint_idx += int(subcon.split("-")[0])
inpaint_chain_break.append((inpaint_idx-1,200))
if self.topo is True or inpaint_hal == []:
receptor_hal = [(i[0], i[1]) for i in receptor_hal]
else:
receptor_hal = [(i[0], i[1] + inpaint_hal[-1][1]) for i in receptor_hal] #rosetta-like numbering
#get rf indexes, with chain breaks
inpaint_rf = np.arange(0,len(inpaint))
receptor_rf = np.arange(len(inpaint)+200,len(inpaint)+len(receptor)+200)
for ch_break in inpaint_chain_break[:-1]:
receptor_rf[:] += 200
inpaint_rf[ch_break[0]:] += ch_break[1]
for ch_break in receptor_chain_break[:-1]:
receptor_rf[ch_break[0]:] += ch_break[1]
return receptor, receptor_hal, receptor_rf.tolist(), inpaint, inpaint_hal, inpaint_rf.tolist()
def get_inpaint_seq_str(self, inpaint_s):
'''
function to generate inpaint_str or inpaint_seq masks specific to this contig
'''
s_mask = np.copy(self.mask_1d)
inpaint_s_list = []
for i in inpaint_s:
if '-' in i:
inpaint_s_list.extend([(i[0],p) for p in range(int(i.split("-")[0][1:]), int(i.split("-")[1])+1)])
else:
inpaint_s_list.append((i[0],int(i[1:])))
for res in inpaint_s_list:
if res in self.ref:
s_mask[self.ref.index(res)] = False #mask this residue
return np.array(s_mask)
def get_idx0(self):
ref_idx0=[]
hal_idx0=[]
ref_idx0_inpaint=[]
hal_idx0_inpaint=[]
ref_idx0_receptor=[]
hal_idx0_receptor=[]
for idx, val in enumerate(self.ref):
if val != ('_','_'):
assert val in self.pdb_idx,f"{val} is not in pdb file!"
hal_idx0.append(idx)
ref_idx0.append(self.pdb_idx.index(val))
for idx, val in enumerate(self.inpaint):
if val != ('_','_'):
hal_idx0_inpaint.append(idx)
ref_idx0_inpaint.append(self.pdb_idx.index(val))
for idx, val in enumerate(self.receptor):
if val != ('_','_'):
hal_idx0_receptor.append(idx)
ref_idx0_receptor.append(self.pdb_idx.index(val))
return ref_idx0, hal_idx0, ref_idx0_inpaint, hal_idx0_inpaint, ref_idx0_receptor, hal_idx0_receptor
def get_mappings(rm):
mappings = {}
mappings['con_ref_pdb_idx'] = [i for i in rm.inpaint if i != ('_','_')]
mappings['con_hal_pdb_idx'] = [rm.inpaint_hal[i] for i in range(len(rm.inpaint_hal)) if rm.inpaint[i] != ("_","_")]
mappings['con_ref_idx0'] = rm.ref_idx0_inpaint
mappings['con_hal_idx0'] = rm.hal_idx0_inpaint
if rm.inpaint != rm.ref:
mappings['complex_con_ref_pdb_idx'] = [i for i in rm.ref if i != ("_","_")]
mappings['complex_con_hal_pdb_idx'] = [rm.hal[i] for i in range(len(rm.hal)) if rm.ref[i] != ("_","_")]
mappings['receptor_con_ref_pdb_idx'] = [i for i in rm.receptor if i != ("_","_")]
mappings['receptor_con_hal_pdb_idx'] = [rm.receptor_hal[i] for i in range(len(rm.receptor_hal)) if rm.receptor[i] != ("_","_")]
mappings['complex_con_ref_idx0'] = rm.ref_idx0
mappings['complex_con_hal_idx0'] = rm.hal_idx0
mappings['receptor_con_ref_idx0'] = rm.ref_idx0_receptor
mappings['receptor_con_hal_idx0'] = rm.hal_idx0_receptor
mappings['inpaint_str'] = rm.inpaint_str
mappings['inpaint_seq'] = rm.inpaint_seq
mappings['sampled_mask'] = rm.sampled_mask
mappings['mask_1d'] = rm.mask_1d
return mappings

View File

@@ -56,8 +56,13 @@ def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra
clamp = torch.zeros_like(difference)
clamp[:,same_chain==1] = d_clamp_intra
clamp[:,same_chain==0] = d_clamp_inter
difference = torch.clamp(difference, max=clamp)
loss = difference / A # (I, B, L, L)
mixing_factor = 0.9
unclamped_difference = difference.clone()
clamped_difference = torch.clamp(difference, max=clamp)
clamped_loss = clamped_difference / A
unclamped_loss = unclamped_difference/A # (I, B, L, L)
# Get a mask information (ignore missing residue + inter-chain residues)
# for positive cases, mask = mask_2d
@@ -66,8 +71,11 @@ def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra
mask = mask_2d * same_chain
else:
mask = mask_2d
# calculate masked loss (ignore missing regions when calculate loss)
loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
clamped_loss = (mask[None]*clamped_loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
unclamped_loss = (mask[None]*unclamped_loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
loss = mixing_factor *clamped_loss + (1-mixing_factor)*unclamped_loss
# weighting loss
w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
@@ -608,7 +616,7 @@ def compute_pde_loss(X, Y, logit_pde, atom_mask, pde_bin_step=0.3, frame_atom_ma
# from Ivan: FAPE generalized over atom sets & frames
def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=None, frame_atom_mask_2d=None,
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, dclamp_2d=None, gamma=0.99, eps=1e-4):
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, dclamp_2d=None, gamma=0.99, mixing_factor=0.9, eps=1e-4):
# X (predicted) N x L x natoms x 3
# Y (native) 1 x L x natoms x 3
@@ -617,7 +625,12 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
# frame_mask 1 x L x nframes
# frame_atom_mask 1 x L x natoms masks the frames over which fape is calculated
# frame_atom_mask_2d 1 x L x nframes x L x natoms 2d mask, 2nd dimension frames, 3rd/4th dimension atoms so fape can be taken over some atoms for some frames (only works for BB fape)
# logit_pae
# logit_pde
# dclamp int
# dclamp_2d
# gamma
# mixing_factor int in cases where the loss is clamped, allows gradients to flow by mixing the clamped and unclamped loss
if frame_atom_mask is None:
frame_atom_mask = atom_mask
@@ -652,6 +665,9 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
# multiply diff by frame_atom_mask_2d if frame_atom_mask_2d not None
if frame_atom_mask_2d is not None:
diff = diff*frame_atom_mask_2d[:,frame_mask[0]][:, :, atom_mask[0]]
N_values = torch.sum(frame_atom_mask_2d[:,frame_mask[0]][:, :, atom_mask[0]])
else:
N_values = diff.shape[1]*diff.shape[2] # frame dimension * atom dimension
assert dclamp is not None or dclamp_2d is not None, "need to provide either dclamp or dclamp_2d to compute_general_FAPE"
assert not (dclamp is not None and dclamp_2d is not None), "you provided both dclamp and dclamp_2d, please only provide one"
@@ -659,8 +675,10 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
if dclamp_2d is not None:
dclamp = dclamp_2d[:,frame_mask[0]][:, :, atom_mask[0]]
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2))
clamped_loss = (1.0/Z) * torch.sum(torch.clamp(diff, max=dclamp), dim=(1,2))/(N_values)
unclamped_loss = (1.0/Z) *torch.sum(diff, dim=(1,2))/(N_values)
loss = mixing_factor * clamped_loss + (1-mixing_factor) * unclamped_loss
pae_loss = compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae, frame_mask, atom_mask, frame_atom_mask_2d=frame_atom_mask_2d) \
if logit_pae is not None \
else torch.tensor(0).to(frames.device)
@@ -683,26 +701,26 @@ def mask_unresolved_frames(frames, frame_mask, atom_mask):
"""
B, L, natoms = atom_mask.shape
frame_mask_update = frame_mask.clone()
# reindex frames for flat X
frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device)
for i in range(L):
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
frames_reindex = frames_reindex.long()
frames_reindex = (
torch.arange(L, device=frames.device)[None,:,None,None] + frames[..., 0]
)*natoms + frames[..., 1]
masked_atom_frames = torch.any(frames_reindex>L*natoms, dim=-1) # find frames with atoms that aren't resolved
masked_atom_frames *= torch.any(frames_reindex<0, dim=-1)
frame_mask_update *= ~masked_atom_frames
# There are currently indices for frames that aren't in the coordinates bc they arent resolved, reset these indices to 0 to avoid
# indexing errors
frames_reindex[masked_atom_frames, :] = 0
frame_mask_update = frame_mask.clone()
frame_mask_update *= ~masked_atom_frames
frame_mask_update *= torch.all(
torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
axis=-1)
return frames_reindex, frame_mask_update
def calc_crd_rmsd(pred, true, atom_mask, rmsd_mask=None, alignment_radius=None):
'''
Calculate coordinate RMSD
@@ -1096,6 +1114,121 @@ def calc_clash(xs, mask):
clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask)
return clash
def calc_l1_clash_loss(xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds, lj_hb_dis=3.0, \
lj_OHdon_dis=2.6, lj_hbond_hdis=1.75, tolerance=1.5, agg="sum", eps=1e-8):
"""
The LJ potential doesnt work for large systems because it means the energy over all atoms, so 1-2 clashes wash
out of the loss. To remedy this, we want to only apply a loss on the violating elements. if doing this, the
contribution of the london dispersion forces no longer matter (you will never evaluate these because they
are not violations). Thus, switching the functional form of this to be a flat bottom L1 loss similar to the
other bond losses in RF.
"""
def compute_loss(deltas, ljrs, tolerance, eps=1e-8):
"""
deltas: differences between atom positions between valid pairs (natom_pairs, 3)
ljrs: lj radii of valid pairs
tolerance: how large the deviation can be (this could become a more complicated hyperparameter)
"""
dist = torch.sqrt( torch.sum ( torch.square( deltas ), dim=-1 ) + eps ) # compute distances
deviations = torch.clamp(ljrs - dist - tolerance, min=0)
return torch.sum(deviations), torch.sum(deviations > 0)
N, L = xs.shape[:2]
rs = torch.triu_indices(L,L,0, device=xs.device)
ri,rj = rs[0],rs[1]
# batch during inference for huge systems
BATCHSIZE = 65536//N
running_clash_val = 0
running_num_violations = 0
#NOTE: a lot of this is similar to calc_lj -- probably should move into its own fx
for i_batch in range((len(ri)-1)//BATCHSIZE + 1):
idx = torch.arange(
i_batch*BATCHSIZE,
min( (i_batch+1)*BATCHSIZE, len(ri)),
device=xs.device
)
rii,rjj = ri[idx],rj[idx] # residue pairs we consider
ridx,ai,aj = (
aamask[seq[rii]][:,:,None]*aamask[seq[rjj]][:,None,:]
).nonzero(as_tuple=True)
deltas = xs[:,rii,:,None,:]-xs[:,rjj,None,:,:] # N,BATCHSIZE,Natm,Natm,3
seqi,seqj = seq[rii[ridx]], seq[rjj[ridx]]
mask = torch.ones_like(ridx, dtype=torch.bool) # are atoms defined?
# mask out atom pairs from too-distant residues (C-alpha dist > 24A)
ca_dist = torch.linalg.norm(deltas[:,:,1,1],dim=-1)
mask *= (ca_dist[:,ridx]<24).any(dim=0) # will work for batch>1 but very inefficient
intrares = (rii[ridx]==rjj[ridx])
mask[intrares*(ai<aj)] = False # upper tri (atoms)
## count-pair
# a) intra-protein
mask[intrares] *= num_bonds[seqi[intrares],ai[intrares],aj[intrares]]>=4
pepbondres = ri[ridx]+1==rj[ridx]
mask[pepbondres] *= (
num_bonds[seqi[pepbondres],ai[pepbondres],2]
+ num_bonds[seqj[pepbondres],0,aj[pepbondres]]
+ 1) >=4
# b) intra-ligand
atommask = (ai==1)*(aj==1)
dist_matrix = torch.nan_to_num(dist_matrix, posinf=4.0) #NOTE: need to run nan_to_num to remove infinities
resmask = (dist_matrix[0,rii,rjj] >= 4) # * will only work for batch=1
mask[atommask] *= resmask[ ridx[atommask] ]
# c) protein/ligand
##fd NOTE1: changed 6->5 in masking (atom 5 is CG which should always be 4+ bonds away from connected atom)
##fd NOTE2: this does NOT work correctly for nucleic acids
##fd for NAs atoms 0-4 are masked, but also 5,7,8 and 9 should be masked!
bbatommask = (ai<5)*(aj<5)
resmask = (bond_feats[0,rii,rjj] != 6) # * will only work for batch=1
mask[bbatommask] *= resmask[ ridx[bbatommask] ]
# d) potential disulfide
# disulfide correction previously this was done by reducing well depth but
# since this loss does not have a well depth it is done by masking the pairs of atoms
potential_disulf = (ljcorr[seqi,ai,3]*ljcorr[seqj,aj,3] )
mask *= ~potential_disulf
# apply mask. only interactions to be scored remain
ai,aj,seqi,seqj,ridx = ai[mask],aj[mask],seqi[mask],seqj[mask],ridx[mask]
deltas = deltas[:,ridx,ai,aj]
# hbond correction
use_hb_dis = (
ljcorr[seqi,ai,0]*ljcorr[seqj,aj,1]
+ ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0] ).nonzero()
use_ohdon_dis = ( # OH are both donors & acceptors
ljcorr[seqi,ai,0]*ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0]
+ljcorr[seqi,ai,0]*ljcorr[seqj,aj,0]*ljcorr[seqj,aj,1]
).nonzero()
use_hb_hdis = (
ljcorr[seqi,ai,2]*ljcorr[seqj,aj,1]
+ljcorr[seqi,ai,1]*ljcorr[seqj,aj,2]
).nonzero()
ljrs = ljparams[seqi,ai,0] + ljparams[seqj,aj,0]
ljrs[use_hb_dis] = lj_hb_dis
ljrs[use_ohdon_dis] = lj_OHdon_dis
ljrs[use_hb_hdis] = lj_hbond_hdis
clash_val, num_violations = compute_loss(deltas, ljrs, tolerance, eps=eps)
running_clash_val += clash_val
running_num_violations += num_violations
# aggregate values from all batches
if agg=="sum":
return running_clash_val, running_num_violations
elif agg=="mean_over_viol":
return running_clash_val/running_num_violations, running_num_violations
else:
raise NotImplementedError("{agg} not an accepted aggregation method for clash loss")
#fd more efficient LJ loss
class LJLoss(torch.autograd.Function):

630
rf2aa/loss/loss_factory.py Normal file
View File

@@ -0,0 +1,630 @@
import torch
import torch.nn as nn
from collections import OrderedDict
from rf2aa.chemical import NAATOKENS
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins
from rf2aa.loss.loss import resolve_equiv_natives, resolve_equiv_natives_asmb, \
resolve_symmetry_predictions, resolve_symmetry, mask_unresolved_frames, \
compute_general_FAPE, torsionAngleLoss, calc_lddt, calc_allatom_lddt_loss, \
calc_crd_rmsd, calc_BB_bond_geom, calc_l1_clash_loss, calc_atom_bond_loss
from rf2aa.util import is_atom, is_protein, Ls_from_same_chain_2d, get_prot_sm_mask, \
xyz_to_frame_xyz, get_frames, NTOTALDOFS, NTOTAL
cce_loss = nn.CrossEntropyLoss(reduction='none')
def get_loss_and_misc(
trainer,
output_i, true_crds, atom_mask, same_chain,
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
loss_param
):
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
if pred_allatom is None:
_, pred_allatom = trainer.xyz_converter.compute_all_atom(msa[0][0][None],pred_crds[-1][None], alphas[-1][None])
pred_crds = pred_crds[:, None]
alphas = alphas[:, None]
if (symmRs is not None):
#print ('a', pred_crds.shape, true_crds.shape, mask_crds.shape)
###
# resolve symmetry
###
true_crds = true_crds[:,0]
atom_mask = atom_mask[:,0]
mapT2P = resolve_symmetry_predictions(pred_crds, true_crds, atom_mask, Lasu) # (Nlayer, Ltrue)
# update all derived data to only include subunits mapping to native
logit_s_new = []
for li in logit_s:
li=torch.gather(li,2,mapT2P[-1][None,None,:,None].repeat(1,li.shape[1],1,li.shape[-1]))
li=torch.gather(li,3,mapT2P[-1][None,None,None,:].repeat(1,li.shape[1],li.shape[2],1))
logit_s_new.append(li)
logit_s = tuple(logit_s_new)
logit_aa_s = logit_aa_s.view(1,NAATOKENS,msa.shape[-2],msa.shape[-1])
logit_aa_s = torch.gather(logit_aa_s,3,mapT2P[-1][None,None,None,:].repeat(1,NAATOKENS,logit_aa_s.shape[-2],1))
logit_aa_s = logit_aa_s.view(1,NAATOKENS,-1)
msa = torch.gather(msa,2,mapT2P[-1][None,None,:].repeat(1,msa.shape[-2],1))
mask_msa = torch.gather(mask_msa,2,mapT2P[-1][None,None,:].repeat(1,mask_msa.shape[-2],1))
logit_pae=torch.gather(logit_pae,2,mapT2P[-1][None,None,:,None].repeat(1,logit_pae.shape[1],1,logit_pae.shape[-1]))
logit_pae=torch.gather(logit_pae,3,mapT2P[-1][None,None,None,:].repeat(1,logit_pae.shape[1],logit_pae.shape[2],1))
logit_pde=torch.gather(logit_pde,2,mapT2P[-1][None,None,:,None].repeat(1,logit_pde.shape[1],1,logit_pde.shape[-1]))
logit_pde=torch.gather(logit_pde,3,mapT2P[-1][None,None,None,:].repeat(1,logit_pde.shape[1],logit_pde.shape[2],1))
pred_crds = torch.gather(pred_crds,2,mapT2P[:,None,:,None,None].repeat(1,1,1,3,3))
pred_allatom = torch.gather(pred_allatom,1,mapT2P[-1,None,:,None,None].repeat(1,1,NTOTAL,3))
alphas = torch.gather(alphas,2,mapT2P[:,None,:,None,None].repeat(1,1,1,NTOTALDOFS,2))
same_chain=torch.gather(same_chain,1,mapT2P[-1][None,:,None].repeat(1,1,same_chain.shape[-1]))
same_chain=torch.gather(same_chain,2,mapT2P[-1][None,None,:].repeat(1,same_chain.shape[1],1))
bond_feats=torch.gather(bond_feats,1,mapT2P[-1][None,:,None].repeat(1,1,bond_feats.shape[-1]))
bond_feats=torch.gather(bond_feats,2,mapT2P[-1][None,None,:].repeat(1,bond_feats.shape[1],1))
dist_matrix=torch.gather(dist_matrix,1,mapT2P[-1][None,:,None].repeat(1,1,dist_matrix.shape[-1]))
dist_matrix=torch.gather(dist_matrix,2,mapT2P[-1][None,None,:].repeat(1,dist_matrix.shape[1],1))
pred_lddts = torch.gather(pred_lddts,2,mapT2P[-1][None,None,:].repeat(1,pred_lddts.shape[-2],1))
idx_pdb = torch.gather(idx_pdb,1,mapT2P[-1][None,:])
elif 'sm_compl' in task[0] or 'metal_compl' in task[0]:
sm_mask = is_atom(seq[0,0])
Ls_prot = Ls_from_same_chain_2d(same_chain[:,~sm_mask][:,:,~sm_mask])
Ls_sm = Ls_from_same_chain_2d(same_chain[:,sm_mask][:,:,sm_mask])
true_crds, atom_mask = resolve_equiv_natives_asmb(
pred_allatom, true_crds, atom_mask, ch_label, Ls_prot, Ls_sm)
else:
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = get_prot_sm_mask(atom_mask, msa[0,0])
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, 0], atom_frames)
c6d = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
# contact accuray not as useful to track anymore
#prob = self.active_fn(logit_s[0]) # distogram
#acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
loss, loss_dict = calc_loss(
trainer, logit_s, c6d,
logit_aa_s, msa, mask_msa, logit_pae, logit_pde, p_bind,
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, bond_feats, dist_matrix,
atom_frames=atom_frames,unclamp=unclamp, negative=negative,
item=item, task=task, **loss_param
)
return loss, loss_dict
def calc_loss(trainer, logit_s, label_s,
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
pred, pred_tors, pred_allatom, true,
mask_crds, mask_BB, mask_2d, same_chain,
pred_lddt, idx, bond_feats, dist_matrix, atom_frames=None, unclamp=False,
negative=False, interface=False,
w_dist=1.0, w_aa=1.0, w_str=1.0, w_inter_fape=0.0, w_lig_fape=1.0, w_lddt=1.0,
w_bond=1.0, w_clash=0.0, w_atom_bond=0.0, w_skip_bond=0.0, w_rigid=0.0, w_hb=0.0, w_bind=0.0,
w_pae=0.0, w_pde=0.0, lj_lin=0.85, eps=1e-6, binder_loss_label_smoothing = 0.0, item=None, task=None, out_dir='./'
):
gpu = pred.device
# track losses for printing to local log and uploading to WandB
loss_dict = OrderedDict()
B, L, natoms = true.shape[:3]
seq = label_aa_s[:,0].clone()
assert (B==1) # fd - code assumes a batch size of 1
tot_loss = 0.0
# set up frames
frames, frame_mask = get_frames(
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
frames_BB = frames.clone()
frames_BB[..., 1:, :, :] = 0
frame_mask_BB = frame_mask.clone()
frame_mask_BB[...,1:] =False
# c6d loss
for i in range(4):
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
mask_2d_ = mask_2d
else:
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
bb_frame_good = bb_frame_good[...,0] # (1,L)
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
mask_2d_ = mask_2d & loss_mask_2d
if negative.item():
# Don't compute inter-chain distogram losses
# for negative examples.
mask_2d_ = mask_2d_ * same_chain
loss = (mask_2d_*loss).sum() / (mask_2d_.sum() + eps)
tot_loss += w_dist*loss
loss_dict[f'c6d_{i}'] = loss.detach()
# masked token prediction loss
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
loss = loss * mask_aa_s.reshape(B, -1)
loss = loss.sum() / (mask_aa_s.sum() + 1e-8)
tot_loss += w_aa*loss
loss_dict['aa_cce'] = loss.detach()
# col 4: binder loss
# only apply binding loss to complexes
# note that this will apply loss to positive sets w/o a corresponding negative set
# (e.g., homomers). Maybe want to change this?
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
if (torch.sum(same_chain==0) > 0):
bce = torch.nn.BCELoss()
target = torch.tensor(
[abs(float(not negative) - binder_loss_label_smoothing)],
device=p_bind.device
)
loss = bce(p_bind,target)
else:
# avoid unused parameter error
loss = 0.0 * p_bind.sum()
tot_loss += w_bind * loss
loss_dict['binder_bce_loss'] = loss.detach()
### GENERAL LAYERS
# Structural loss (layer-wise backbone FAPE)
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
dclamp_prot = 10
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
# whereas other losses also maks unresolved ligand atoms (mask_BB)
# frames with unresolved ligand atoms are masked in compute_general_FAPE
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
# create 2d masks for intrachain and interchain fape calculations
nframes = frame_mask.shape[-1]
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, NTOTAL)
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
different_chain = ~same_chain.bool()
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
# ic(task, res_mask.sum(), pred.shape, true.shape)
if 'tf' in task[0] or res_mask.sum() == 0:
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
pae_loss = 0.0 * logit_pae.sum()
pde_loss = 0.0 * logit_pde.sum()
elif negative: # inter-chain fapes should be ignored for negative cases
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
dclamp=dclamp,
logit_pae=logit_pae,
logit_pde=logit_pde,
)
#fd pae/pde loss not computed correctly, zero for negatives
# Pascal: I think the above is no longer true. PAE/PDE should
# be computed correctly for intra chain
#pae_loss *= 0.0
#pde_loss *= 0.0
else:
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
# change clamp for intra protein to 10, leave rest at 30
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
if not unclamp:
is_prot = is_protein(seq) # (1,L)
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
# zero out rows and columns with small molecules
same_chain_clamp_mask[:, ~is_prot[0]] = 0
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
dclamp_2d *= ~same_chain_clamp_mask.bool()
dclamp_2d += same_chain_clamp_mask*dclamp_prot
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
dclamp=None,
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
logit_pae=logit_pae,
logit_pde=logit_pde,
)
# free up big intermediate data tensors
del dclamp_2d
if not unclamp:
del same_chain_clamp_mask
num_layers = pred.shape[0]
gamma = 1.0 # equal weighting of fape across all layers
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
w_bb_fape = torch.flip(w_bb_fape, (0,))
w_bb_fape = w_bb_fape / w_bb_fape.sum()
bb_l_fape = (w_bb_fape*tot_str).sum()
tot_loss += 0.5*w_str*bb_l_fape
for i in range(len(tot_str)):
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
loss_dict['bb_fape_full'] = bb_l_fape.detach()
tot_loss += w_pae*pae_loss + w_pde*pde_loss
loss_dict['pae_loss'] = pae_loss.detach()
loss_dict['pde_loss'] = pde_loss.detach()
## small-molecule ligands
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
#if not negative and bool(torch.any(~sm_res_mask)) and torch.any(frame_mask_BB[0,~sm_res_mask]):
## protein fape (layer-averaged fape on protein coordinates with protein frames)
#l_fape_prot_intra, _, _ = compute_general_FAPE(
#pred[:, ~sm_res_mask[None],:,:3],
#true[:,~sm_res_mask,:3,:3],
#atom_mask = mask_crds[:,~sm_res_mask, :3],
#frames = frames_BB[:,~sm_res_mask],
#frame_mask = frame_mask_BB[:,~sm_res_mask],
#frame_atom_mask_2d=frame_atom_mask_2d_intra[:, ~sm_res_mask][:, :, :, ~sm_res_mask],
#)
#prot_fape = l_fape_prot_intra.mean()
#l_fape_prot_inter, _, _ = compute_general_FAPE(
#pred[:, ~sm_res_mask[None],:,:3],
#true[:,~sm_res_mask,:3,:3],
#atom_mask = mask_crds[:,~sm_res_mask, :3],
#frames = frames_BB[:,~sm_res_mask],
#frame_mask = frame_mask_BB[:,~sm_res_mask],
#frame_atom_mask_2d=frame_atom_mask_2d_inter[:, ~sm_res_mask][:, :, :, ~sm_res_mask],
#)
#inter_prot_fape = l_fape_prot_inter.mean()
#else:
#prot_fape = torch.tensor(0).to(gpu)
#inter_prot_fape = torch.tensor(0).to(gpu)
#loss_dict['bb_fape_prot_intra'] = prot_fape.detach()
#loss_dict['bb_fape_prot_inter'] = inter_prot_fape.detach()
##if bool(torch.any(sm_res_mask)) and torch.any(frame_mask_BB[0,sm_res_mask]):
## ligand fape (layer-averaged fape on atom coordinates with atom frames)
#l_fape_sm_intra, _, _ = compute_general_FAPE(
#pred[:, sm_res_mask[None],:,:3],
#true[:,sm_res_mask,:3,:3],
#atom_mask = mask_crds[:,sm_res_mask, :3],
#frames = frames_BB[:,sm_res_mask],
#frame_mask = frame_mask_BB[:,sm_res_mask],
#frame_atom_mask_2d=frame_atom_mask_2d_intra[:, sm_res_mask][:, :, :, sm_res_mask],
#dclamp=dclamp_sm,
#Z=Z_sm
#)
#lig_fape = (w_bb_fape*l_fape_sm_intra).sum()
#tot_loss += 0.5*w_lig_fape*lig_fape
#l_fape_sm_inter, _, _ = compute_general_FAPE(
#pred[:, sm_res_mask[None],:,:3],
#true[:,sm_res_mask,:3,:3],
#atom_mask = mask_crds[:,sm_res_mask, :3],
#frames = frames_BB[:,sm_res_mask],
#frame_mask = frame_mask_BB[:,sm_res_mask],
#frame_atom_mask_2d=frame_atom_mask_2d_inter[:, sm_res_mask][:, :, :, sm_res_mask],
#dclamp=dclamp_sm,
#Z=Z_sm
#)
#inter_lig_fape = l_fape_sm_inter.mean()
#else:
#lig_fape = torch.tensor(0).to(gpu)
#inter_lig_fape = torch.tensor(0).to(gpu)
#loss_dict['bb_fape_lig_intra'] = lig_fape.detach()
#loss_dict['bb_fape_lig_inter'] = inter_lig_fape.detach()
#if not bool(torch.all(sm_res_mask)) and bool(torch.any(sm_res_mask)):
## calculate interchain fape
## fape of protein coordinates wrt ligand frames
#mask_crds_protein = mask_crds.clone()
#mask_crds_protein[:, sm_res_mask] = False
#frame_mask_BB_sm = frame_mask_BB.clone()
#frame_mask_BB_sm[:,~sm_res_mask] = False
#if torch.any(mask_crds_protein[:,res_mask[0], :3]) and torch.any(frame_mask_BB_sm[:,res_mask[0]]):
#l_fape_protein_sm, _, _ = compute_general_FAPE(
#pred[:, res_mask,:,:3],
#true[:, res_mask[0],:3,:3],
#atom_mask = mask_crds_protein[:,res_mask[0], :3],
#frames = frames_BB[:,res_mask[0]],
#frame_mask = frame_mask_BB_sm[:,res_mask[0]],
#frame_atom_mask = mask_crds[:,res_mask[0],:3],
#dclamp=dclamp
#)
#else:
#l_fape_protein_sm = torch.tensor(0).to(gpu)
## fape of ligand coordinates wrt protein frames
#mask_crds_sm = mask_crds.clone()
#mask_crds_sm[:, ~sm_res_mask] = False
#frame_mask_BB_protein = frame_mask_BB.clone()
#frame_mask_BB_protein[:,sm_res_mask] = False
#if torch.any(mask_crds_sm[:,res_mask[0], :3]) and torch.any(frame_mask_BB_protein[:,res_mask[0]]):
#l_fape_sm_protein, _, _ = compute_general_FAPE(
#pred[:, res_mask,:,:3],
#true[:, res_mask[0],:3,:3],
#atom_mask = mask_crds_sm[:,res_mask[0], :3],
#frames = frames_BB[:,res_mask[0]],
#frame_mask = frame_mask_BB_protein[:,res_mask[0]],
#frame_atom_mask = mask_crds[:,res_mask[0],:3],
#dclamp=dclamp
#)
#else:
#l_fape_sm_protein = torch.tensor(0).to(gpu)
##frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
##inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
#inter_fape = l_fape_sm_protein + l_fape_protein_sm
#bb_l_fape_inter = (w_bb_fape*inter_fape).sum()
#tot_loss += 0.5*w_inter_fape*bb_l_fape_inter
#else:
#bb_l_fape_inter = torch.tensor(0).to(gpu)
#loss_dict['bb_fape_inter'] = bb_l_fape_inter.detach()
## AllAtom loss
# get ground-truth torsion angles
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
true, seq, mask_in=mask_crds)
tors_mask *= mask_BB[...,None]
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true)
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
predTs = pred[-1,...]
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
# - resolve symmetry
xs_mask = trainer.aamask[seq] # (B, L, 27)
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
xs_mask *= mask_crds # mask missing atoms & residues as well
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
# torsion angle loss
l_tors = torsionAngleLoss(
pred_tors,
true_tors,
true_tors_alt,
tors_mask,
tors_planar,
eps = 1e-10)
tot_loss += w_str*l_tors
loss_dict['torsion'] = l_tors.detach()
### FINETUNING LAYERS
# lddts (CA)
ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface)
loss_dict['ca_lddt'] = ca_lddt[-1].detach()
# lddts (allatom) + lddt loss
lddt_loss, allatom_lddt = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain,
negative=negative, interface=interface, N_stripe=10)
tot_loss += w_lddt*lddt_loss
loss_dict['lddt_loss'] = lddt_loss.detach()
loss_dict['allatom_lddt'] = allatom_lddt[0].detach()
#print (allatom_lddt[0].detach())
# FAPE losses
# allatom fape and torsion angle loss
# frames, frame_mask = get_frames(
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
if 'tf' in task[0] or res_mask.sum() == 0:
l_fape = torch.zeros((pred.shape[0])).to(gpu)
elif negative.item(): # inter-chain fapes should be ignored for negative cases
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
)
else:
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]]
)
tot_loss += w_str*l_fape[0]
loss_dict['allatom_fape'] = l_fape[0].detach()
# rmsd loss (for logging only)
if torch.any(mask_BB[0]):
rmsd = calc_crd_rmsd(
pred_allatom[:,mask_BB[0],:,:3],
nat_symm[None,mask_BB[0],:,:3],
xs_mask[:,mask_BB[0]]
)
loss_dict["rmsd"] = rmsd[0].detach()
else:
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
# create protein and not protein masks; not protein could include nucleic acids
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
not_prot_mask_BB = ~prot_mask_BB.bool()
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
xs_mask_prot[:,~prot_mask_BB] = False
xs_mask_lig[:,~not_prot_mask_BB] = False
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_prot = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
)
else:
rmsd_prot_prot = torch.tensor([0], device=pred.device)
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_lig_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
)
else:
rmsd_lig_lig = torch.tensor([0], device=pred.device)
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
alignment_radius=10.0
)
else:
rmsd_prot_lig = torch.tensor([0], device=pred.device)
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
# cart bonded (bond geometry)
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss
loss_dict['bond_geom'] = bond_loss.detach()
# if (pred_allatom.shape[0] > 1):
# bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
# if w_bond > 0.0:
# tot_loss += w_bond*bond_loss.mean()
# loss_dict['clash_loss'] = ( bond_loss.detach() )
# else:
# bond_loss = torch.tensor(0).to(gpu)
# loss_dict['bond_loss'] = bond_loss.detach()
# clash [use all atoms not just those in native]
# clash_loss = calc_lj(
# seq[0], pred_allatom,
# self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
# lj_lin=lj_lin
# )
clash_loss, num_violations = calc_l1_clash_loss(pred_allatom, seq[0],\
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, \
trainer.lj_correction_parameters, trainer.num_bonds)
if w_clash > 0.0:
tot_loss += w_clash*clash_loss.mean()
loss_dict['clash_loss'] = clash_loss.detach()
if torch.any(mask_BB[0]):
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
pred=pred_allatom[:,mask_BB[0]],
true=nat_symm[None,mask_BB[0]],
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
seq=seq[:,mask_BB[0]]
)
else:
atom_bond_loss = torch.tensor(0, device=gpu)
skip_bond_loss = torch.tensor(0, device=gpu)
rigid_loss = torch.tensor(0, device=gpu)
if w_atom_bond >= 0.0:
tot_loss += w_atom_bond*atom_bond_loss
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
if w_skip_bond >= 0.0:
tot_loss += w_skip_bond*skip_bond_loss
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
if w_rigid >= 0.0:
tot_loss += w_rigid*rigid_loss
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
chain_prot = same_chain.clone()
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, negative=True, N_stripe=10)
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
chain_lig = same_chain.clone()
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
chain_prot_lig_inter += protein_mask_2d
chain_prot_lig_inter += not_protein_mask_2d
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
chain_prot_lig_inter, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
loss_dict['total_loss'] = tot_loss.detach()
return tot_loss, loss_dict
### this file will contain specific calls to the loss function
class LossManager:
""" this class computes the loss and holds useful primitives for loss calc """
def __init__(self, config) -> None:
self.loss_list = []
self.loss_weights = []
self.loss_dict = {}
def compute_loss(self, rf_inputs, rf_outputs):
for loss in self.loss_list:
pass
def get_frames(self):
if self.frames is not None and self.frame_mask is not None:
return self.frames, self.frame_mask
else:
pass
loss_factory = {
"c6d": None,
"mlm": None,
"lddt": None,
"pae": None,
"bb_fape": None,
"allatom_fape": None,
}
def c6d_loss(loss_manager, trainer):
pass

View File

@@ -1,79 +0,0 @@
import os
import numpy as np
import subprocess
import torch
import torch.multiprocessing as mp
import argparse
from pprint import pprint
from pathlib import Path
from rf2aa.submitit_utils import add_slurm_args, create_executor
from rf2aa.arguments import get_args
from factory import trainer_factory
def call_fn(args, dataset_param, model_param, loader_param, loss_param):
master_port = str(np.random.randint(10000, 100000))
master_address = (
subprocess.check_output(
['scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1'], shell=True
)
.decode()
.strip()
)
print(f"Setting MASTER_PORT={master_port} and MASTER_ADDR={master_address}")
os.environ["MASTER_PORT"] = master_port
os.environ["MASTER_ADDR"] = master_address
print("============== INPUT ARGUMENTS ==============")
pprint(vars(args))
print("=============================================")
mp.freeze_support()
trainer_object = trainer_factory(args, dataset_param, model_param, loader_param, loss_param)
trainer_object.run_model_training(torch.cuda.device_count())
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"mode",
choices=["train", "eval"],
help="Set to eval to run evaluation only",
)
parser = add_slurm_args(parser)
print("Reading in arguments...")
args, dataset_param, model_param, loader_param, loss_param = get_args(parser)
print("Done reading in arguments.")
assert (
args.local or not args.interactive
), "When submiting via submitit you either have to launch it locally, or not in interactive mode."
if args.local:
call_fn(args, dataset_param, model_param, loader_param, loss_param)
else:
if args.mode == "train":
log_folder = Path(args.slurm_log_path) / args.model_name / "training_log/"
else:
if args.initialize_model_from_checkpoint is not None:
model_restore_path = Path(args.initialize_model_from_checkpoint)
model_name = model_restore_path.stem
else:
model_name = "no_checkpoint"
log_folder = Path(args.slurm_log_path) / args.model_name / f"eval_{model_name}/"
log_folder.parent.mkdir(parents=True, exist_ok=True)
job_name = f"{args.model_name}_training"
executor = create_executor(args, log_folder, job_name)
job = executor.submit(
call_fn, args, dataset_param, model_param, loader_param, loss_param
)
print(
f"Submitted job {job.job_id} with name {job_name} and log folder {log_folder}"
)
if __name__ == "__main__":
main()

View File

@@ -3,9 +3,9 @@ import torch.nn as nn
import assertpy
from assertpy import assert_that
from icecream import ic
from rf2aa.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling
from rf2aa.Track_module import IterativeSimulator
from rf2aa.AuxiliaryPredictor import (
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
from rf2aa.model.Track_module import IterativeSimulator
from rf2aa.model.layers.AuxiliaryPredictor import (
DistanceNetwork,
MaskedTokenNetwork,
LDDTNetwork,
@@ -26,7 +26,7 @@ def get_shape(t):
return type(t)
class RoseTTAFoldModule(nn.Module):
class LegacyRoseTTAFoldModule(nn.Module):
def __init__(
self,
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
@@ -50,6 +50,7 @@ class RoseTTAFoldModule(nn.Module):
d_hidden_templ=64,
p_drop=0.15,
additional_dt1d=0,
recycling_type="msa_pair",
SE3_param={}, SE3_ref_param={},
atom_type_index=None,
aamask=None,
@@ -60,22 +61,25 @@ class RoseTTAFoldModule(nn.Module):
cb_tor=None,
num_bonds=None,
lj_lin=0.6,
use_extra_l1=True,
use_chiral_l1=True,
use_lj_l1=False,
use_atom_frames=True,
use_same_chain=False,
# New for diffusion
freeze_track_motif=False,
assert_single_sequence_input=False,
fit=False,
tscale=1.0
):
super(RoseTTAFoldModule, self).__init__()
super(LegacyRoseTTAFoldModule, self).__init__()
self.freeze_track_motif = freeze_track_motif
self.assert_single_sequence_input = assert_single_sequence_input
self.recycling_type = recycling_type
#
# Input Embeddings
d_state = SE3_param["l0_out_features"]
self.latent_emb = MSA_emb(
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain
)
self.full_emb = Extra_emb(
d_msa=d_msa_full, d_init=NAATOKENS - 1 + 4, p_drop=p_drop
@@ -97,7 +101,8 @@ class RoseTTAFoldModule(nn.Module):
additional_dt1d=additional_dt1d)
# Update inputs with outputs from previous round
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
#
self.simulator = IterativeSimulator(
n_extra_block=n_extra_block,
@@ -122,23 +127,22 @@ class RoseTTAFoldModule(nn.Module):
cb_ang=cb_ang,
cb_tor=cb_tor,
lj_lin=lj_lin,
use_extra_l1=use_extra_l1,
use_lj_l1=use_lj_l1,
use_chiral_l1=use_chiral_l1,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method,
main_block=main_block
main_block=main_block,
use_same_chain=use_same_chain
)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
if (
use_extra_l1
): # extra l1 features introduced at the same time as pAE/pDE heads
self.pae_pred = PAENetwork(d_pair)
self.pde_pred = PAENetwork(
self.pae_pred = PAENetwork(d_pair)
self.pde_pred = PAENetwork(
d_pair
) # distance error, but use same architecture as aligned error
@@ -147,8 +151,6 @@ class RoseTTAFoldModule(nn.Module):
# this prediction head.
# self.binder_network = BinderNetwork(d_pair, d_state)
self.use_extra_l1 = use_extra_l1
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
self.use_atom_frames = use_atom_frames
@@ -167,7 +169,7 @@ class RoseTTAFoldModule(nn.Module):
dist_matrix,
chirals,
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
msa_prev=None, pair_prev=None, mask_recycle=None, is_motif=None,
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
return_raw=False,
use_checkpoint=False,
return_infer=False, #fd ?
@@ -338,12 +340,15 @@ class RoseTTAFoldModule(nn.Module):
msa_prev = torch.zeros_like(msa_latent[:,0])
if pair_prev is None:
pair_prev = torch.zeros_like(pair)
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle = self.recycle(msa_prev, pair_prev, xyz, mask_recycle)
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
state = state + state_recycle # if state is not recycled these will be zeros
# add template embedding
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
@@ -383,26 +388,15 @@ class RoseTTAFoldModule(nn.Module):
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
)
# if return_infer:
# xyz_last = xyz_allatom[-1].unsqueeze(0)
# # return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
# # return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
# return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
logits_pae = logits_pde = p_bind = None
if not return_infer:
# predict aligned error and distance error
if self.use_extra_l1:
logits_pae = self.pae_pred(pair)
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
# predict aligned error and distance error
logits_pae = self.pae_pred(pair)
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
#fd predict bind/no-bind
p_bind = self.bind_pred(logits_pae,same_chain)
else:
logits_pae = None
logits_pde = None
#fd predict bind/no-bind
p_bind = self.bind_pred(logits_pae,same_chain)
return (
logits, logits_aa, logits_pae, logits_pde, p_bind,
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
)

View File

@@ -8,11 +8,10 @@ from icecream import ic
from contextlib import ExitStack, nullcontext
from rf2aa.util_module import *
from rf2aa.Attention_module import *
from rf2aa.SE3_network import SE3TransformerWrapper
from rf2aa.resnet import ResidualNetwork
from rf2aa.model.layers.Attention_module import *
from rf2aa.model.layers.SE3_network import SE3TransformerWrapper
from rf2aa.util import INIT_CRDS, is_atom, xyz_frame_from_rotation_mask
from rf2aa.loss import (
from rf2aa.loss.loss import (
calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads,
calc_lj, calc_cart_bonded, calc_chiral_grads
)
@@ -28,7 +27,7 @@ from rf2aa.symmetry import get_symm_map
class PositionalEncoding2D(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15, use_same_chain=False):
super(PositionalEncoding2D, self).__init__()
self.minpos = minpos
self.maxpos = maxpos
@@ -37,8 +36,13 @@ class PositionalEncoding2D(nn.Module):
self.nbin_atom = maxpos_atom+2 # include 0 and "unknown" token (maxpos_sm + 1)
self.emb_res = nn.Embedding(self.nbin_res, d_pair)
self.emb_atom = nn.Embedding(self.nbin_atom, d_pair)
self.use_same_chain = use_same_chain
if use_same_chain:
self.emb_chain = nn.Embedding(2, d_pair)
def forward(self, seq, idx, bond_feats, dist_matrix):
def forward(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
sm_mask = is_atom(seq[0])
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask,
@@ -54,6 +58,10 @@ class PositionalEncoding2D(nn.Module):
out = emb_res + emb_atom
if self.use_same_chain and same_chain is not None:
emb_c = self.emb_chain(same_chain.long())
out += emb_c*0 # cursed but exists for backwards compatibility
return out
@@ -353,10 +361,10 @@ class PairStr2Pair(nn.Module):
return pair + (pairnew/countnew[...,None]).reshape(N,L,L,-1)
def forward(self, pair, rbf_feat, state, crop=64):
def forward(self, pair, rbf_feat, state, crop=-1):
B,L = pair.shape[:2]
rbf_feat = self.emb_rbf(rbf_feat)
rbf_feat = self.emb_rbf(rbf_feat) # B, L, L, d_pair
state = self.norm_state(state)
left = self.proj_left(state)
@@ -458,7 +466,7 @@ class Str2Str(nn.Module):
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.embed_node = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
self.ff_node = FeedForwardLayer(SE3_param['l0_in_features'], 2, p_drop=p_drop)
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
@@ -887,7 +895,7 @@ class IterBlock(nn.Module):
d_hidden=32, d_hidden_msa=None,
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
nextra_l0=0, nextra_l1=0,
nextra_l0=0, nextra_l1=0, use_same_chain=False,
symmetrize_repeats=None, repeat_length=None,symmsub_k=None, sym_method=None, main_block=None,
fit=False, tscale=1.0
):
@@ -900,7 +908,7 @@ class IterBlock(nn.Module):
self.tscale = tscale
self.pos = PositionalEncoding2D(d_rbf, minpos=minpos, maxpos=maxpos,
maxpos_atom=maxpos_atom, p_drop=p_drop)
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain)
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
n_head=n_head_msa,
@@ -932,7 +940,7 @@ class IterBlock(nn.Module):
crop=-1
):
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, dist_matrix)
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, dist_matrix, same_chain)
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
@@ -966,7 +974,7 @@ class IterativeSimulator(nn.Module):
atom_type_index=None, aamask=None,
ljlk_parameters=None, lj_correction_parameters=None,
cb_len=None, cb_ang=None, cb_tor=None,
num_bonds=None, lj_lin=0.6, use_extra_l1=True,
num_bonds=None, lj_lin=0.6, use_same_chain=False, use_chiral_l1=True, use_lj_l1=False,
symmetrize_repeats=None,
repeat_length=None,
symmsub_k=None,
@@ -990,7 +998,8 @@ class IterativeSimulator(nn.Module):
self.cb_len = cb_len
self.cb_ang = cb_ang
self.cb_tor = cb_tor
self.use_extra_l1 = use_extra_l1 # set to False to not use chiral & LJ grads
self.use_chiral_l1 = use_chiral_l1
self.use_lj_l1 = use_lj_l1
self.fit = fit
self.tscale = tscale
@@ -1004,7 +1013,8 @@ class IterativeSimulator(nn.Module):
p_drop=p_drop,
use_global_attn=True,
SE3_param=SE3_param,
nextra_l1=3 if self.use_extra_l1 else 0,
nextra_l1=3 if self.use_chiral_l1 else 0,
use_same_chain=use_same_chain,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
@@ -1023,7 +1033,8 @@ class IterativeSimulator(nn.Module):
p_drop=p_drop,
use_global_attn=False,
SE3_param=SE3_param,
nextra_l1=3 if self.use_extra_l1 else 0,
nextra_l1=3 if self.use_chiral_l1 else 0,
use_same_chain=use_same_chain,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
@@ -1035,12 +1046,19 @@ class IterativeSimulator(nn.Module):
# Final SE(3) refinement
if n_ref_block > 0:
n_extra_l0 = 0
n_extra_l1 = 0
if self.use_chiral_l1:
n_extra_l1 += 3
if self.use_lj_l1:
n_extra_l0 += 2*NTOTALDOFS
n_extra_l1 += 3
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_ref_param,
p_drop=p_drop,
nextra_l0=2*NTOTALDOFS if self.use_extra_l1 else 0,
nextra_l1=6 if self.use_extra_l1 else 0
nextra_l0=n_extra_l0,
nextra_l1=n_extra_l1,
)
# # Fine-tuning all-atom SE(3) refinement
@@ -1096,7 +1114,7 @@ class IterativeSimulator(nn.Module):
for i_m in range(self.n_extra_block):
extra_l0 = None
extra_l1 = None
if self.use_extra_l1:
if self.use_chiral_l1:
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l1 = dchiraldxyz[0].detach()
@@ -1119,7 +1137,7 @@ class IterativeSimulator(nn.Module):
for i_m in range(self.n_main_block):
extra_l0 = None
extra_l1 = None
if self.use_extra_l1:
if self.use_chiral_l1:
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l1 = dchiraldxyz[0].detach()
msa, pair, xyz, state, alpha, symmsub = self.main_block[i_m](msa, pair,
@@ -1140,17 +1158,16 @@ class IterativeSimulator(nn.Module):
_, xyzallatom = self.xyzconverter.compute_all_atom(seq_unmasked, xyz, alpha) # think about detach here...
# memory savings: only backprop 1st and another random step
backprop = np.random.randint(1,self.n_ref_block)
backprop = torch.arange(self.n_ref_block) # backprop through everything
for i_m in range(self.n_ref_block):
with ExitStack() as stack:
if i_m != 0 and i_m != backprop:
if (backprop != i_m).all():
stack.enter_context(torch.no_grad())
extra_l0 = None
extra_l1 = None
extra_l1 = []
if self.use_extra_l1:
if self.use_lj_l1:
dljdxyz, dljdalpha = calc_lj_grads(
seq_unmasked, xyz.detach(), alpha.detach(),
self.xyzconverter.compute_all_atom,
@@ -1160,14 +1177,19 @@ class IterativeSimulator(nn.Module):
self.lj_correction_parameters,
self.num_bonds,
lj_lin=self.lj_lin)
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach()
extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
extra_l1.append(dljdxyz[0].detach())
if self.use_chiral_l1:
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
#extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
extra_l1.append(dchiraldxyz[0].detach())
extra_l1 = torch.cat(extra_l1, dim=1)
xyz, state, alpha = self.str_refiner(
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx,
rotation_mask, bond_feats, dist_matrix, atom_frames,
is_motif, extra_l0.float(), extra_l1.float(), top_k=64, use_atom_frames=use_atom_frames #fd 128->64
is_motif, extra_l0, extra_l1.float(), top_k=64, use_atom_frames=use_atom_frames #fd 128->64
)

View File

@@ -0,0 +1,125 @@
import torch
import torch.nn as nn
from rf2aa.model.layers.Embeddings import MSA_emb, MSA_emb_nostate, \
Extra_emb, Bond_emb, Templ_emb, recycling_factory
from rf2aa.chemical import NBTYPES, NAATOKENS
class RF2_embedding(nn.Module):
def __init__(self, global_params, block_params):
super(RF2_embedding, self).__init__()
d_msa, d_msa_full, d_pair, d_state = global_params["d_msa"], global_params["d_msa_full"], global_params["d_pair"], global_params["d_state"]
self.latent_emb = MSA_emb(
d_msa=d_msa,
d_pair=d_pair,
d_state=d_state,
p_drop=block_params.p_drop,
use_same_chain=block_params.use_same_chain
)
self.full_emb = Extra_emb(
d_msa=d_msa_full,
d_init=NAATOKENS - 1 + 4, #HACK: should define this freom the config (4: ins/del,nterm/cterm feats)
p_drop=block_params.p_drop
)
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES)
self.templ_emb = Templ_emb(d_pair=d_pair,
d_templ=block_params.d_templ,
d_state=d_state,
n_head=block_params.n_head_templ,
d_hidden=block_params.d_hidden_templ,
p_drop=block_params.templ_p_drop,
symmetrize_repeats=block_params.symmetrize_repeats, # repeat protein stuff
repeat_length=block_params.repeat_length,
symmsub_k=block_params.symmsub_k,
sym_method=block_params.sym_method,
main_block=block_params.main_block,
copy_main_block=block_params.copy_main_block_template,
additional_dt1d=block_params.additional_dt1d)
## Update inputs with outputs from previous forward pass
self.recycle = recycling_factory[block_params.recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
self.recycling_type = block_params.recycling_type
assert self.recycling_type == "msa_pair", "no backward compatibility to recycling state"
def _unpack_inputs(self, rf_inputs):
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix = \
rf_inputs["msa_latent"], rf_inputs["msa_full"], rf_inputs["seq"], rf_inputs["idx"], rf_inputs["bond_feats"], \
rf_inputs["dist_matrix"]
## recycling inputs
msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle = rf_inputs["msa_prev"], rf_inputs["pair_prev"], None, \
rf_inputs["xyz"], rf_inputs["sctors"], rf_inputs["mask_recycle"]
return msa_latent, msa_full, seq, idx, bond_feats, dist_matrix, msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle
def _add_templ_features(self, rf_inputs, pair, state):
t1d, t2d, alpha_t, xyz_t, mask_t = rf_inputs["t1d"], rf_inputs["t2d"], \
rf_inputs["alpha_t"], rf_inputs["xyz_t"], \
rf_inputs["mask_t"]
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state)
return pair, state
def forward(self, rf_inputs):
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix, msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle = \
self._unpack_inputs(rf_inputs)
B, N, L = msa_latent.shape[:3]
dtype = msa_latent.dtype
msa_latent, pair, state = self.latent_emb(
msa_latent, seq, idx, bond_feats, dist_matrix
)
msa_full = self.full_emb(msa_full, seq, idx)
pair = pair + self.bond_emb(bond_feats)
msa_latent, pair = msa_latent.to(dtype), pair.to(dtype)
msa_full = msa_full.to(dtype)
if state is not None:
state = state.to(dtype)
if msa_prev is None:
msa_prev = torch.zeros_like(msa_latent[:,0])
if pair_prev is None:
pair_prev = torch.zeros_like(pair)
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
state_prev = torch.zeros_like(msa_latent[:, 0])
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
# No support for recycling state
#state = state + state_recycle # if state is not recycled these will be zeros
# add template embedding
pair, state = self._add_templ_features(rf_inputs, pair, state)
return {
"msa": msa_latent,
"msa_full": msa_full,
"pair": pair,
"state": state
}
class RF2_embedding_nostate(RF2_embedding):
def __init__(self, global_params, block_params):
super(RF2_embedding_nostate, self).__init__(global_params, block_params)
d_msa, d_msa_full, d_pair, d_state = global_params["d_msa"], global_params["d_msa_full"], global_params["d_pair"], global_params["d_state"]
self.latent_emb = MSA_emb_nostate(
d_msa=d_msa,
d_pair=d_pair,
d_state=d_state,
p_drop=block_params.p_drop,
use_same_chain=block_params.use_same_chain
)
embedding_factory = {
"rf2aa": RF2_embedding,
"rf2aa_nostate": RF2_embedding_nostate
}

View File

@@ -85,7 +85,6 @@ class SequenceWeight(nn.Module):
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
self.dropout = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
@@ -453,7 +452,7 @@ class BiasedAxialAttention(nn.Module):
pair = self.norm_pair(pair)
bias = self.norm_bias(bias)
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)

View File

@@ -4,9 +4,9 @@ import torch.nn as nn
from rf2aa.chemical import NAATOKENS
class DistanceNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
def __init__(self, n_feat, p_drop=0.0):
super(DistanceNetwork, self).__init__()
#
#HACK: dimensions are hard coded here
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
self.proj_asymm = nn.Linear(n_feat, 37+19)
@@ -36,7 +36,7 @@ class DistanceNetwork(nn.Module):
return logits_dist, logits_omega, logits_theta, logits_phi
class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
def __init__(self, n_feat, p_drop=0.0):
super(MaskedTokenNetwork, self).__init__()
#fd note this predicts probability for the mask token (which is never in ground truth)
@@ -76,6 +76,7 @@ class PAENetwork(nn.Module):
super(PAENetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_pae)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
@@ -101,3 +102,10 @@ class BinderNetwork(nn.Module):
prob = torch.sigmoid( self.classify( logits_inter ) )
return prob
aux_predictor_factory = {
"c6d": DistanceNetwork,
"mlm": MaskedTokenNetwork,
"plddt": LDDTNetwork,
"pae": PAENetwork,
"binder": BinderNetwork
}

View File

@@ -1,3 +1,4 @@
from rf2aa.util import NAATOKENS
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -5,8 +6,8 @@ from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rf2aa.util import *
from rf2aa.util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
from rf2aa.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
from rf2aa.Track_module import PairStr2Pair, PositionalEncoding2D
from rf2aa.model.layers.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
from rf2aa.model.Track_module import PairStr2Pair, PositionalEncoding2D
from rf2aa.chemical import NAATOKENS,NTOTALDOFS, NBTYPES
# Module contains classes and functions to generate initial embeddings
@@ -14,7 +15,7 @@ from rf2aa.chemical import NAATOKENS,NTOTALDOFS, NBTYPES
class MSA_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2,
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
super(MSA_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
@@ -22,7 +23,7 @@ class MSA_emb(nn.Module):
self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_state = nn.Embedding(NAATOKENS, d_state)
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
maxpos_atom=maxpos_atom, p_drop=p_drop)
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain)
self.reset_parameter()
@@ -35,6 +36,26 @@ class MSA_emb(nn.Module):
nn.init.zeros_(self.emb.bias)
def _msa_emb(self, msa, seq):
N = msa.shape[1]
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
return msa
def _pair_emb(self, seq, idx, bond_feats, dist_matrix):
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix) # add relative position
return pair
def _state_emb(self, seq):
return self.emb_state(seq)
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
@@ -45,25 +66,24 @@ class MSA_emb(nn.Module):
# - msa: Initial MSA embedding (B, N, L, d_msa)
# - pair: Initial Pair embedding (B, L, L, d_pair)
N = msa.shape[1] # number of sequenes in MSA
# msa embedding
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
#msa = self.drop(msa)
msa = self._msa_emb(msa, seq)
# pair embedding
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix) # add relative position
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
# state embedding
state = self.emb_state(seq)
state = self._state_emb(seq)
return msa, pair, state
class MSA_emb_nostate(MSA_emb):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2 * NAATOKENS + 2 + 2, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
super().__init__(d_msa, d_pair, d_state, d_init, minpos, maxpos, maxpos_atom, p_drop, use_same_chain)
self.emb_state = None # emb state is just the identity
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
msa = self._msa_emb(msa, seq)
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
return msa, pair, None
class Extra_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_init=NAATOKENS-1+4, p_drop=0.1):
@@ -352,7 +372,7 @@ class Recycling(nn.Module):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
def forward(self, msa, pair, xyz, mask_recycle=None):
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = msa.shape[:2]
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
@@ -367,5 +387,47 @@ class Recycling(nn.Module):
pair = pair + self.proj_dist(dist_CA)
return msa, pair
return msa, pair, state # state is just zeros
class RecyclingAllFeatures(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
super(RecyclingAllFeatures, self).__init__()
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa)
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_state = nn.LayerNorm(d_state)
self.reset_parameter()
def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
self.proj_sctors = init_lecun_normal(self.proj_sctors)
nn.init.zeros_(self.proj_sctors.bias)
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = pair.shape[:2]
state = self.norm_state(state)
left = state.unsqueeze(2).expand(-1,-1,L,-1)
right = state.unsqueeze(1).expand(-1,L,-1,-1)
Ca_or_P = xyz[:,:,1].contiguous()
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
if mask_recycle != None:
dist = mask_recycle[...,None].float()*dist
dist = torch.cat((dist, left, right), dim=-1)
dist = self.proj_dist(dist)
pair = dist + self.norm_pair(pair)
sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS))
msa = sctors + self.norm_msa(msa)
return msa, pair, state
recycling_factory = {
"msa_pair": Recycling,
"all": RecyclingAllFeatures
}

View File

@@ -0,0 +1,296 @@
import torch
import torch.nn as nn
from icecream import ic
import inspect
import sys, os
#script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
#sys.path.insert(0,script_dir+'SE3Transformer')
from rf2aa.util import xyz_frame_from_rotation_mask
from rf2aa.util_module import init_lecun_normal_param, \
make_full_graph, rbf, init_lecun_normal
from rf2aa.loss.loss import calc_chiral_grads
from rf2aa.model.layers.Attention_module import FeedForwardLayer
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.model.layers.resnet import SCPred
se3_transformer_path = inspect.getfile(SE3Transformer)
se3_fiber_path = inspect.getfile(Fiber)
assert 'rf2aa' in se3_transformer_path
class SE3TransformerWrapper(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32):
super().__init__()
# Build the network
self.l1_in = l1_in_features
self.l1_out = l1_out_features
#
fiber_edge = Fiber({0: num_edge_features})
if l1_out_features > 0:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
self.se3 = SE3Transformer(num_layers=num_layers,
fiber_in=fiber_in,
fiber_hidden=fiber_hidden,
fiber_out = fiber_out,
num_heads=n_heads,
channels_div=div,
fiber_edge=fiber_edge,
populate_edge="arcsin",
final_layer="lin",
use_layer_norm=True)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
if self.l1_out > 0:
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
node_features = {'0': type_0_features, '1': type_1_features}
else:
node_features = {'0': type_0_features}
edge_features = {'0': edge_features}
return self.se3(G, node_features, edge_features)
class FullyConnectedSE3_noR(nn.Module):
def __init__(self,
d_msa,
d_pair,
d_rbf,
num_layers,
num_channels,
num_degrees,
n_heads,
div,
l0_in_features,
l0_out_features,
l1_in_features,
l1_out_features,
num_edge_features
):
super(FullyConnectedSE3_noR, self).__init__()
# initial node & pair feature process
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.embed_node = nn.Linear(d_msa, l0_in_features)
self.ff_node = FeedForwardLayer(l0_in_features, 2) #HACK: hardcoded value
self.norm_node = nn.LayerNorm(l0_in_features)
self.embed_edge = nn.Linear(d_pair+d_rbf, num_edge_features)
self.ff_edge = FeedForwardLayer(num_edge_features, 2)
self.norm_edge = nn.LayerNorm(num_edge_features)
self.se3 = SE3TransformerWrapper(
num_layers=num_layers,
num_channels=num_channels,
num_degrees=num_degrees,
n_heads=n_heads,
div=div,
l0_in_features=l0_in_features,
l0_out_features=l0_out_features,
l1_in_features=l1_in_features,
l1_out_features=l1_out_features,
num_edge_features=num_edge_features
)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.embed_node = init_lecun_normal(self.embed_node)
self.embed_edge = init_lecun_normal(self.embed_edge)
# initialize bias to zeros
nn.init.zeros_(self.embed_node.bias)
nn.init.zeros_(self.embed_edge.bias)
def embed_node_feats(self, msa, state):
seq = self.norm_msa(msa[:, 0])
node = self.embed_node(seq)
node = node + self.ff_node(node)
node = self.norm_node(node)
return node
def embed_edge_feats(self, pair, xyz):
pair = self.norm_pair(pair)
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas))
edge = torch.cat((pair, rbf_feat), dim=-1)
edge = self.embed_edge(edge)
edge = edge + self.ff_edge(edge)
edge = self.norm_edge(edge)
return edge
def construct_graph(self, xyz, edge):
B, L = xyz.shape[:2]
idx = torch.arange(L, device=edge.device).reshape(B, L) # NOTE: only works in B==1
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx)
return G, edge_feats
def construct_l1_feats(self, xyz, is_atom, atom_frames, chirals):
l1_feats = get_chiral_vectors(xyz[...,:3,:], chirals)[..., 1:2, :] # only pass features from Calpha
return l1_feats
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, is_atom):
B, L = xyz.shape[:2]
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
state = shift["0"].reshape(B, L, -1)
offset = shift["1"].reshape(B, L, 3)
T = offset / 10.0
xyz_update = xyz.clone()
xyz_update[...,1:2, :] = xyz[..., 1:2, :] +T[..., None, :]
return state, xyz_update
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
#TODO: allow these functions to accept kwargs so we can pass
# different inputs when iterating
B, N, L = msa.shape[:3]
node = self.embed_node_feats(msa, state)
edge = self.embed_edge_feats(pair, xyz)
G, edge_feats = self.construct_graph(xyz, edge)
#TODO: get extra l1 feats automatically and populate the extra l1 dimension
l1_feats = self.construct_l1_feats(xyz, is_atom, atom_frames, chirals)
state, xyz_update = self.compute_structure_update(G, node, l1_feats, edge_feats, xyz, is_atom)
return state, xyz_update
class FullyConnectedSE3(FullyConnectedSE3_noR):
def __init__(self, d_msa, d_pair, d_state, d_rbf, num_layers, num_channels, num_degrees, n_heads, div, l0_in_features, l0_out_features, l1_in_features, l1_out_features, num_edge_features, sc_pred_d_hidden, sc_pred_p_drop):
super().__init__(d_msa, d_pair, d_rbf, num_layers, num_channels, num_degrees, n_heads, div, l0_in_features, l0_out_features, l1_in_features, l1_out_features, num_edge_features)
self.embed_node = nn.Linear(d_msa+d_state, l0_in_features)
self.norm_state = nn.LayerNorm(d_state)
self.sc_predictor = SCPred(
d_msa=d_msa,
d_state=l0_out_features,
d_hidden=sc_pred_d_hidden,
p_drop=sc_pred_p_drop
)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.embed_node = init_lecun_normal(self.embed_node)
self.embed_edge = init_lecun_normal(self.embed_edge)
# initialize bias to zeros
nn.init.zeros_(self.embed_node.bias)
nn.init.zeros_(self.embed_edge.bias)
nn.init.ones_(self.norm_msa.weight)
nn.init.ones_(self.norm_pair.weight)
def embed_node_feats(self, msa, state):
seq = self.norm_msa(msa[:, 0])
state = self.norm_state(state)
node = self.embed_node(torch.cat((seq, state), dim=-1))
node = node + self.ff_node(node)
node = self.norm_node(node)
return node
def construct_l1_feats(self, xyz, is_atom, atom_frames, chirals):
l1_feats = torch.cat(
[
get_backbone_offset_vectors(xyz, is_atom, atom_frames),
get_chiral_vectors(xyz, chirals)
], dim=1
)
return l1_feats
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, is_atom):
B, L = node.shape[:2]
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
state = shift["0"].reshape(B, L, -1)
offset = shift["1"].reshape(B, L, 2, 3)
T = offset[:,:,0,:] / 10
R = offset[:,:,1,:] / 100.0
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
v = xyz - xyz[:,:,1:2,:]
Rout = torch.zeros((B,L,3,3), device=xyz.device)
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
Rout = torch.where(is_atom.reshape(B, L, 1,1), I, Rout)
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
return state, xyz
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
state, xyz = super().forward(msa, pair, state, xyz, is_atom, atom_frames, chirals)
alpha = self.sc_predictor(msa[:, 0], state)
return {
"state": state,
"xyz": xyz,
"alpha": alpha
}
def get_backbone_offset_vectors(xyz, is_atom, atom_frames):
xyz_frame = xyz_frame_from_rotation_mask(xyz, is_atom, atom_frames)
l1_feats = xyz_frame - xyz_frame[:,:,1,:].unsqueeze(2)
return l1_feats[0][..., :3, :]
def get_chiral_vectors(xyz, chirals):
dchiraldxyz, = calc_chiral_grads(xyz,chirals)
extra_l1 = dchiraldxyz[0]
extra_l1_slice = extra_l1.clone()
return extra_l1_slice.detach()

View File

@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from rf2aa.util_module import init_lecun_normal
class OuterProductMean(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
super(OuterProductMean, self).__init__()
self.norm = nn.LayerNorm(d_msa)
self.proj_left = nn.Linear(d_msa, d_hidden)
self.proj_right = nn.Linear(d_msa, d_hidden)
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.proj_left = init_lecun_normal(self.proj_left)
self.proj_right = init_lecun_normal(self.proj_right)
nn.init.zeros_(self.proj_left.bias)
nn.init.zeros_(self.proj_right.bias)
# zero initialize output
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
msa = self.norm(msa)
left = self.proj_left(msa)
right = self.proj_right(msa)
right = right / float(N)
out = torch.einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
out = self.proj_out(out)
return out

View File

@@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from rf2aa.chemical import NTOTALDOFS
from rf2aa.util_module import init_lecun_normal
# pre-activation bottleneck resblock
class ResBlock2D_bottleneck(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D_bottleneck, self).__init__()
padding = self._get_same_padding(kernel, dilation)
n_b = n_c // 2 # bottleneck channel
layer_s = list()
# pre-activation
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# project down to n_b
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# convolution
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# project up
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
# make final layer initialize with zeros
#nn.init.zeros_(layer_s[-1].weight)
self.layer = nn.Sequential(*layer_s)
self.reset_parameter()
def reset_parameter(self):
# zero-initialize final layer right before residual connection
nn.init.zeros_(self.layer[-1].weight)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return x + out
class ResidualNetwork(nn.Module):
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
dilation=[1,2,4,8], p_drop=0.15):
super(ResidualNetwork, self).__init__()
layer_s = list()
# project to n_feat_block
if n_feat_in != n_feat_block:
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
# add resblocks
for i_block in range(n_block):
d = dilation[i_block%len(dilation)]
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
layer_s.append(res_block)
if n_feat_out != n_feat_block:
# project to n_feat_out
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
self.layer = nn.Sequential(*layer_s)
def forward(self, x):
return self.layer(x)
#TODO: get rid of this duplicated code, needed now to avoid circular import
class SCPred(nn.Module):
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
super(SCPred, self).__init__()
self.norm_s0 = nn.LayerNorm(d_msa)
self.norm_si = nn.LayerNorm(d_state)
self.linear_s0 = nn.Linear(d_msa, d_hidden)
self.linear_si = nn.Linear(d_state, d_hidden)
# ResNet layers
self.linear_1 = nn.Linear(d_hidden, d_hidden)
self.linear_2 = nn.Linear(d_hidden, d_hidden)
self.linear_3 = nn.Linear(d_hidden, d_hidden)
self.linear_4 = nn.Linear(d_hidden, d_hidden)
# Final outputs
self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.linear_s0 = init_lecun_normal(self.linear_s0)
self.linear_si = init_lecun_normal(self.linear_si)
self.linear_out = init_lecun_normal(self.linear_out)
nn.init.zeros_(self.linear_s0.bias)
nn.init.zeros_(self.linear_si.bias)
nn.init.zeros_(self.linear_out.bias)
# right before relu activation: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_1.bias)
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_3.bias)
# right before residual connection: zero initialize
nn.init.zeros_(self.linear_2.weight)
nn.init.zeros_(self.linear_2.bias)
nn.init.zeros_(self.linear_4.weight)
nn.init.zeros_(self.linear_4.bias)
def forward(self, seq, state):
'''
Predict side-chain torsion angles along with backbone torsions
Inputs:
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
Outputs:
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
'''
B, L = seq.shape[:2]
seq = self.norm_s0(seq)
state = self.norm_si(state)
si = self.linear_s0(seq) + self.linear_si(state)
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
si = self.linear_out(F.relu_(si))
return si.view(B, L, NTOTALDOFS, 2)

View File

@@ -0,0 +1,62 @@
from rf2aa.util_module import rbf, init_lecun_normal
import torch
import torch.nn as nn
from opt_einsum import contract as einsum
class StructureBias(torch.nn.Module):
def __init__(self, d_rbf, d_pair) -> None:
super(StructureBias, self).__init__()
self.proj_rbf = nn.Linear(d_rbf, d_pair)
self.reset_parameter()
def reset_parameter(self):
self.proj_rbf = init_lecun_normal(self.proj_rbf)
nn.init.zeros_(self.proj_rbf.bias)
def forward(self, xyz):
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas))
bias = self.proj_rbf(rbf_feat)
return bias
class GatedStructureBias(torch.nn.Module):
def __init__(self, d_rbf, d_state, d_pair, d_hidden_gate) -> None:
super(GatedStructureBias, self).__init__()
self.norm_state = nn.LayerNorm(d_state)
self.proj_rbf = nn.Linear(d_rbf, d_pair)
self.proj_left = nn.Linear(d_state, d_hidden_gate)
self.proj_right = nn.Linear(d_state, d_hidden_gate)
self.to_gate = nn.Linear(d_hidden_gate*d_hidden_gate, d_pair)
self.reset_parameter()
def reset_parameter(self):
pass
def forward(self, xyz, state):
B, L = xyz.shape[:2]
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas))
rbf_feat = self.proj_rbf(rbf_feat)
state = self.norm_state(state)
left = self.proj_left(state)
right = self.proj_right(state)
gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1)
gate = torch.sigmoid(self.to_gate(gate))
rbf_feat = gate*rbf_feat
return rbf_feat
structure_bias_factory = {
"ungated": StructureBias,
"gated": GatedStructureBias
}

92
rf2aa/model/network.py Normal file
View File

@@ -0,0 +1,92 @@
import hydra
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from rf2aa.debug import debug_nans
from rf2aa.model.embedding_blocks import embedding_factory
from rf2aa.model.refinement_blocks import refinement_factory
from rf2aa.model.simulator_blocks import block_factory
from rf2aa.model.layers.AuxiliaryPredictor import aux_predictor_factory
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util import is_atom
class RosettaFold(nn.Module):
""" creates an instance of RosettaFold which includes an embedder, trunk, refinement layers and aux predictor"""
def __init__(self, config):
super(RosettaFold, self).__init__()
model_params = config.model
assert len(model_params.embedding.keys()) == 1, "only can have one embedder"
embedding_type = next(iter(model_params.embedding.keys()))
self.embedding = embedding_factory[embedding_type](model_params["global_params"], model_params.embedding[embedding_type]["params"])
## instantiate blocks of network
blocks = []
for block in model_params.blocks.keys():
if block not in block_factory:
raise ValueError(f"User specified {block} type, but this block is not registered in rf2aa.Trunk_blocks.")
blocks_to_add = [block_factory[block](
global_config=model_params["global_params"],
block_params=model_params.blocks[block]["params"])
for _ in range(model_params.blocks[block]["num_blocks"])]
blocks.extend(blocks_to_add)
self.simulator = nn.ModuleList(blocks)
assert len(model_params.refinement.keys()) == 1, "only can have one refinment block"
refinement_type = next(iter(model_params.refinement.keys()))
self.refinement = refinement_factory[refinement_type](
model_params["global_params"],
model_params.refinement[refinement_type]["params"]
)
aux_tasks = {}
for aux_task in model_params.auxiliary_predictors.keys():
aux_tasks.update({
aux_task:
aux_predictor_factory[aux_task](
model_params.auxiliary_predictors[aux_task]["n_feat"])
}
) #HACK: eventually this will just use the correct n_feat from the global config
self.auxiliary_predictors = nn.ModuleDict(aux_tasks)
self.auxiliary_predictor_input_feats = {
aux_task:model_params.auxiliary_predictors[aux_task]["input_feature"] \
for aux_task in model_params.auxiliary_predictors.keys()
}
def forward(self, rf_inputs, use_checkpoint):
latent_feats = self.embedding(rf_inputs)
#load useful primitives into latent_features
latent_feats.update(
{
"is_atom": is_atom(rf_inputs["seq_unmasked"]),
"atom_frames": rf_inputs["atom_frames"],
"chirals": rf_inputs["chirals"],
"xyz": rf_inputs["xyz"]
}
)
for block in self.simulator:
latent_feats = block(latent_feats,use_checkpoint)
rf_outputs = self.refinement(latent_feats)
for aux_task, aux_predictor in self.auxiliary_predictors.items():
input_feature = self.auxiliary_predictor_input_feats[aux_task]
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
rf_outputs.update({aux_task: auxiliary_predictions})
return rf_outputs, latent_feats
@hydra.main(version_base=None, config_path='../config/train', config_name='base')
def main(config):
model = RosettaFold(config)
import pdb; pdb.set_trace()
if __name__ =="__main__":
main()

View File

@@ -0,0 +1,172 @@
import torch
import torch.nn as nn
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors
from rf2aa.model.Track_module import SCPred
from rf2aa.util import NTOTAL, NTOTALDOFS
from rf2aa.util_module import rbf, make_topk_graph, init_lecun_normal
class LocalRefinementSE3(FullyConnectedSE3):
def __init__(self, global_config, block_params):
d_msa, d_pair, d_state = global_config.d_msa, global_config.d_pair, global_config.d_state
d_rbf, num_layers, num_channels, num_degrees, n_heads, div, \
l0_in_features, l0_out_features, l1_in_features, l1_out_features, \
num_edge_features, top_k, sc_pred_d_hidden, sc_pred_p_drop = \
block_params.d_rbf, block_params.n_se3_layers, block_params.n_se3_channels, \
block_params.n_se3_degrees, block_params.n_se3_head, block_params.n_div, \
block_params.l0_in_features, block_params.l0_out_features, \
block_params.l1_in_features, block_params.l1_out_features, \
block_params.n_se3_edge_features, block_params.top_k, \
block_params.sc_pred_d_hidden, block_params.sc_pred_p_drop
super(LocalRefinementSE3, self).__init__(d_msa,
d_pair,
d_state,
d_rbf,
num_layers,
num_channels,
num_degrees,
n_heads,
div,
l0_in_features,
l0_out_features,
l1_in_features,
l1_out_features,
num_edge_features,
sc_pred_d_hidden,
sc_pred_p_drop
)
self.top_k = top_k
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.embed_node = init_lecun_normal(self.embed_node)
self.embed_edge = init_lecun_normal(self.embed_edge)
# initialize bias to zeros
nn.init.zeros_(self.embed_node.bias)
nn.init.zeros_(self.embed_edge.bias)
nn.init.ones_(self.norm_msa.weight)
nn.init.ones_(self.norm_pair.weight)
def construct_graph(self, xyz, edge):
L = xyz.shape[1]
idx = torch.arange(L, device=edge.device)[None]
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
return G, edge_feats
#def reset_parameter(self):
## initialize weights to normal distribution
#self.embed_node = init_lecun_normal(self.embed_node)
#self.embed_edge = init_lecun_normal(self.embed_edge)
## initialize bias to zeros
#nn.init.zeros_(self.embed_node.bias)
#nn.init.zeros_(self.embed_edge.bias)
#nn.init.ones_(self.norm_msa.weight)
#nn.init.ones_(self.norm_pair.weight)
#def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
#B, N, L = msa.shape[:3]
#seq = self.norm_msa(msa[:, 0])
#pair = self.norm_pair(pair)
#node = self.embed_node(torch.cat((seq, state), dim=-1))
#node = node + self.ff_node(node)
#node = self.norm_node(node)
##NOTE: Ablating providing the positional encoding at every step
## we introduced this and I do not think it is in RF2
##neighbor = get_seqsep_protein_sm(idx, bond_feats, dist_matrix, rotation_mask)
#cas = xyz[:,:,1].contiguous()
#rbf_feat = rbf(torch.cdist(cas, cas))
#edge = torch.cat((pair, rbf_feat), dim=-1)
#edge = self.embed_edge(edge)
#edge = edge + self.ff_edge(edge)
#edge = self.norm_edge(edge)
#idx = torch.arange(L, device=edge.device)[None]
#G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
##TODO: get extra l1 feats automatically and populate the extra l1 dimension
#l1_feats = torch.cat(
#[
#get_backbone_offset_vectors(xyz, is_atom, atom_frames),
#get_chiral_vectors(xyz, chirals)
#], dim=1
#)
#shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
#state = shift["0"].reshape(B, L, -1)
#offset = shift["1"].reshape(B, L, 2, 3)
#T = offset[:,:,0,:] / 10
#R = offset[:,:,1,:] / 100.0
#Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
#qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
#v = xyz - xyz[:,:,1:2,:]
#Rout = torch.zeros((B,L,3,3), device=xyz.device)
#Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
#Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
#Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
#Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
#Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
#Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
#Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
#Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
#Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
#I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
#Rout = torch.where(is_atom.reshape(B, L, 1,1), I, Rout)
#xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
#alpha = self.sc_predictor(msa[:,0], state)
#return {
#"state": state,
#"xyz": xyz,
#"alpha": alpha
#}
class RecurrentLocalRefinement(nn.Module):
def __init__(self, global_config, block_params):
super(RecurrentLocalRefinement, self).__init__()
self.num_iterations = block_params.num_iterations
self.se3 = LocalRefinementSE3(global_config, block_params)
def _unpack_inputs(self, latent_feats):
msa, pair, state, xyz, is_atom, atom_frames, chirals = \
latent_feats["msa"], latent_feats["pair"], \
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
latent_feats["atom_frames"], latent_feats["chirals"]
return msa, pair, state, xyz, is_atom, atom_frames, chirals
def forward(self, latent_feats):
B, N, L = latent_feats["msa"].shape[:3]
xyzs = torch.full((self.num_iterations, L, 3, 3 ), torch.nan, device=latent_feats["msa"].device)
alphas = torch.full((self.num_iterations, L, NTOTALDOFS, 2), torch.nan, device=latent_feats["msa"].device)
msa, pair, state, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
for i in range(self.num_iterations):
output = self.se3(msa, pair, state, xyz, is_atom, atom_frames, chirals)
xyzs[i] = output["xyz"]
alphas[i] = output["alpha"]
latent_feats["state"] = output["state"]
return {
"xyzs": xyzs,
"alphas": alphas,
}
refinement_factory ={
"local": RecurrentLocalRefinement
}

View File

@@ -0,0 +1,324 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from functools import partial
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util_module import Dropout
class RF2_block(nn.Module):
"""
nearly faithful implementation of RF2aa blocks in new paradigm
unfaithful portions are:
- ablating adding the positional encodings as biases to the attentions
- the "is_bonded" boolean feature is no longer embedded with the edge features of the SE3 transformer
"""
def __init__(self, global_config, block_params, is_full, **kwargs):
super(RF2_block, self).__init__()
d_msa, d_msa_full, d_pair, d_state = global_config.d_msa, global_config.d_msa_full, global_config.d_pair, \
global_config.d_state
self.is_full = is_full
if self.is_full:
d_msa = d_msa_full
self.norm_pair_bias = nn.LayerNorm(d_pair)
self.norm_state_bias = nn.LayerNorm(d_state)
self.proj_state_bias = nn.Linear(d_state, d_msa)
self.msa_str_bias = structure_bias_factory["ungated"](block_params.d_rbf, d_pair)
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_pair)
self.msa_row_attn = MSARowAttentionWithBias(
d_msa=d_msa, d_pair=d_pair, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels)
if self.is_full:
self.msa_col_attn = MSAColGlobalAttention(
d_msa=d_msa,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels
)
else:
self.msa_col_attn = MSAColAttention(
d_msa=d_msa,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels
)
self.msa_transition = FeedForwardLayer(d_msa, 4, p_drop=block_params.msa_transition_drop)
# Pair update parameters
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels, \
p_drop=block_params.p_drop_outer_product)
self.structure_bias = structure_bias_factory["gated"](block_params.d_rbf, d_state, d_pair, block_params.structure_bias_gate_channels)
self.tri_mul_outgoing = TriangleMultiplication(d_pair, d_hidden=block_params.n_pair_channels, outgoing=True)
self.tri_mul_incoming = TriangleMultiplication(d_pair, d_hidden=block_params.n_pair_channels, outgoing=False)
self.pair_row_attn = BiasedAxialAttention(d_pair, d_pair, block_params.n_pair_head, block_params.n_pair_channels, p_drop=block_params.p_drop_pair, is_row=True)
self.pair_col_attn = BiasedAxialAttention(d_pair, d_pair, block_params.n_pair_head, block_params.n_pair_channels, p_drop=block_params.p_drop_pair, is_row=False)
self.pair_transition = FeedForwardLayer(d_pair, 2) # HACK: hardcoded value for transition
self.structure_attn = FullyConnectedSE3(d_msa,
d_pair,
d_state,
block_params.d_rbf,
block_params.n_se3_layers,
block_params.n_se3_channels,
block_params.n_se3_degrees,
block_params.n_se3_head,
block_params.n_div,
block_params.l0_in_features,
block_params.l0_out_features,
block_params.l1_in_features,
block_params.l1_out_features,
block_params.n_se3_edge_features,
block_params.sc_pred_d_hidden,
block_params.sc_pred_p_drop
)
def _unpack_inputs(self, latent_feats):
pair, state, xyz, is_atom, atom_frames, chirals = \
latent_feats["pair"], latent_feats["state"], \
latent_feats["xyz"], latent_feats["is_atom"], \
latent_feats["atom_frames"], latent_feats["chirals"]
if self.is_full:
msa = latent_feats["msa_full"]
else:
msa = latent_feats["msa"]
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals
def _pack_outputs(self, msa, pair, state, xyz, alpha, latent_feats):
if self.is_full:
latent_feats["msa_full"] = msa
else:
latent_feats["msa"] = msa
latent_feats["pair"] = pair
latent_feats["state"] = state
latent_feats["xyz"] = xyz
#HACK: appending to growing list, this could cause weird memory problems in pytorch
# eventually want to refactor this to make it more elegant
if "xyz_intermediate" not in latent_feats:
latent_feats["xyz_intermediate"] = [xyz]
else:
latent_feats["xyz_intermediate"].append(xyz)
if "alpha_intermediate" not in latent_feats:
latent_feats["alpha_intermediate"] = [alpha]
else:
latent_feats["alpha_intermediate"].append(alpha)
return latent_feats
def _1d_update(self, msa, pair, state, xyz):
pair = self.norm_pair_bias(pair)
pair = pair + self.msa_str_bias(xyz)
state = self.norm_state_bias(state)
state_update = self.proj_state_bias(state)
msa = msa.type_as(state_update)
msa = msa.index_add(1, torch.tensor([0,], device=state_update.device), state_update[None])
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
msa = msa + self.msa_col_attn(msa)
msa = msa + self.msa_transition(msa)
return msa
def _2d_update(self, msa, pair, state, xyz):
msa_bias = self.outer_product(msa)
pair = pair + msa_bias
str_bias = self.structure_bias(xyz, state)
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
pair = pair + self.drop_row(self.pair_row_attn(pair, str_bias))
pair = pair + self.drop_col(self.pair_col_attn(pair, str_bias))
pair = pair + self.pair_transition(pair)
return pair
def _3d_update(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
block_outputs = self.structure_attn(msa, pair, state, xyz, is_atom, atom_frames, chirals)
return block_outputs["state"], block_outputs["xyz"], block_outputs["alpha"]
def forward(self, latent_feats, use_checkpoint):
msa, pair, state, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self._1d_update), msa, pair, state, xyz)
pair = checkpoint.checkpoint(create_custom_forward(self._2d_update), msa, pair, state, xyz)
# 3D track cannot use re-entrant = False because of chiral features call to autograd
#TODO: allow this to happen since new versions of Pytorch will be using reentrant=False
state, xyz, alpha = checkpoint.checkpoint(create_custom_forward(self._3d_update), \
msa, pair, state, xyz, is_atom, atom_frames, chirals)
else:
msa= self._1d_update(msa, pair, state, xyz)
pair = self._2d_update(msa, pair, state, xyz)
state, xyz, alpha = self._3d_update(msa, pair, state, xyz, is_atom, atom_frames, chirals)
latent_feats = self._pack_outputs(msa, pair, state, xyz, alpha, latent_feats)
return latent_feats
class RF2_withgradients(nn.Module):
"""
this is an updated version of the RF2 block, without computing rotations
to allow gradients to flow through all blocks
"""
def __init__(self, global_config=None, block_params=None, is_full=False, **kwargs
) -> None:
super(RF2_withgradients, self).__init__()
d_msa, d_msa_full, d_pair = global_config.d_msa, global_config.d_msa_full, global_config.d_pair
self.is_full = is_full
if self.is_full:
d_msa = d_msa_full
self.msa_row_attn = MSARowAttentionWithBias(
d_msa=d_msa, d_pair=d_pair, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels)
if self.is_full:
self.msa_col_attn = MSAColGlobalAttention(
d_msa=d_msa,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels
)
else:
self.msa_col_attn = MSAColAttention(
d_msa=d_msa, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels
)
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_col)
self.msa_transition = FeedForwardLayer(d_msa, r_ff=4)
self.compute_structure_bias = structure_bias_factory[block_params.structure_bias_type](
d_rbf=block_params.structure_bias_channels,
d_pair=d_pair
)
self.pair_row_attn = BiasedAxialAttention(
d_pair=d_pair,
d_bias=d_pair,
n_head=block_params.n_pair_head,
d_hidden=block_params.n_pair_channels,
is_row=True
)
self.pair_col_attn = BiasedAxialAttention(
d_pair=d_pair,
d_bias=d_pair,
n_head=block_params.n_pair_head,
d_hidden=block_params.n_pair_channels,
is_row=False
)
self.tri_mult_incoming = TriangleMultiplication(
d_pair=d_pair, d_hidden=block_params.n_pair_channels, outgoing=False
)
self.tri_mult_outgoing = TriangleMultiplication(
d_pair=d_pair, d_hidden=block_params.n_pair_channels, outgoing=True
)
self.pair_transition = FeedForwardLayer(
d_pair, r_ff=4
)
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels)
self.structure_attn = FullyConnectedSE3_noR(
d_msa=d_msa,
d_pair=d_pair,
d_rbf=block_params.structure_bias_channels,
num_layers=block_params.n_se3_layers,
num_channels=block_params.n_se3_channels,
num_degrees=block_params.n_se3_degrees,
n_heads=block_params.n_se3_head,
div=block_params.n_div,
l0_in_features=block_params.l0_in_features,
l0_out_features=block_params.l0_out_features,
l1_in_features=block_params.l1_in_features,
l1_out_features=block_params.l1_out_features,
num_edge_features=block_params.n_se3_edge_features
)
self.structure_transition = FeedForwardLayer(
block_params.l0_out_features, r_ff=4
)
self.proj_state = nn.Linear(block_params.l0_out_features, d_msa)
self.reset_parameter()
def reset_parameter(self):
pass
def _unpack_inputs(self, latent_feats):
pair, xyz, is_atom, atom_frames, chirals = \
latent_feats["pair"], \
latent_feats["xyz"], latent_feats["is_atom"], \
latent_feats["atom_frames"], latent_feats["chirals"]
if self.is_full:
msa = latent_feats["msa_full"]
else:
msa = latent_feats["msa"]
return msa, pair, xyz[..., :3, :], is_atom, atom_frames, chirals
def _pack_outputs(self, msa, pair, state, xyz, latent_feats):
if self.is_full:
latent_feats["msa_full"] = msa
else:
latent_feats["msa"] = msa
latent_feats["pair"] = pair
latent_feats["state"] = state
latent_feats["xyz"] = xyz
return latent_feats
def _1d_update(self, msa, pair):
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
msa = msa + self.drop_col(self.msa_col_attn(msa))
msa = msa + self.msa_transition(msa)
msa_bias = self.outer_product(msa)
pair = pair + msa_bias
return msa, pair
def _2d_update(self, pair, xyz):
# break 3d symmetries with bias from coordinates
structure_bias = self.compute_structure_bias(xyz)
pair = pair + self.drop_row(self.pair_row_attn(pair, structure_bias))
pair = pair + self.drop_col(self.pair_col_attn(pair, structure_bias))
# provide triangle inductive bias
pair = pair + self.drop_row(self.tri_mult_outgoing(pair))
pair = pair + self.drop_row(self.tri_mult_incoming(pair))
pair = pair + self.pair_transition(pair)
return pair
def _3d_update(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
# apply structure attention and update seq features
state, xyz = self.structure_attn(msa, pair, state, xyz, is_atom, atom_frames, chirals)
state = state + self.structure_transition(state)
# state features bias the msa first row features
state_update = self.proj_state(state)
msa = msa.type_as(state_update)
msa = msa.index_add(1, torch.tensor([0,], device=state_update.device), state_update.unsqueeze(1))
return msa, state, xyz
def forward(self, latent_feats, use_checkpoint):
msa, pair, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
state = None
if use_checkpoint:
msa, pair = checkpoint.checkpoint(create_custom_forward(self._1d_update), msa, pair)
pair = checkpoint.checkpoint(create_custom_forward(self._2d_update), pair, xyz)
# 3D track cannot use re-entrant = False because of chiral features call to autograd
#TODO: allow this to happen since new versions of Pytorch will be using reentrant=False
msa, state, xyz = checkpoint.checkpoint(create_custom_forward(self._3d_update), \
msa, pair, state, xyz, is_atom, atom_frames, chirals)
else:
msa, pair = self._1d_update(msa, pair)
pair = self._2d_update(pair, xyz)
msa, state, xyz = self._3d_update(msa, pair, state, xyz, is_atom, atom_frames, chirals)
latent_feats = self._pack_outputs(msa, pair, state, xyz, latent_feats)
return latent_feats
block_factory = {
"RF2_withgradients": partial(RF2_withgradients, is_full=False),
"RF2_withgradients_full": partial(RF2_withgradients, is_full=True),
"RF2aa": partial(RF2_block, is_full=False),
"RF2aa_full": partial(RF2_block, is_full=True)
}

View File

@@ -1,72 +0,0 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
# pre-activation bottleneck resblock
class ResBlock2D_bottleneck(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D_bottleneck, self).__init__()
padding = self._get_same_padding(kernel, dilation)
n_b = n_c // 2 # bottleneck channel
layer_s = list()
# pre-activation
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# project down to n_b
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# convolution
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# project up
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
# make final layer initialize with zeros
#nn.init.zeros_(layer_s[-1].weight)
self.layer = nn.Sequential(*layer_s)
self.reset_parameter()
def reset_parameter(self):
# zero-initialize final layer right before residual connection
nn.init.zeros_(self.layer[-1].weight)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return x + out
class ResidualNetwork(nn.Module):
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
dilation=[1,2,4,8], p_drop=0.15):
super(ResidualNetwork, self).__init__()
layer_s = list()
# project to n_feat_block
if n_feat_in != n_feat_block:
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
# add resblocks
for i_block in range(n_block):
d = dilation[i_block%len(dilation)]
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
layer_s.append(res_block)
if n_feat_out != n_feat_block:
# project to n_feat_out
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
self.layer = nn.Sequential(*layer_s)
def forward(self, x):
return self.layer(x)

View File

@@ -1,180 +0,0 @@
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
#def get_cosine_with_hard_restarts_schedule_with_warmup(
# optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
#):
# """
# Create a schedule with a learning rate that decreases following the values of the cosine function between the
# initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
# linearly between 0 and the initial lr set in the optimizer.
#
# Args:
# optimizer (:class:`~torch.optim.Optimizer`):
# The optimizer for which to schedule the learning rate.
# num_warmup_steps (:obj:`int`):
# The number of steps for the warmup phase.
# num_training_steps (:obj:`int`):
# The total number of training steps.
# num_cycles (:obj:`int`, `optional`, defaults to 1):
# The number of hard restarts to use.
# last_epoch (:obj:`int`, `optional`, defaults to -1):
# The index of the last epoch when resuming training.
#
# Return:
# :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
# """
#
# def lr_lambda(current_step):
# if current_step < num_warmup_steps:
# return float(current_step) / float(max(1, num_warmup_steps))
# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
# if progress >= 1.0:
# return 0.0
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
#
# return LambdaLR(optimizer, lr_lambda, last_epoch)
#
class CosineAnnealingWarmupRestarts(_LRScheduler):
"""
optimizer (Optimizer): Wrapped optimizer.
first_cycle_steps (int): First cycle step size.
cycle_mult(float): Cycle steps magnification. Default: -1.
max_lr(float): First cycle's max learning rate. Default: 0.1.
min_lr(float): Min learning rate. Default: 0.001.
warmup_steps(int): Linear warmup step size. Default: 0.
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self,
optimizer : torch.optim.Optimizer,
first_cycle_steps : int,
cycle_mult : float = 1.,
max_lr : float = 0.1,
min_lr : float = 0.001,
warmup_steps : int = 0,
gamma : float = 1.,
last_epoch : int = -1
):
assert warmup_steps < first_cycle_steps
self.first_cycle_steps = first_cycle_steps # first cycle step size
self.cycle_mult = cycle_mult # cycle steps magnification
self.base_max_lr = max_lr # first max learning rate
self.max_lr = max_lr # max learning rate in the current cycle
self.min_lr = min_lr # min learning rate
self.warmup_steps = warmup_steps # warmup step size
self.gamma = gamma # decrease rate of max learning rate by cycle
self.cur_cycle_steps = first_cycle_steps # first cycle step size
self.cycle = 0 # cycle count
self.step_in_cycle = last_epoch # step size of the current cycle
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
# set learning rate min_lr
self.init_lr()
def init_lr(self):
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
self.base_lrs.append(self.min_lr)
def get_lr(self):
if self.step_in_cycle == -1:
return self.base_lrs
elif self.step_in_cycle < self.warmup_steps:
return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.max_lr - base_lr) \
* (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
/ (self.cur_cycle_steps - self.warmup_steps))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.step_in_cycle = self.step_in_cycle + 1
if self.step_in_cycle >= self.cur_cycle_steps:
self.cycle += 1
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
else:
if epoch >= self.first_cycle_steps:
if self.cycle_mult == 1.:
self.step_in_cycle = epoch % self.first_cycle_steps
self.cycle = epoch // self.first_cycle_steps
else:
n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
self.cycle = n
self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
else:
self.cur_cycle_steps = self.first_cycle_steps
self.step_in_cycle = epoch
self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_ratio=0.001, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
min_ratio, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_stepwise_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_steps_decay, decay_rate, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_fades = (current_step-num_warmup_steps)//num_steps_decay
return (decay_rate**num_fades)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -0,0 +1,218 @@
import sys
sys.path.append('../')
import argparse
import numpy as np
import matplotlib.pyplot as plt
from torch.utils import data
from tqdm import tqdm
from rf2aa.data.data_loader import default_dataloader_params, loader_tf_complex, loader_distil_tf, get_train_valid_set, DistilledDataset, loader_pdb, loader_complex, loader_na_complex, loader_fb, loader_dna_rna, loader_sm_compl_assembly_single, loader_sm_compl_assembly, loader_sm, loader_atomize_pdb, loader_atomize_complex
from rf2aa.chemical import load_pdb_ideal_sdf_strings
import random
import numpy as np
import torch
import pickle
class OrderedSampler(torch.utils.data.sampler.Sampler):
"""
Custom sampler that samples specific indices from a dataset.
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def parse_arguments():
parser = argparse.ArgumentParser(description='Process crop size and MSA limit.')
parser.add_argument('--crop-size', type=int, help='Crop size for the data loader', default=256)
parser.add_argument('--msa-limit', type=int, help='MSA limit for the data loader', default=None)
parser.add_argument('--output', type=str, help='Output file path for the lengths', default='sample_lengths.pt')
parser.add_argument('--num-samples', type=int, help='Number of samples for the loop', default=None)
parser.add_argument('--num-workers', type=int, help='Number of DataLoader workers', default=4)
return parser.parse_args()
def seed_everything():
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def build_dataset(crop_size, msa_limit):
"""
Build a DistilledDataset with the given crop size and MSA limit overwriting the default parameters.
"""
# Load default parameters
loader_param = default_dataloader_params
# Override default parameters
loader_param['CROP'] = crop_size or loader_param.get('CROP')
loader_param['MSA_LIMIT'] = msa_limit or loader_param.get('MSA_LIMIT')
# Build dataset
(
train_ID_dict,
valid_ID_dict,
weights_dict,
train_dict,
valid_dict,
homo,
chid2hash,
chid2taxid,
chid2smpartners,
) = get_train_valid_set(loader_param)
# define atomize_pdb train/valid sets, which use the same examples as pdb set
train_ID_dict['atomize_pdb'] = train_ID_dict['pdb']
valid_ID_dict['atomize_pdb'] = valid_ID_dict['pdb']
weights_dict['atomize_pdb'] = weights_dict['pdb']
train_dict['atomize_pdb'] = train_dict['pdb']
valid_dict['atomize_pdb'] = valid_dict['pdb']
# define atomize_pdb train/valid sets, which use the same examples as pdb set
train_ID_dict['atomize_complex'] = train_ID_dict['compl']
valid_ID_dict['atomize_complex'] = valid_ID_dict['compl']
weights_dict['atomize_complex'] = weights_dict['compl']
train_dict['atomize_complex'] = train_dict['compl']
valid_dict['atomize_complex'] = valid_dict['compl']
# Assign loaders to each dataset
loader_dict = dict(
pdb = loader_pdb,
peptide = loader_pdb,
compl = loader_complex,
neg_compl = loader_complex,
na_compl = loader_na_complex,
neg_na_compl = loader_na_complex,
distil_tf = loader_distil_tf,
tf = loader_tf_complex,
neg_tf = loader_tf_complex,
fb = loader_fb,
rna = loader_dna_rna,
dna = loader_dna_rna,
sm_compl = loader_sm_compl_assembly_single,
metal_compl = loader_sm_compl_assembly_single,
sm_compl_multi = loader_sm_compl_assembly_single,
sm_compl_covale = loader_sm_compl_assembly_single,
sm_compl_asmb = loader_sm_compl_assembly,
sm = loader_sm,
atomize_pdb = loader_atomize_pdb,
atomize_complex = loader_atomize_complex,
sm_compl_furthest_neg = loader_sm_compl_assembly,
sm_compl_permuted_neg = loader_sm_compl_assembly,
sm_compl_docked_neg = loader_sm_compl_assembly,
)
# Get ligand dictionary. This is used for loading negative examples.
ligand_dictionary = load_pdb_ideal_sdf_strings(return_only_sdf_strings=True)
# Build dataset
train_set = DistilledDataset(
train_ID_dict,
train_dict,
loader_dict,
homo,
chid2hash,
chid2taxid,
chid2smpartners,
loader_param,
native_NA_frac=0.25,
ligand_dictionary=ligand_dictionary
)
return train_set
def main(crop_size, msa_limit, output_file, num_samples, num_workers):
print(f"Building dataset with CROP_SIZE={crop_size} and MSA_LIMIT={msa_limit}...")
# Setup the dataset
train_set = build_dataset(crop_size, msa_limit)
# Datasets where we will assume `LEN_EXIST` returns the appropriate length
fixed_length_datasets = ['pdb', 'fb']
# Datasets where we pass the sample through the DataLoader to measure the length
variable_length_datasets = ["compl", "neg_compl", "na_compl", "neg_na_compl", "distil_tf","tf","neg_tf","rna","dna", "sm_compl", "metal_compl", "sm_compl_multi", "sm_compl_covale", "sm_compl_asmb", "sm", "sm_compl_docked_neg", "sm_compl_permuted_neg", "sm_compl_furthest_neg", "atomize_pdb", "atomize_complex"]
# Assert that the monomer and variable dataset lists do not overlap
assert len(set(fixed_length_datasets) & set(variable_length_datasets)) == 0, "Monomer and variable datasets should not overlap"
# Create a tensor of zeros of the same length as the dataset
lengths = torch.zeros(len(train_set))
offset = 0
indices_to_process = []
for key, index_list in train_set.index_dict.items():
# Add offset to every value in list
adjusted_index_list = [x + offset for x in index_list]
# Either add index to list to process later or calculate length directly
if key in variable_length_datasets:
indices_to_process.extend(adjusted_index_list)
elif key in fixed_length_datasets:
# Calculate lengths and remove duplicates
if key in ['pdb', 'fb']:
train_set.dataset_dict[key]['LENGTH'] = train_set.dataset_dict[key]["SEQUENCE"].apply(len)
else:
train_set.dataset_dict[key]['LENGTH'] = train_set.dataset_dict[key]["LEN_EXIST"]
fixed_lengths = train_set.dataset_dict[key].drop_duplicates(subset=['CLUSTER'])
# Create a dictionary for mapping cluster ID to length
cluster_length_map = dict(zip(fixed_lengths['CLUSTER'], fixed_lengths['LENGTH']))
# Map the cluster ID to the index in the dataset
ids = train_set.ID_dict[key]
fixed_lengths_processed = [cluster_length_map.get(id) for id in ids]
# Replace the indices within lengths corresponding to adjusted_index_list with the fixed lengths
lengths[adjusted_index_list] = torch.tensor(fixed_lengths_processed, dtype=torch.float)
offset += len(index_list)
sampler = None
if num_samples:
print(f"Sampling {num_samples} examples...")
np.random.shuffle(indices_to_process)
indices_to_process = indices_to_process[:num_samples]
sampler = OrderedSampler(indices_to_process)
else:
num_samples = len(indices_to_process)
sampler = OrderedSampler(indices_to_process)
# Create the training loader
dataloader_kwargs = {
"shuffle": False,
"num_workers": num_workers,
"pin_memory": False,
"batch_size": 1,
}
print("num_workers:", dataloader_kwargs["num_workers"])
train_loader = data.DataLoader(train_set, sampler=sampler, **dataloader_kwargs)
print(f"Looping through {num_samples} batches.")
print("First ten indices:", indices_to_process[:10])
print("Last ten indices:", indices_to_process[-10:])
for batch_idx, batch in tqdm(enumerate(train_loader), total=num_samples, desc="Processing batches"):
try:
# Get the N_residues from the first tensor, which is the last dimension
length = batch[0].shape[-1]
# Store the length in a slot corresponding to the absolute index of the example
lengths[indices_to_process[batch_idx]] = length
except Exception as e:
print(f"An error occurred while processing batch {batch_idx}: {e}")
# Break if limiting examples
if batch_idx >= num_samples - 1:
break
# Save length tensor to pickle
print(f"Saving lengths as a tensor to {output_file}...")
torch.save(lengths, output_file)
if __name__ == "__main__":
args = parse_arguments()
seed_everything()
main(args.crop_size, args.msa_limit, args.output, args.num_samples, args.num_workers)

View File

@@ -1,91 +0,0 @@
import os
import argparse
import submitit
from typing import Dict, Optional
def add_slurm_args(parser: Optional[argparse.ArgumentParser] = None, prefix: str = "-") -> argparse.ArgumentParser:
if parser is None:
parser = argparse.ArgumentParser(
description="Submits a job to the digs cluster via submitit"
)
parser.add_argument(
f"{prefix}slurm_log_path",
default="/home/psturm/RF2-allatom/slurm_logs",
type=str,
help="Path where slurm logs will go."
)
parser.add_argument(
f"{prefix}local",
action="store_true",
help="Set to true to train locally rather than submitting to slurm",
)
parser.add_argument(
f"{prefix}slurm_partition",
type=str,
default="gpu",
help="Slurm partition to run job on",
)
parser.add_argument(
f"{prefix}gpu_type",
type=str,
default="a6000",
help="Which gpus to run on, slurm gres constraint",
)
parser.add_argument(
f"{prefix}cpu_memory",
type=int,
default=64,
help="Amount of cpu job memory to request for slurm submission",
)
parser.add_argument(
f"{prefix}cpus_per_task",
type=int,
default=4,
help="Number of cpu cores to request for slurm submission",
)
parser.add_argument(
f"{prefix}timeout_min",
type=int,
default=10000,
help="Maximum number of minutes for slurm job to run",
)
parser.add_argument(
f"{prefix}max_slurm_jobs_at_once",
type=int,
default=16,
help="Maximum number of array jobs to run at once",
)
parser.add_argument(
f"{prefix}num_gpus",
type=int,
default=4,
help="Number of GPUs to train on"
)
parser.add_argument(
f"{prefix}nodes",
type=int,
default=1,
help="Number of nodes to submit to"
)
return parser
def create_executor(args: Dict, log_folder: str, job_name: str) -> submitit.AutoExecutor:
executor = submitit.AutoExecutor(folder=log_folder)
executor.update_parameters(
slurm_partition=args.slurm_partition,
slurm_mem=f"{args.cpu_memory}gb",
slurm_job_name=job_name,
cpus_per_task=args.cpus_per_task,
slurm_ntasks_per_node=args.num_gpus,
slurm_array_parallelism=args.max_slurm_jobs_at_once,
nodes=args.nodes,
timeout_min=args.timeout_min,
)
if args.gpu_type != "none":
executor.update_parameters(
slurm_gres=f"gpu:{args.gpu_type}:{args.num_gpus}",
)
return executor

220
rf2aa/test_inference.py Normal file
View File

@@ -0,0 +1,220 @@
import torch
import numpy as np
import unittest, random, os, sys
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(script_dir)
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
import rf2aa.data.parsers as parsers
import rf2aa.util as util
from rf2aa.chemical import load_pdb_ideal_sdf_strings, NTOTALDOFS
from rf2aa.data.data_loader import MSAFeaturize, TemplFeaturize, \
center_and_realign_missing, generate_xyz_prev, get_bond_distances
from rf2aa.data.compose_dataset import default_dataloader_params
from rf2aa.kinematics import get_chirals, xyz_to_t2d
from rf2aa.tensor_util import assert_equal, cmp
MAXLAT=256
MAXSEQ=2048
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 4,
"d_msa" : 256,
"d_pair" : 192,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 6,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.0,
"lj_lin" : 0.7,
'symmetrize_repeats': False,
'repeat_length': float('nan'),
'symmsub_k': float('nan'),
'sym_method': float('nan'),
'main_block': float('nan'),
'copy_main_block_template': False
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
SE3_ref_param = {
"num_layers" : 2,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
params = default_dataloader_params
def make_deterministic(seed=0):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def featurize_input(three_letter="HEM", pickle_input=None):
"""
make dummy features for the model
"""
if pickle_input is not None:
if os.path.exists(pickle_input):
rf_input = torch.load(pickle_input)
return rf_input
ligands = load_pdb_ideal_sdf_strings(return_only_sdf_strings=True)
sdf = ligands[three_letter]
mol, msa_sm, ins_sm, xyz_sm, mask_sm = parsers.parse_mol(sdf, filetype="sdf", string=True)
a3m_sm = {"msa": msa_sm.unsqueeze(0), "ins": ins_sm.unsqueeze(0)}
G = util.get_nxgraph(mol)
N_symmetry, sm_L, _ = xyz_sm.shape
Ls = [ sm_L]
a3m = a3m_sm
msa = a3m['msa'].long()
ins = a3m['ins'].long()
bond_feats = util.get_bond_feats(mol)
chirals = get_chirals(mol, xyz_sm[0])
atom_frames = util.get_atom_frames(msa_sm, G)
idx = torch.arange(sm_L)
same_chain = torch.ones((sm_L, sm_L)).long()
dist_matrix = get_bond_distances(bond_feats)
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins,
p_mask=0.0, params={'MAXLAT': MAXLAT, 'MAXSEQ': MAXSEQ, 'MAXCYCLE': 1}, tocpu=True)
xyz_t, f1d_t, mask_t, _ = TemplFeaturize({"ids":[]}, sm_L, params, offset=0,
npick=0, pick_top=True)
xyz_t_frames = util.xyz_t_to_frame_xyz(xyz_t[None], seq, atom_frames[None])
mask_t_2d = util.get_prot_sm_mask(mask_t, seq[0])[None] # (B, T, L)
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
t2d = xyz_to_t2d(xyz_t_frames, mask_t_2d[None])
alpha_t = torch.zeros(1, sum(Ls), NTOTALDOFS*3)
ntempl = xyz_t.shape[0]
xyz_t = torch.stack(
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
)
L = sum(Ls)
xyz_prev, _ = generate_xyz_prev(xyz_t, mask_t, params)
alpha_prev = torch.zeros((1,L,NTOTALDOFS,2))
rf_input = {
"msa_clust": msa_seed,
"msa_extra": msa_extra,
"seq": seq,
"seq_unmasked": msa_seed_orig[:, 0],
"xyz_prev": xyz_prev[None],
"alpha_prev": alpha_prev,
"idx_pdb": idx[None],
"bond_feats": bond_feats[None],
"dist_matrix": dist_matrix[None],
"chirals": chirals[None],
"atom_frames": atom_frames[None],
"t1d": f1d_t[None],
"t2d": t2d[0],
"xyz_t": xyz_t[...,1,:][None],
"alpha_t": alpha_t[None],
"mask_t": mask_t_2d,
"same_chain": same_chain[None],
"msa_prev": None,
"pair_prev": None,
"state_prev": None,
"mask_recycle": None
}
if pickle_input is not None:
torch.save(rf_input, pickle_input)
return rf_input
class FoldAndDock3TestCase(unittest.TestCase):
def setUp(self):
super(FoldAndDock3TestCase, self).__init__()
self.name = "fd3_inference"
pickle_input = str(os.path.join("test_pickles", f"{self.name}_input"))
#MODEL_PARAM['use_extra_l1'] = True
#MODEL_PARAM['use_atom_frames'] = True
# refactored model param allowing for backwards compatibility
MODEL_PARAM['use_chiral_l1'] = True
MODEL_PARAM['use_lj_l1'] = True
MODEL_PARAM['use_atom_frames'] = True
MODEL_PARAM['use_same_chain'] = True
MODEL_PARAM['recycling_type'] = 'all'
self.model = LegacyRoseTTAFoldModule(
**MODEL_PARAM,
aamask = util.allatom_mask,
atom_type_index = util.atom_type_index,
ljlk_parameters = util.ljlk_parameters,
lj_correction_parameters = util.lj_correction_parameters,
num_bonds = util.num_bonds,
cb_len = util.cb_length_t,
cb_ang = util.cb_angle_t,
cb_tor = util.cb_torsion_t,
)
self.rf_input = featurize_input(pickle_input=pickle_input)
def test_inference(self):
make_deterministic()
out = self.model(
msa_latent = self.rf_input["msa_clust"].float(),
msa_full = self.rf_input["msa_extra"].float(),
seq = self.rf_input["seq"],
seq_unmasked = self.rf_input["seq_unmasked"],
xyz = self.rf_input["xyz_prev"],
sctors = self.rf_input["alpha_prev"],
idx = self.rf_input["idx_pdb"],
bond_feats = self.rf_input["bond_feats"],
dist_matrix = self.rf_input["dist_matrix"],
chirals = self.rf_input["chirals"],
atom_frames = self.rf_input["atom_frames"],
t1d = self.rf_input["t1d"],
t2d = self.rf_input["t2d"],
xyz_t = self.rf_input["xyz_t"],
alpha_t = self.rf_input["alpha_t"],
mask_t = self.rf_input["mask_t"],
same_chain = self.rf_input["same_chain"],
msa_prev = self.rf_input["msa_prev"],
pair_prev = self.rf_input["pair_prev"],
state_prev = self.rf_input["state_prev"],
mask_recycle = self.rf_input["mask_recycle"]
)
output_names = ("logits_c6d", "logits_aa", "logits_pae", \
"logits_pde", "p_bind", "xyz", "alpha", "xyz_allatom", \
"lddt", "seq", "pair", "state")
output_dict = dict(zip(output_names, out))
test_out_path = os.path.join("test_pickles", f"{self.name}_out")
if not os.path.exists(test_out_path):
torch.save(output_dict, test_out_path)
print(f"saved output at {test_out_path}")
else:
reference_output_dict = torch.load(test_out_path)
for output in output_names:
want = reference_output_dict[output]
got = output_dict[output]
if output == "logits_c6d":
want = want[0]
got = got[0]
print(output)
cmp(got, want)
if __name__ == "__main__":
unittest.main()

Binary file not shown.

Binary file not shown.

357
rf2aa/trainer_new.py Normal file
View File

@@ -0,0 +1,357 @@
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
import hydra
import os
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
from rf2aa.data.data_loader import loader_atomize_pdb
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items
from rf2aa.debug import debug_unused_params, debug_used_params
from rf2aa.training.EMA import EMA, count_parameters
from rf2aa.loss.loss_factory import get_loss_and_misc
from rf2aa.training.optimizer import add_weight_decay
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_sampling
from rf2aa.model.network import RosettaFold
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
from rf2aa.util import frame_indices, long2alt, allatom_mask, num_bonds, \
atom_type_index, ljlk_parameters, lj_correction_parameters, hbtypes, \
hbbaseatoms, hbpolys, cb_length_t, cb_angle_t, cb_torsion_t
import rf2aa.util as util
from rf2aa.util_module import XYZConverter
#TODO: control environment variables from config
# limit thread counts
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
## To reproduce errors
import random
def seed_all(seed=0):
random.seed(0)
torch.manual_seed(5924)
torch.cuda.manual_seed(5924)
np.random.seed(6636)
torch.set_num_threads(4)
#torch.autograd.set_detect_anomaly(True)
class Trainer:
def __init__(self, config) -> None:
self.config = config
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
if self.config.experiment.output_dir is not None:
self.output_dir = self.config.experiment.output_dir
else:
self.output_dir = "models/"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def construct_model(self):
raise NotImplementedError()
def construct_optimizer(self):
if self.config.training_params.weight_decay is not None:
opt_params = add_weight_decay(self.model, self.config.training_params.weight_decay)
else:
opt_params = self.model.parameters()
self.optimizer = torch.optim.AdamW(opt_params, lr=self.config.training_params.learning_rate)
def construct_scheduler(self):
self.scheduler = get_stepwise_decay_schedule_with_warmup(self.optimizer, \
**self.config.training_params.learning_rate_schedule)
def construct_scaler(self):
self.scaler = torch.cuda.amp.GradScaler(enabled=self.config.training_params.use_amp)
def load_checkpoint(self, rank):
if self.config.training_params.resume_train:
checkpoint_path = f"{self.output_dir}/{self.config.experiment.name}_last.pt"
elif self.config.eval_params.checkpoint_path:
checkpoint_path = self.config.eval_params.checkpoint_path
map_location = {"cuda:0": f"cuda:{rank}"}
self.checkpoint = torch.load(checkpoint_path, map_location=map_location)
print(f"Loading checkpoint from {checkpoint_path} on rank:{rank}")
def load_model(self):
torch.cuda.empty_cache()
if self.config.training_params.resume_train is None:
raise ValueError("Should not load model when resume_train is True")
#TODO: check if model should load the final state dict and not the EMA
self.model.module.model.load_state_dict(self.checkpoint["final_state_dict"], strict=True)
self.model.module.shadow.load_state_dict(self.checkpoint["model_state_dict"], strict=False)
print("Checkpoint loaded into model")
def load_optimizer(self):
if self.config.training_params.resume_train is None:
raise ValueError("Should not load optimizer when resume_train is True")
self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
def load_scheduler(self):
if self.config.training_params.resume_train is None:
raise ValueError("Should not load scheduler when resume_train is True")
self.scheduler.load_state_dict(self.checkpoint['scheduler_state_dict'])
def load_scaler(self):
if self.config.training_params.resume_train is None:
raise ValueError("Should not load scaler when resume_train is True")
self.scaler.load_state_dict(self.checkpoint['scaler_state_dict'])
def construct_dataset(self, rank, world_size):
return compose_dataset(self.config.dataset_params, self.config.loader_params, rank, world_size)
def construct_loss_function(self):
raise NotImplementedError()
def move_constants_to_device(self, gpu):
self.fi_dev = frame_indices.to(gpu)
self.xyz_converter = XYZConverter().to(gpu)
self.l2a = long2alt.to(gpu)
self.aamask = allatom_mask.to(gpu)
self.num_bonds = num_bonds.to(gpu)
self.atom_type_index = atom_type_index.to(gpu)
self.ljlk_parameters = ljlk_parameters.to(gpu)
self.lj_correction_parameters = lj_correction_parameters.to(gpu)
self.hbtypes = hbtypes.to(gpu)
self.hbbaseatoms = hbbaseatoms.to(gpu)
self.hbpolys = hbpolys.to(gpu)
self.cb_len = cb_length_t.to(gpu)
self.cb_ang = cb_angle_t.to(gpu)
self.cb_tor = cb_torsion_t.to(gpu)
def checkpoint_model(self, epoch, metadata={}):
checkpoint_data = {
'epoch' : epoch,
'model_state_dict' : self.model.module.shadow.state_dict(),
'final_state_dict' : self.model.module.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'scaler_state_dict' : self.scaler.state_dict(),
'training_config' : dict(self.config),
}
checkpoint_data.update(metadata)
torch.save(checkpoint_data, f"{self.output_dir}/{self.config.experiment.name}_{epoch}.pt")
def launch_distributed_training(self):
world_size = torch.cuda.device_count()
if ('MASTER_ADDR' not in os.environ):
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
if ('MASTER_PORT' not in os.environ):
os.environ['MASTER_PORT'] = '%d'%self.config.ddp_params.port
if ("SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
world_size = int(os.environ["SLURM_NTASKS"])
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
#mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
else:
print ("Launched from interactive")
world_size = torch.cuda.device_count()
if world_size == 1:
# No need for multiple processes with 1 GPU
self.train_model(0, world_size)
else:
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
def init_process_group(self, rank, world_size):
gpu = rank % torch.cuda.device_count()
dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)
torch.cuda.set_device("cuda:%d"%gpu)
return gpu
def cleanup(self):
dist.destroy_process_group()
def train_model(self, rank, world_size):
""" runs model training on each gpu """
gpu = self.init_process_group(rank, world_size)
train_loader, train_sampler, valid_loaders, valid_sampler = self.construct_dataset(rank, world_size)
self.train_loader = train_loader
# move global information to device
self.move_constants_to_device(gpu)
self.construct_model(device=gpu)
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
if rank == 0:
print(f"Loading model with {count_parameters(self.model)} parameters")
self.construct_optimizer()
self.construct_scheduler()
self.construct_scaler()
if self.config.training_params.resume_train:
self.load_checkpoint(rank)
self.load_model()
self.load_optimizer()
self.load_scheduler()
self.load_scaler()
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
self.config.experiment.n_epoch,
self.config.dataset_params.n_train,
world_size)
for epoch in range(self.config.experiment.n_epoch):
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
self.train_epoch(epoch, rank)
self.cleanup()
def train_epoch(self, epoch, rank):
""" train model """
# turn on gradients
self.model.train()
# clear gradients
self.optimizer.zero_grad()
for train_idx, inputs in enumerate(self.train_loader):
n_cycle = self.recycle_schedule[epoch, train_idx] # number of recycling
# run forward pass and compute loss
loss, loss_dict = self.train_step(inputs, n_cycle)
# aggregate loss and update parameters
loss = loss / self.config.ddp_params.accum
self.scaler.scale(loss).backward()
if train_idx%self.config.ddp_params.accum == 0:
self.update_parameters()
if train_idx % self.config.log_params.log_every_n_examples == 0 and rank == 0:
self.log_intermediate_losses(inputs, loss_dict, n_cycle)
torch.cuda.empty_cache()
if rank == 0:
self.checkpoint_model(epoch)
def train_step(self, inputs, n_cycle):
""" take an input from dataloader, run the model and compute a loss """
raise NotImplementedError()
def update_parameters(self):
""" scale, clip gradients and update parameters """
# gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_params.grad_clip)
self.scaler.step(self.optimizer)
scale = self.scaler.get_scale()
self.scaler.update()
skip_lr_sched = (scale != self.scaler.get_scale())
self.optimizer.zero_grad()
if not skip_lr_sched:
self.scheduler.step()
self.model.module.update() # apply EMA
def log_intermediate_losses(self, inputs, loss_dict, n_cycle):
item = inputs[-1]
max_mem = torch.cuda.max_memory_allocated()/1e9
print(f"Example: {item} Max Memory:{max_mem} Recycle:{n_cycle}\n"+
"\t".join([f"{k}:{v}" for k,v in loss_dict.items()]))
torch.cuda.reset_peak_memory_stats()
class LegacyTrainer(Trainer):
""" trains Legacy versions of RFAA """
def __init__(self, config) -> None:
super().__init__(config)
def construct_model(self, device="cpu"):
self.model = LegacyRoseTTAFoldModule(
**self.config.legacy_model_param,
aamask = util.allatom_mask.to(device),
atom_type_index = util.atom_type_index.to(device),
ljlk_parameters = util.ljlk_parameters.to(device),
lj_correction_parameters = util.lj_correction_parameters.to(device),
num_bonds = util.num_bonds.to(device),
cb_len = util.cb_length_t.to(device),
cb_ang = util.cb_angle_t.to(device),
cb_tor = util.cb_torsion_t.to(device),
).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
def train_step(self, inputs, n_cycle):
""" take an input from dataloader, run the model and compute a loss """
gpu = self.model.device
# HACK: certain features are constructed during the train step
# in the future this should only promote the constructed features onto gpu
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input(inputs, self.xyz_converter, gpu)
output_i = recycle_step_legacy(self.model, network_input, n_cycle, self.config.training_params.use_amp)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
msa = msa.to(gpu)
mask_msa = mask_msa.to(gpu)
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
return loss, loss_dict
class ComposedTrainer(Trainer):
""" trains composed versions of RFAA """
def __init__(self, config) -> None:
super().__init__(config)
def construct_model(self, device="cpu"):
self.model = RosettaFold(self.config).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
def train_step(self, inputs, n_cycle):
""" take an input from dataloader, run the model and compute a loss """
gpu = self.model.device
# HACK: certain features are constructed during the train step
# in the future this should only promote the constructed features onto gpu
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input(inputs, self.xyz_converter, gpu)
output_i = recycle_step_packed(self.model, network_input, n_cycle, self.config.training_params.use_amp)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
msa = msa.to(gpu)
mask_msa = mask_msa.to(gpu)
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
return loss, loss_dict
@hydra.main(version_base=None, config_path='config/train')
def main(config):
seed_all()
trainer = trainer_factory[config.experiment.trainer](config=config)
trainer.launch_distributed_training()
trainer_factory = {
"legacy": LegacyTrainer,
"composed": ComposedTrainer,
}
if __name__ == "__main__":
main()

55
rf2aa/training/EMA.py Normal file
View File

@@ -0,0 +1,55 @@
import torch
import torch.nn as nn
from collections import OrderedDict
from copy import deepcopy
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
self.decay = decay
self.model = model
self.shadow = deepcopy(self.model)
for param in self.shadow.parameters():
param.detach_()
@torch.no_grad()
def update(self):
if not self.training:
print("EMA update should only be called during training", file=stderr, flush=True)
return
model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
if param.requires_grad:
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
model_buffers = OrderedDict(self.model.named_buffers())
shadow_buffers = OrderedDict(self.shadow.named_buffers())
# check if both model contains the same set of keys
assert model_buffers.keys() == shadow_buffers.keys()
for name, buffer in model_buffers.items():
# buffers are copied
shadow_buffers[name].copy_(buffer)
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
else:
return self.shadow(*args, **kwargs)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

View File

@@ -0,0 +1,5 @@
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward

View File

@@ -0,0 +1,11 @@
def add_weight_decay(model, l2_coeff):
decay, no_decay = [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
#if len(param.shape) == 1 or name.endswith(".bias"):
if "norm" in name or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}]

136
rf2aa/training/recycling.py Normal file
View File

@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import numpy as np
from contextlib import ExitStack
from rf2aa.util import INIT_CRDS, NTOTAL
def recycle_step_legacy(ddp_model, input, n_cycle, use_amp):
gpu = ddp_model.device
xyz_prev, alpha_prev, mask_recycle = \
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
if i_cycle < n_cycle -1:
stack.enter_context(torch.no_grad())
stack.enter_context(ddp_model.no_sync())
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
return_raw=True
use_checkpoint=False
else:
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
return_raw=False
use_checkpoint=True
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
output_i = ddp_model(**input_i)
return output_i
def recycle_step_packed(ddp_model, input, n_cycle, use_amp):
""" exactly same logic as legacy recycling, except inputs and outputs are dictionaries"""
gpu = ddp_model.device
xyz_prev, alpha_prev, mask_recycle = \
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
if i_cycle < n_cycle -1:
stack.enter_context(torch.no_grad())
stack.enter_context(ddp_model.no_sync())
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
return_raw = True
use_checkpoint=False
else:
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
return_raw = False
use_checkpoint=True
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
rf_outputs, rf_latents = ddp_model(input_i, use_checkpoint=use_checkpoint)
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
return output_i
def unpack_outputs(rf_outputs, rf_latents, return_raw):
#HACK: this just unpacks the outputs into the way the previous RFAA loss function accepts it
# in the future the loss function should accept rf_outputs and rf_latents
msa, pair, state = rf_latents["msa"], rf_latents["pair"], rf_latents["state"]
if return_raw:
xyz_prev = rf_outputs["xyzs"][-1][None]
alpha_prev = rf_outputs["alphas"][-1]
return msa[:, 0], pair, xyz_prev, alpha_prev, None # mask_recycle is always None
else:
c6d_logits, mlm_logits, pae_logits, plddt_logits = rf_outputs["c6d"], rf_outputs["mlm"], \
rf_outputs["pae"], rf_outputs["plddt"]
pde_logits = None
p_bind = None
xyz, alphas = rf_outputs["xyzs"], rf_outputs["alphas"]
if "xyz_intermediate" in rf_latents:
intermediate_xyzs = torch.cat(rf_latents["xyz_intermediate"], dim=0)
xyz = torch.cat((intermediate_xyzs, xyz), dim=0)
if "alpha_intermediate" in rf_latents:
alpha_intermediate = torch.cat(rf_latents["alpha_intermediate"], dim=0)
alphas = torch.cat((alpha_intermediate, alphas), dim=0)
xyz_allatom = None
return (c6d_logits, mlm_logits, pae_logits, pde_logits, p_bind,
xyz, alphas, xyz_allatom, plddt_logits, msa[:, 0], pair, state)
def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
input_i = {}
for key in network_input:
if key in ['msa_latent', 'msa_full', 'seq']:
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
else:
input_i[key] = network_input[key]
L = input_i["msa_latent"].shape[2]
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
input_i['msa_prev'] = msa_prev
input_i['pair_prev'] = pair_prev
input_i['xyz'] = xyz_prev
input_i['mask_recycle'] = mask_recycle
input_i['sctors'] = alpha
input_i['return_raw'] = return_raw
input_i['use_checkpoint'] = use_checkpoint
input_i.pop('xyz_prev')
input_i.pop('alpha_prev')
return input_i
def get_recycle_schedule(max_cycle, n_epochs, n_train, world_size, **kwargs):
'''
get's the number of recycles per example.
'''
assert n_train % world_size == 0
# need to sync different gpus
recycle_schedules=[]
# make deterministic
np.random.seed(0)
for i in range(n_epochs):
recycle_schedule=[np.random.randint(1,max_cycle+1) for _ in range(n_train//world_size)]
recycle_schedules.append(torch.tensor(recycle_schedule))
return torch.stack(recycle_schedules, dim=0)
def get_recycle_schedule_opt(max_cycle, n_epochs, n_train, world_size, **kwargs):
assert n_train % world_size == 0
np.random.seed(0)
recycle_schedule = np.random.randint(1, max_cycle+1, (n_epochs, n_train // world_size))
return torch.tensor(recycle_schedule)
def get_random_recycle(max_cycle, **kwargs):
N_cycle = np.random.randint(1, max_cycle+1)
return N_cycle
recycle_sampling = {
"random": get_random_recycle,
"by_batch": get_recycle_schedule_opt
}

View File

@@ -287,7 +287,7 @@ def xyz_t_to_frame_xyz_sm_mask(xyz_t, is_sm, atom_frames):
return xyz_t_frame
def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
B,L,natoms = xyz_in.shape[:3]
#B,L,natoms = xyz_in.shape[:3]
frames = frame_indices[seq]
atoms = is_atom(seq)
if torch.any(atoms):
@@ -1117,6 +1117,7 @@ def atomize_protein(i_start, msa, xyz, mask, n_res_atomize=5):
lig_mask = residue_atomize_mask[r, a].repeat(nat_symm.shape[0], 1)
bond_feats = get_atomize_protein_bond_feats(i_start, msa, ra, n_res_atomize=n_res_atomize)
#HACK: use networkx graph to make the atom frames, correct implementation will include frames with "residue atoms"
# NOTE: REQUIRES NETWORKX < 3.0
G = nx.from_numpy_matrix(bond_feats.numpy())
frames = get_atom_frames(lig_seq, G)
@@ -1628,7 +1629,7 @@ def get_alt_query_ligand(chains, ligand_name, partners, lig_akeys, asmb_xfs):
return xyz_alt_s, mask_alt_s
def get_automorphs(mol, xyz_sm, mask_sm):
def get_automorphs(mol, xyz_sm, mask_sm, max_symm=1000):
"""Enumerate atom symmetry permutations."""
try:
automorphs = openbabel.vvpairUIntUInt()
@@ -1647,7 +1648,9 @@ def get_automorphs(mol, xyz_sm, mask_sm):
except Exception as e:
xyz_sm = xyz_sm[None]
mask_sm = mask_sm[None]
if xyz_sm.shape[0] > max_symm:
xyz_sm = xyz_sm[:max_symm]
mask_sm = mask_sm[:max_symm]
return xyz_sm, mask_sm
def expand_xyz_sm_to_ntotal(xyz_sm, mask_sm, N_symmetry=None):

View File

@@ -262,8 +262,7 @@ def make_full_graph(xyz, pair, idx):
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):

69
rf2aa/validate.py Normal file
View File

@@ -0,0 +1,69 @@
import torch
import os
import hydra
import pandas as pd
from torch.nn.parallel import DistributedDataParallel as DDP
from rf2aa.trainer_new import LegacyTrainer
from rf2aa.data.compose_dataset import compose_posebusters
class Validator(LegacyTrainer):
def evaluate_model(self, rank, world_size):
raise NotImplementedError()
def compose_dataset(self):
raise NotImplementedError()
def valid_step(self, inputs, n_cycle):
pass
class PoseBustersBenchmark(Validator):
def construct_dataset(self, rank, world_size):
return compose_posebusters(self.config.loader_params, rank, world_size)
def evaluate_model(self, rank, world_size):
world_size = torch.cuda.device_count()
if ('MASTER_ADDR' not in os.environ):
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
if ('MASTER_PORT' not in os.environ):
os.environ['MASTER_PORT'] = '%d'%self.config.ddp_params.port
gpu = self.init_process_group(rank, world_size)
benchmark_loader = self.construct_dataset(rank, world_size)
# move global information to device
self.move_constants_to_device(gpu)
self.construct_model(device=gpu)
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
self.load_checkpoint(rank)
self.load_model()
self.model.eval()
records = []
for inputs in benchmark_loader:
item = inputs[-1]
with torch.no_grad():
loss, loss_dict = self.train_step(inputs, self.config.loader_params.maxcycle)
loss_dict["CHAINID"] = item["CHAINID"][0]
for k, v in loss_dict.items():
if torch.is_tensor(v):
loss_dict[k] = v.item()
records.append(loss_dict)
df = pd.DataFrame(records)
df.to_csv(f"{self.output_dir}/{self.config.experiment.name}_posebusters.csv")
torch.cuda.empty_cache()
@hydra.main(version_base=None, config_path='config/train')
def main(config):
benchmarker = PoseBustersBenchmark(config=config)
benchmarker.evaluate_model(0, 1)
if __name__ == "__main__":
main()