mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Refactor network
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -12,3 +12,5 @@ __pycache__/
|
||||
*/run_scripts/
|
||||
*/tests/
|
||||
unit_tests/
|
||||
ruff.toml
|
||||
*/scratch/
|
||||
|
||||
@@ -25,7 +25,7 @@ import torch.distributed as dist
|
||||
from abc import ABC
|
||||
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
||||
|
||||
from se3_transformer.runtime.utils import get_local_rank
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank
|
||||
|
||||
|
||||
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
|
||||
|
||||
@@ -31,9 +31,9 @@ from torch import Tensor
|
||||
from torch.utils.data import random_split, DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from se3_transformer.data_loading.data_module import DataModule
|
||||
from se3_transformer.model.basis import get_basis
|
||||
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
||||
from rf2aa.SE3Transformer.se3_transformer.data_loading.data_module import DataModule
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
||||
|
||||
|
||||
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
|
||||
|
||||
@@ -31,7 +31,7 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from se3_transformer.runtime.utils import degree_to_dim
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
|
||||
@@ -29,7 +29,7 @@ from typing import Dict
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from se3_transformer.runtime.utils import degree_to_dim
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
|
||||
|
||||
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
||||
|
||||
|
||||
@@ -30,10 +30,10 @@ from dgl.ops import edge_softmax
|
||||
from torch import Tensor
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from se3_transformer.model.layers.linear import LinearSE3
|
||||
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ from dgl import DGLGraph
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||
|
||||
|
||||
class ConvSE3FuseLevel(Enum):
|
||||
|
||||
@@ -29,7 +29,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class LinearSE3(nn.Module):
|
||||
|
||||
@@ -29,7 +29,7 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class NormSE3(nn.Module):
|
||||
|
||||
@@ -29,14 +29,14 @@ import torch.nn as nn
|
||||
from dgl import DGLGraph
|
||||
from torch import Tensor
|
||||
|
||||
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from se3_transformer.model.layers.linear import LinearSE3
|
||||
from se3_transformer.model.layers.norm import NormSE3
|
||||
from se3_transformer.model.layers.pooling import GPooling
|
||||
from se3_transformer.runtime.utils import str2bool
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.norm import NormSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.pooling import GPooling
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class Sequential(nn.Sequential):
|
||||
|
||||
@@ -24,9 +24,9 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
from se3_transformer.data_loading import QM9DataModule
|
||||
from se3_transformer.model import SE3TransformerPooled
|
||||
from se3_transformer.runtime.utils import str2bool
|
||||
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
|
||||
|
||||
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from se3_transformer.runtime.loggers import Logger
|
||||
from se3_transformer.runtime.metrics import MeanAbsoluteError
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import Logger
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.metrics import MeanAbsoluteError
|
||||
|
||||
|
||||
class BaseCallback(ABC):
|
||||
|
||||
@@ -29,11 +29,11 @@ from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from se3_transformer.runtime import gpu_affinity
|
||||
from se3_transformer.runtime.arguments import PARSER
|
||||
from se3_transformer.runtime.callbacks import BaseCallback
|
||||
from se3_transformer.runtime.loggers import DLLogger
|
||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import BaseCallback
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import DLLogger
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -57,10 +57,10 @@ def evaluate(model: nn.Module,
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
||||
from se3_transformer.runtime.utils import init_distributed, seed_everything
|
||||
from se3_transformer.model import SE3TransformerPooled, Fiber
|
||||
from se3_transformer.data_loading import QM9DataModule
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import init_distributed, seed_everything
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled, Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||
import torch.distributed as dist
|
||||
import logging
|
||||
import sys
|
||||
|
||||
@@ -31,7 +31,7 @@ import torch.distributed as dist
|
||||
import wandb
|
||||
from dllogger import Verbosity
|
||||
|
||||
from se3_transformer.runtime.utils import rank_zero_only
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import rank_zero_only
|
||||
|
||||
|
||||
class Logger(ABC):
|
||||
|
||||
@@ -36,16 +36,16 @@ from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from se3_transformer.data_loading import QM9DataModule
|
||||
from se3_transformer.model import SE3TransformerPooled
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from se3_transformer.runtime import gpu_affinity
|
||||
from se3_transformer.runtime.arguments import PARSER
|
||||
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
||||
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
||||
PerformanceCallback
|
||||
from se3_transformer.runtime.inference import evaluate
|
||||
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.inference import evaluate
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
||||
using_tensor_cores, increase_l2_fetch_granularity
|
||||
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
|
||||
import torch
|
||||
|
||||
from se3_transformer.model import SE3Transformer
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||
|
||||
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from icecream import ic
|
||||
import inspect
|
||||
|
||||
import sys, os
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
sys.path.insert(0,script_dir+'SE3Transformer')
|
||||
#sys.path.insert(0, '/home/ahern/projects/rf_diffusion/RF2-allatom/rf2aa/SE3Transformer') # jue commented this -- might need to uncomment
|
||||
|
||||
from rf2aa.util_module import init_lecun_normal_param
|
||||
|
||||
from se3_transformer.model import SE3Transformer
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
se3_transformer_path = inspect.getfile(SE3Transformer)
|
||||
se3_fiber_path = inspect.getfile(Fiber)
|
||||
#ic(se3_transformer_path, se3_fiber_path)
|
||||
assert 'rf2aa' in se3_transformer_path
|
||||
|
||||
class SE3TransformerWrapper(nn.Module):
|
||||
"""SE(3) equivariant GCN with attention"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=2,
|
||||
num_edge_features=32):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.l1_in = l1_in_features
|
||||
self.l1_out = l1_out_features
|
||||
#
|
||||
fiber_edge = Fiber({0: num_edge_features})
|
||||
if l1_out_features > 0:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
|
||||
self.se3 = SE3Transformer(num_layers=num_layers,
|
||||
fiber_in=fiber_in,
|
||||
fiber_hidden=fiber_hidden,
|
||||
fiber_out = fiber_out,
|
||||
num_heads=n_heads,
|
||||
channels_div=div,
|
||||
fiber_edge=fiber_edge,
|
||||
populate_edge="arcsin",
|
||||
final_layer="lin",
|
||||
use_layer_norm=True)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
|
||||
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||
for n, p in self.se3.named_parameters():
|
||||
if "bias" in n:
|
||||
nn.init.zeros_(p)
|
||||
elif len(p.shape) == 1:
|
||||
continue
|
||||
else:
|
||||
if "radial_func" not in n:
|
||||
p = init_lecun_normal_param(p)
|
||||
else:
|
||||
if "net.6" in n:
|
||||
nn.init.zeros_(p)
|
||||
else:
|
||||
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||
|
||||
# make last layers to be zero-initialized
|
||||
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
if self.l1_out > 0:
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
|
||||
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||
if self.l1_in > 0:
|
||||
node_features = {'0': type_0_features, '1': type_1_features}
|
||||
else:
|
||||
node_features = {'0': type_0_features}
|
||||
edge_features = {'0': edge_features}
|
||||
return self.se3(G, node_features, edge_features)
|
||||
@@ -22,7 +22,7 @@ DATASET_PARAMS = [
|
||||
"fraction_sm_compl_covale",
|
||||
"fraction_sm",
|
||||
"fraction_atomize_pdb",
|
||||
"fraction_atomize_complex"
|
||||
"fraction_atomize_complex",
|
||||
"fraction_sm_compl_asmb",
|
||||
"fraction_sm_compl_furthest_neg",
|
||||
"fraction_sm_compl_permuted_neg",
|
||||
@@ -51,6 +51,7 @@ DATASET_PARAMS = [
|
||||
"n_valid_sm_compl_furthest_neg",
|
||||
"n_valid_sm_compl_permuted_neg",
|
||||
"n_valid_sm_compl_docked_neg",
|
||||
"n_valid_atomize_complex",
|
||||
"n_valid_dude_actives",
|
||||
"n_valid_dude_inactives",
|
||||
"p_short_crop",
|
||||
@@ -268,6 +269,19 @@ def get_args(parser: Optional[argparse.ArgumentParser] = None, input_args: Optio
|
||||
help="probability of a given non-standard residue being atomized, rather "
|
||||
"than being converted to a standard equivalent [1.0]",
|
||||
)
|
||||
data_group.add_argument(
|
||||
"-batch_by_dataset",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Batch examples by dataset, e.g., all nodes receive an example from the same dataset. [False]",
|
||||
)
|
||||
data_group.add_argument(
|
||||
"-batch_by_length",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Batch examples by example length, e.g., all nodes receive a similarly-sized example. [False]",
|
||||
)
|
||||
|
||||
|
||||
# dataset parameters
|
||||
dataset_group = parser.add_argument_group("data loading parameters")
|
||||
@@ -4,12 +4,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import data
|
||||
from rf2aa.parsers import parse_a3m, parse_fasta, read_template_pdb
|
||||
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from rf2aa.data.parsers import parse_a3m, parse_fasta, read_template_pdb
|
||||
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
import util
|
||||
from collections import namedtuple
|
||||
from rf2aa.ffindex import *
|
||||
from rf2aa.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
|
||||
from rf2aa.data.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
|
||||
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_init_xyz
|
||||
from rf2aa.util_module import ComputeAllAtomCoords
|
||||
from rf2aa.chemical import NTOTAL, NTOTALDOFS, NAATOKENS
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from torch.utils import data
|
||||
from rf2aa.data_loader import (
|
||||
from rf2aa.data.data_loader import (
|
||||
get_train_valid_set,
|
||||
loader_pdb,
|
||||
loader_fb,
|
||||
@@ -25,11 +25,11 @@ from rf2aa.data_loader import (
|
||||
DatasetSMComplex,
|
||||
DatasetSMComplexAssembly,
|
||||
)
|
||||
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from rf2aa.loss import *
|
||||
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from rf2aa.loss.loss import *
|
||||
from rf2aa.util import *
|
||||
|
||||
from rf2aa.train_multi_EMA import Trainer, EMA, count_parameters
|
||||
from rf2aa.archived.train_multi_EMA import Trainer, EMA, count_parameters
|
||||
|
||||
# disable openbabel warnings
|
||||
from openbabel import openbabel as ob
|
||||
@@ -11,18 +11,18 @@ script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(script_dir)
|
||||
|
||||
import rf2aa
|
||||
import rf2aa.parsers as parsers
|
||||
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
import rf2aa.data.parsers as parsers
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule as RoseTTAFoldModule
|
||||
import rf2aa.util as util
|
||||
from rf2aa.util import *
|
||||
from rf2aa.loss import *
|
||||
from collections import namedtuple, OrderedDict
|
||||
from rf2aa.ffindex import *
|
||||
from rf2aa.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo, merge_a3m_hetero
|
||||
from rf2aa.data.data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo, merge_a3m_hetero
|
||||
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
|
||||
from rf2aa.util_module import XYZConverter
|
||||
from rf2aa.chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
|
||||
from rf2aa.parsers import read_templates, parse_multichain_fasta, parse_mixed_fasta
|
||||
from rf2aa.data.parsers import read_templates, parse_multichain_fasta, parse_mixed_fasta
|
||||
from rf2aa.memory import mem_report
|
||||
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs, update_symm_subs, get_symm_map, get_symmetry
|
||||
|
||||
@@ -72,8 +72,11 @@ def get_args():
|
||||
parser.add_argument("-n_cycle", type=int, default=10, help='number of recycles')
|
||||
parser.add_argument("-trunc_N", type=int, default=0, help='residues to truncate at N-term on MSA to match PDB')
|
||||
parser.add_argument("-trunc_C", type=int, default=0, help='residues to truncate at C-term on MSA to match PDB')
|
||||
parser.add_argument("-no_extra_l1", dest='use_extra_l1', default='True', action='store_false',
|
||||
help="Turn off chirality and LJ grad inputs to SE3 layers (for backwards compatibility).")
|
||||
parser.add_argument("-use_chiral_l1", type=bool, default=True,
|
||||
help="use chiral L1 features (for backwards compatibility)")
|
||||
parser.add_argument("-use_lj_l1", type=bool, default=True,
|
||||
help="use LJ L1 features (for backwards compatibility)")
|
||||
|
||||
parser.add_argument("-no_atom_frames", dest='use_atom_frames', default='True', action='store_false',
|
||||
help="Turn off l1 features from atom frames in SE3 layers (for backwards compatibility).")
|
||||
|
||||
@@ -370,8 +373,11 @@ class Predictor():
|
||||
read_data(args.db+'_pdb.ffdata'))
|
||||
|
||||
# define model & load model
|
||||
MODEL_PARAM['use_extra_l1'] = args.use_extra_l1
|
||||
MODEL_PARAM['use_chiral_l1'] = args.use_chiral_l1
|
||||
MODEL_PARAM['use_lj_l1'] = args.use_lj_l1
|
||||
MODEL_PARAM['use_atom_frames'] = args.use_atom_frames
|
||||
MODEL_PARAM['use_same_chain'] = True
|
||||
MODEL_PARAM['recycling_type'] = 'all'
|
||||
self.model = RoseTTAFoldModule(
|
||||
**MODEL_PARAM,
|
||||
aamask = util.allatom_mask.to(self.device),
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
from torch.utils import data
|
||||
|
||||
# from chemical import NFRAMES
|
||||
from rf2aa.data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params, loader_atomize_pdb
|
||||
from rf2aa.data.data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params, loader_atomize_pdb
|
||||
from rf2aa.kinematics import xyz_to_c6d, xyz_to_t2d
|
||||
from rf2aa.chemical import num2aa, aa2elt, aa2num, aabonds,aa2long, aabtypes, atomized_protein_frames
|
||||
from rf2aa.loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss, calc_chiral_loss
|
||||
@@ -263,9 +263,9 @@ class LossTestCase(unittest.TestCase):
|
||||
print(true_crds)
|
||||
break
|
||||
|
||||
def test_chiral_loss(self):
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats, chirals, task, item in self.valid_sm_compl_loader:
|
||||
print(calc_chiral_loss(true_crds[0][None],chirals)
|
||||
def test_chiral_loss(self):
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats, chirals, task, item in self.valid_sm_compl_loader:
|
||||
print(calc_chiral_loss(true_crds[0][None],chirals))
|
||||
|
||||
class DataLoaderTestCase(unittest.TestCase):
|
||||
|
||||
@@ -19,7 +19,7 @@ from tqdm import tqdm
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(script_dir)
|
||||
|
||||
from rf2aa.data_loader import (
|
||||
from rf2aa.data.data_loader import (
|
||||
get_train_valid_set, loader_pdb, loader_fb, loader_complex,
|
||||
loader_na_complex, loader_distil_tf, loader_tf_complex, loader_dna_rna,
|
||||
loader_sm, loader_atomize_pdb, loader_atomize_complex,
|
||||
@@ -29,11 +29,11 @@ from rf2aa.data_loader import (
|
||||
DistilledDataset, DistributedWeightedSampler, unbatch_item
|
||||
)
|
||||
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, xyz_to_bbtor
|
||||
from rf2aa.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule as RoseTTAFoldModule
|
||||
from rf2aa.loss import *
|
||||
from rf2aa.util import *
|
||||
from rf2aa.util_module import XYZConverter
|
||||
from rf2aa.scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup
|
||||
from rf2aa.training.scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup
|
||||
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs, get_symm_map
|
||||
|
||||
from rf2aa.chemical import load_pdb_ideal_sdf_strings
|
||||
@@ -55,6 +55,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# limit thread counts
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||
|
||||
## To reproduce errors
|
||||
import random
|
||||
@@ -80,6 +81,20 @@ def add_weight_decay(model, l2_coeff):
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
def get_recycle_schedule(n_epochs, n_train, n_max, world_size):
|
||||
'''
|
||||
get's the number of recycles per example.
|
||||
'''
|
||||
assert n_train % world_size == 0
|
||||
# need to sync different gpus
|
||||
recycle_schedules=[]
|
||||
# make deterministic
|
||||
np.random.seed(0)
|
||||
for i in range(n_epochs):
|
||||
recycle_schedule=[np.random.randint(1,n_max) for _ in range(n_train//world_size)]
|
||||
recycle_schedules.append(torch.tensor(recycle_schedule))
|
||||
return torch.stack(recycle_schedules, dim=0)
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(self, model, decay):
|
||||
super().__init__()
|
||||
@@ -97,13 +112,13 @@ class EMA(nn.Module):
|
||||
print("EMA update should only be called during training", file=stderr, flush=True)
|
||||
return
|
||||
|
||||
model_params = OrderedDict(self.model.named_parameters())
|
||||
self.model_params = OrderedDict(self.model.named_parameters())
|
||||
shadow_params = OrderedDict(self.shadow.named_parameters())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_params.keys() == shadow_params.keys()
|
||||
assert self.model_params.keys() == shadow_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
for name, param in self.model_params.items():
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
||||
if param.requires_grad:
|
||||
@@ -128,7 +143,7 @@ class EMA(nn.Module):
|
||||
class Trainer():
|
||||
def __init__(self, model_name='BFF', checkpoint_path=None,
|
||||
n_epoch=100, step_lr=100, lr=1.0e-4, l2_coeff=1.0e-2, port=None, interactive=False,
|
||||
model_param={}, loader_param={}, loss_param={}, dataset_param={}, batch_size=1,
|
||||
model_params={}, loader_param={}, loss_param={}, dataset_param={}, batch_size=1,
|
||||
accum_step=1, maxcycle=4, eval=False, out_dir=None, wandb_prefix=None,
|
||||
model_dir='models/', dataloader_kwargs = {}, **kwargs):
|
||||
|
||||
@@ -140,7 +155,7 @@ class Trainer():
|
||||
self.port = port
|
||||
self.interactive = interactive
|
||||
self.eval = eval
|
||||
self.model_param = model_param
|
||||
self.model_params = model_params
|
||||
self.loader_param = loader_param
|
||||
self.loss_param = loss_param
|
||||
self.dataset_param = dataset_param
|
||||
@@ -158,7 +173,9 @@ class Trainer():
|
||||
self.debug_mode = kwargs.get("debug_mode", False)
|
||||
self.skip_valid = kwargs.get("skip_valid", 1)
|
||||
self.start_epoch = kwargs.get("start_epoch", 0)
|
||||
|
||||
self.n_train = self.dataset_param["n_train"]
|
||||
self.maxcycle = maxcycle
|
||||
self.recycle_schedule=get_recycle_schedule(self.n_epoch, self.n_train, self.maxcycle+1, world_size)
|
||||
# for all-atom str loss
|
||||
#self.ti_dev = torsion_indices
|
||||
#self.ti_flip = torsion_can_flip
|
||||
@@ -184,7 +201,6 @@ class Trainer():
|
||||
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
|
||||
self.active_fn = nn.Softmax(dim=1)
|
||||
|
||||
self.maxcycle = maxcycle
|
||||
|
||||
self.pdb_counter=0
|
||||
|
||||
@@ -468,8 +484,9 @@ class Trainer():
|
||||
else:
|
||||
l_fape_sm_protein = torch.tensor(0).to(gpu)
|
||||
|
||||
frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
|
||||
inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
|
||||
#frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
|
||||
#inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
|
||||
inter_fape = l_fape_sm_protein
|
||||
bb_l_fape_inter = (w_bb_fape*inter_fape).sum()
|
||||
tot_loss += 0.5*w_inter_fape*bb_l_fape_inter
|
||||
else:
|
||||
@@ -601,25 +618,27 @@ class Trainer():
|
||||
tot_loss += w_bond*bond_loss
|
||||
loss_dict['bond_geom'] = bond_loss.detach()
|
||||
|
||||
if (pred_allatom.shape[0] > 1):
|
||||
bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
|
||||
if w_bond > 0.0:
|
||||
tot_loss += w_bond*bond_loss.mean()
|
||||
loss_dict['clash_loss'] = ( bond_loss.detach() )
|
||||
else:
|
||||
bond_loss = torch.tensor(0).to(gpu)
|
||||
loss_dict['bond_loss'] = bond_loss.detach()
|
||||
# if (pred_allatom.shape[0] > 1):
|
||||
# bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
|
||||
# if w_bond > 0.0:
|
||||
# tot_loss += w_bond*bond_loss.mean()
|
||||
# loss_dict['clash_loss'] = ( bond_loss.detach() )
|
||||
# else:
|
||||
# bond_loss = torch.tensor(0).to(gpu)
|
||||
# loss_dict['bond_loss'] = bond_loss.detach()
|
||||
|
||||
# clash [use all atoms not just those in native]
|
||||
clash_loss = calc_lj(
|
||||
seq[0], pred_allatom,
|
||||
self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
|
||||
lj_lin=lj_lin
|
||||
)
|
||||
|
||||
# clash_loss = calc_lj(
|
||||
# seq[0], pred_allatom,
|
||||
# self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
|
||||
# lj_lin=lj_lin
|
||||
# )
|
||||
clash_loss, num_violations = calc_l1_clash_loss(pred_allatom, seq[0],\
|
||||
self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, \
|
||||
self.lj_correction_parameters, self.num_bonds)
|
||||
if w_clash > 0.0:
|
||||
tot_loss += w_clash*clash_loss.mean()
|
||||
loss_dict['clash_loss'] = clash_loss[0].detach()
|
||||
loss_dict['clash_loss'] = clash_loss.detach()
|
||||
if torch.any(mask_BB[0]):
|
||||
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
|
||||
pred=pred_allatom[:,mask_BB[0]],
|
||||
@@ -761,6 +780,7 @@ class Trainer():
|
||||
|
||||
def load_model(self, model, model_name, rank, suffix='last', checkpoint_path=None, resume_train=False,
|
||||
optimizer=None, scheduler=None, scaler=None):
|
||||
torch.cuda.empty_cache()
|
||||
if self.debug_mode:
|
||||
return -1, 99999999.9
|
||||
if checkpoint_path==None:
|
||||
@@ -849,15 +869,27 @@ class Trainer():
|
||||
print("Running in DEBUG mode...")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
|
||||
self.train_model(rank, world_size)
|
||||
if (not self.interactive and "SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
|
||||
world_size = int(os.environ["SLURM_NTASKS"])
|
||||
rank = int (os.environ["SLURM_PROCID"])
|
||||
print ("Launched from slurm", rank, world_size)
|
||||
self.train_model(rank, world_size)
|
||||
if torch.cuda.device_count() == int(os.environ["SLURM_NTASKS"]):
|
||||
# If launching 1 job per node
|
||||
world_size = int(os.environ["SLURM_NTASKS"])
|
||||
rank = int (os.environ["SLURM_PROCID"])
|
||||
print ("Launched from slurm", rank, world_size)
|
||||
self.train_model(rank, world_size)
|
||||
elif torch.cuda.device_count() > 1 and int(os.environ["SLURM_NTASKS"]) == 1:
|
||||
# If launching all jobs from same node
|
||||
world_size = torch.cuda.device_count()
|
||||
print(f"Spawning all jobs from one node. World size: {world_size}")
|
||||
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
# Raise error, since we either need one job per node or all jobs from one node
|
||||
raise RuntimeError("Invalid distributed processing combination of nodes/tasks/gpus for SLURM.")
|
||||
else:
|
||||
print ("Launched from interactive")
|
||||
world_size = torch.cuda.device_count()
|
||||
|
||||
if world_size == 1:
|
||||
# No need for multiple processes with 1 GPU
|
||||
self.train_model(0, world_size)
|
||||
@@ -906,7 +938,7 @@ class Trainer():
|
||||
)
|
||||
all_param = {}
|
||||
all_param.update(self.loader_param)
|
||||
all_param.update(self.model_param)
|
||||
all_param.update(self.self.model_params)
|
||||
all_param.update(self.loss_param)
|
||||
|
||||
wandb.config = all_param
|
||||
@@ -914,7 +946,7 @@ class Trainer():
|
||||
|
||||
#print ("running ddp on rank %d, world_size %d"%(rank, world_size))
|
||||
gpu = rank % torch.cuda.device_count()
|
||||
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
|
||||
dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device("cuda:%d"%gpu)
|
||||
|
||||
# Get ligand dictionary. This is used for loading negative examples.
|
||||
@@ -1002,7 +1034,9 @@ class Trainer():
|
||||
fractions=OrderedDict([(k, self.dataset_param['fraction_'+k]) for k in train_dict]),
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
replacement=True
|
||||
lengths=self.loader_param["EXAMPLE_LENGTHS"],
|
||||
batch_by_dataset=self.loader_param["BATCH_BY_DATASET"],
|
||||
batch_by_length=self.loader_param["BATCH_BY_LENGTH"],
|
||||
)
|
||||
|
||||
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **self.dataloader_kwargs)
|
||||
@@ -1048,11 +1082,6 @@ class Trainer():
|
||||
loader_distil_tf, valid_dict['distil_tf'],
|
||||
self.loader_param, negative=False, native_NA_frac=0.0
|
||||
),
|
||||
atomize_pdb = Dataset(
|
||||
valid_ID_dict['atomize_pdb'][:self.dataset_param['n_valid_atomize_pdb']],
|
||||
loader_atomize_pdb, valid_dict['atomize_pdb'],
|
||||
self.loader_param, homo, p_homo_cut=-1.0, n_res_atomize=3, flank=0, p_short_crop=-1.0
|
||||
),
|
||||
metal_compl = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['metal_compl'][:self.dataset_param['n_valid_metal_compl']],
|
||||
loader_sm_compl_assembly, valid_dict['metal_compl'],
|
||||
@@ -1302,10 +1331,16 @@ class Trainer():
|
||||
self.cb_len = self.cb_len.to(gpu)
|
||||
self.cb_ang = self.cb_ang.to(gpu)
|
||||
self.cb_tor = self.cb_tor.to(gpu)
|
||||
self.model_params['use_chiral_l1'] = True
|
||||
self.model_params['use_lj_l1'] = False
|
||||
self.model_params['use_atom_frames'] = True
|
||||
self.model_params['use_same_chain'] = False
|
||||
self.model_params['recycling_type'] = 'msa_pair'
|
||||
self.model_params.pop('use_extra_l1')
|
||||
|
||||
# define model
|
||||
model = EMA(RoseTTAFoldModule(
|
||||
**self.model_param,
|
||||
**self.model_params,
|
||||
aamask=self.aamask,
|
||||
atom_type_index=self.atom_type_index,
|
||||
ljlk_parameters=self.ljlk_parameters,
|
||||
@@ -1553,9 +1588,10 @@ class Trainer():
|
||||
network_input['symmRs'] = symmRs
|
||||
network_input['symmmeta'] = symmmeta
|
||||
|
||||
mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1)
|
||||
mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
|
||||
mask_recycle = same_chain.float()*mask_recycle.float()
|
||||
#mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1)
|
||||
#mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
|
||||
#mask_recycle = same_chain.float()*mask_recycle.float()
|
||||
mask_recycle = None
|
||||
return task, item, network_input, xyz_prev, alpha_prev, mask_recycle, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
|
||||
|
||||
def _get_model_input(self, network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
|
||||
@@ -1565,7 +1601,11 @@ class Trainer():
|
||||
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
|
||||
else:
|
||||
input_i[key] = network_input[key]
|
||||
msa_prev, pair_prev, xyz_prev, alpha, mask_recycle = output_i
|
||||
|
||||
L = input_i["msa_latent"].shape[2]
|
||||
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
|
||||
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
|
||||
|
||||
input_i['msa_prev'] = msa_prev
|
||||
input_i['pair_prev'] = pair_prev
|
||||
input_i['xyz'] = xyz_prev
|
||||
@@ -1579,7 +1619,7 @@ class Trainer():
|
||||
self, output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label, ctrid=0
|
||||
):
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _ = output_i
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
|
||||
|
||||
if (symmRs is not None):
|
||||
#print ('a', pred_crds.shape, true_crds.shape, mask_crds.shape)
|
||||
@@ -1682,7 +1722,10 @@ class Trainer():
|
||||
|
||||
counter = 0
|
||||
|
||||
for inputs in train_loader:
|
||||
for train_idx, inputs in enumerate(train_loader):
|
||||
regression = {
|
||||
"dataloader_inputs": inputs
|
||||
}
|
||||
(
|
||||
task, item, network_input, xyz_prev, alpha_prev, mask_recycle,
|
||||
true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
|
||||
@@ -1691,9 +1734,12 @@ class Trainer():
|
||||
|
||||
counter += 1
|
||||
|
||||
N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling
|
||||
#N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling
|
||||
# all examples in a pseudo batch have the same recycle
|
||||
N_cycle = self.recycle_schedule[epoch, train_idx] # number of recycling
|
||||
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
N_cycle = 1
|
||||
for i_cycle in range(N_cycle):
|
||||
with ExitStack() as stack:
|
||||
if i_cycle < N_cycle -1:
|
||||
@@ -1707,9 +1753,8 @@ class Trainer():
|
||||
return_raw=False
|
||||
use_checkpoint=True
|
||||
input_i = self._get_model_input(network_input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||
|
||||
output_i = ddp_model(**input_i)
|
||||
|
||||
|
||||
if i_cycle < N_cycle - 1:
|
||||
continue
|
||||
loss, loss_dict, acc_s, _, true_crds, pred_allatom, res_mask = self._get_loss_and_misc(
|
||||
@@ -1720,10 +1765,21 @@ class Trainer():
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
len(train_loader)*rank+counter
|
||||
)
|
||||
regression.update( {
|
||||
"model_input": input_i,
|
||||
"model_output": output_i,
|
||||
"loss": loss,
|
||||
"loss_dict": loss_dict
|
||||
})
|
||||
torch.save(
|
||||
regression,
|
||||
"test_pickles/model_io.pt"
|
||||
)
|
||||
loss = loss / self.ACCUM_STEP
|
||||
scaler.scale(loss).backward()
|
||||
if counter%self.ACCUM_STEP == 0:
|
||||
# gradient clipping
|
||||
print("accumulation")
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.2)
|
||||
scaler.step(optimizer)
|
||||
@@ -1734,7 +1790,15 @@ class Trainer():
|
||||
if not skip_lr_sched:
|
||||
scheduler.step()
|
||||
ddp_model.module.update() # apply EMA
|
||||
torch.save({'epoch': epoch,
|
||||
#'model_state_dict': ddp_model.state_dict(),
|
||||
'model_state_dict': ddp_model.module.shadow.state_dict(),
|
||||
'final_state_dict': ddp_model.module.model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'scaler_state_dict': scaler.state_dict()},"test_pickles/optimizer_regression.pt")
|
||||
|
||||
raise UnboundLocalError("stopping after 1 iteration")
|
||||
item_ = unbatch_item(item) # remove nested lists to make more readable when printed
|
||||
save_pdbs = False
|
||||
if torch.isnan(loss):
|
||||
@@ -1886,6 +1950,7 @@ class Trainer():
|
||||
return_raw=False
|
||||
|
||||
input_i = self._get_model_input(network_input, output_i, i_cycle, gpu, return_raw=return_raw)
|
||||
|
||||
output_i = ddp_model(**input_i)
|
||||
|
||||
if i_cycle < N_cycle - 1:
|
||||
@@ -2233,7 +2298,7 @@ class Trainer():
|
||||
|
||||
if __name__ == "__main__":
|
||||
from arguments import get_args
|
||||
args, dataset_param, model_param, loader_param, loss_param = get_args()
|
||||
args, dataset_param, model_params, loader_param, loss_param = get_args()
|
||||
|
||||
if args.debug:
|
||||
DEBUG = True
|
||||
@@ -2250,7 +2315,8 @@ if __name__ == "__main__":
|
||||
"num_workers": args.dataloader_num_workers,
|
||||
"pin_memory": not args.dont_pin_memory,
|
||||
}
|
||||
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
trainer_object = Trainer(
|
||||
model_name=args.model_name,
|
||||
checkpoint_path = args.checkpoint_path,
|
||||
@@ -2259,7 +2325,7 @@ if __name__ == "__main__":
|
||||
lr=args.lr,
|
||||
l2_coeff=1.0e-2,
|
||||
port=args.port,
|
||||
model_param=model_param,
|
||||
model_params=model_params,
|
||||
loader_param=loader_param,
|
||||
loss_param=loss_param,
|
||||
batch_size=args.batch_size,
|
||||
@@ -2274,5 +2340,6 @@ if __name__ == "__main__":
|
||||
dataloader_kwargs=dataloader_kwargs,
|
||||
debug_mode=args.debug,
|
||||
skip_valid=args.skip_valid,
|
||||
world_size=world_size
|
||||
)
|
||||
trainer_object.run_model_training(torch.cuda.device_count())
|
||||
trainer_object.run_model_training(world_size)
|
||||
129
rf2aa/config/train/base.yaml
Normal file
129
rf2aa/config/train/base.yaml
Normal file
@@ -0,0 +1,129 @@
|
||||
experiment:
|
||||
name: rf2a-fd4-20231117
|
||||
n_epoch: 800
|
||||
output_dir: null
|
||||
trainer: "legacy"
|
||||
|
||||
model:
|
||||
embedding: null
|
||||
blocks: null
|
||||
refinment: null
|
||||
auxiliary_predictors: {}
|
||||
legacy_model: null
|
||||
dataset_params:
|
||||
fraction_pdb: 0
|
||||
fraction_fb: 0
|
||||
fraction_compl: 0
|
||||
fraction_neg_compl: 0
|
||||
fraction_na_compl: 0
|
||||
fraction_neg_na_compl: 0
|
||||
fraction_distil_tf: 0
|
||||
fraction_tf: 0
|
||||
fraction_neg_tf: 0
|
||||
fraction_rna: 0
|
||||
fraction_dna: 0
|
||||
fraction_sm_compl: 1
|
||||
fraction_metal_compl: 0
|
||||
fraction_sm_compl_multi: 0
|
||||
fraction_sm_compl_covale: 0
|
||||
fraction_sm: 0
|
||||
fraction_atomize_pdb: 0
|
||||
fraction_atomize_complex: 0
|
||||
fraction_sm_compl_asmb: 0
|
||||
fraction_sm_compl_furthest_neg: 0
|
||||
fraction_sm_compl_permuted_neg: 0
|
||||
fraction_sm_compl_docked_neg: 0
|
||||
n_train: 256000
|
||||
n_valid_pdb: 0
|
||||
n_valid_homo: 0
|
||||
n_valid_dslf: 0
|
||||
n_valid_compl: 0
|
||||
n_valid_neg_compl: 0
|
||||
n_valid_na_compl: 0
|
||||
n_valid_neg_na_compl: 0
|
||||
n_valid_distil_tf: 0
|
||||
n_valid_tf: 0
|
||||
n_valid_neg_tf: 0
|
||||
n_valid_rna: 0
|
||||
n_valid_dna: 0
|
||||
n_valid_sm_compl: 0
|
||||
n_valid_metal_compl: 0
|
||||
n_valid_sm_compl_multi: 0
|
||||
n_valid_sm_compl_covale: 0
|
||||
n_valid_sm_compl_strict: 0
|
||||
n_valid_sm: 0
|
||||
n_valid_atomize_pdb: 0
|
||||
n_valid_atomize_complex: 0
|
||||
n_valid_sm_compl_asmb: 0
|
||||
n_valid_sm_compl_furthest_neg: 0
|
||||
n_valid_sm_compl_permuted_neg: 0
|
||||
n_valid_sm_compl_docked_neg: 0
|
||||
n_valid_dude_actives: 0
|
||||
n_valid_dude_inactives: 0
|
||||
p_short_crop: 0
|
||||
p_dslf_crop: 0
|
||||
dslf_fb_upsample: 1
|
||||
|
||||
loader_params:
|
||||
maxseq: 1024
|
||||
maxtoken: 1024
|
||||
maxlat: 128
|
||||
crop: 256
|
||||
rescut: 4.5
|
||||
mintplt: 1
|
||||
maxtplt: 4
|
||||
seqid: 150.0
|
||||
p_msa_mask: 0.15
|
||||
maxcycle: 4
|
||||
nres_atomize_min: 3
|
||||
nres_atomize_max: 5
|
||||
atomize_flank: 0
|
||||
p_metal: 1
|
||||
p_atomize_modres: 1
|
||||
batch_by_dataset: True
|
||||
batch_by_length: True
|
||||
dataloader_kwargs:
|
||||
shuffle: False
|
||||
num_workers: 0
|
||||
pin_memory: True
|
||||
|
||||
ddp_params:
|
||||
accum: 1
|
||||
batch_size: 1
|
||||
port: 12435
|
||||
|
||||
training_params:
|
||||
resume_train: False
|
||||
EMA: 0.99
|
||||
weight_decay: 0.01
|
||||
learning_rate: .001
|
||||
learning_rate_schedule:
|
||||
num_warmup_steps: 0
|
||||
num_steps_decay: 5000
|
||||
decay_rate: 0.95
|
||||
grad_clip: 0.2
|
||||
use_amp: False
|
||||
|
||||
loss_param:
|
||||
w_dist: 1.0
|
||||
w_str: 1.0
|
||||
w_inter_fape: 0.0
|
||||
w_lig_fape: 0.0
|
||||
w_lddt: 0.1
|
||||
w_aa: 3.0
|
||||
w_bond: 0.0
|
||||
w_bind: 0.0
|
||||
binder_loss_label_smoothing: 0.0
|
||||
w_clash: 0.0
|
||||
w_atom_bond: 0.0
|
||||
w_skip_bond: 0.0
|
||||
w_rigid: 0.0
|
||||
w_hb: 0.0
|
||||
w_pae: 0.05
|
||||
w_pde: 0.05
|
||||
lj_lin: 0.75
|
||||
|
||||
log_params:
|
||||
log_every_n_examples: 1
|
||||
|
||||
eval_params: null
|
||||
61
rf2aa/config/train/legacy_train.yaml
Normal file
61
rf2aa/config/train/legacy_train.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
defaults:
|
||||
- base
|
||||
|
||||
experiment:
|
||||
name: frank_model-10rec
|
||||
output_dir: null
|
||||
|
||||
legacy_model_param:
|
||||
n_extra_block: 4
|
||||
n_main_block: 32
|
||||
n_ref_block: 4
|
||||
n_finetune_block: 0
|
||||
d_msa: 256
|
||||
d_msa_full: 64
|
||||
d_pair: 192
|
||||
d_templ: 64
|
||||
n_head_msa: 8
|
||||
n_head_pair: 6
|
||||
n_head_templ: 4
|
||||
d_hidden_templ: 64
|
||||
p_drop: 0.0
|
||||
use_chiral_l1: True
|
||||
use_lj_l1: True
|
||||
use_atom_frames: True
|
||||
recycling_type: "all"
|
||||
use_same_chain: True
|
||||
lj_lin: 0.75
|
||||
SE3_param:
|
||||
num_layers: 1
|
||||
num_channels: 32
|
||||
num_degrees: 2
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 3
|
||||
l1_out_features: 2
|
||||
num_edge_features: 64
|
||||
n_heads: 4
|
||||
div: 4
|
||||
SE3_ref_param:
|
||||
num_layers: 2
|
||||
num_channels: 32
|
||||
num_degrees: 2
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 3
|
||||
l1_out_features: 2
|
||||
num_edge_features: 64
|
||||
n_heads: 4
|
||||
div: 4
|
||||
|
||||
ddp_params:
|
||||
port: 12375
|
||||
batch_size: 1
|
||||
|
||||
loader_params:
|
||||
crop: 800
|
||||
maxcycle: 10
|
||||
maxlat: 256
|
||||
|
||||
eval_params:
|
||||
checkpoint_path: "/home/dimaio/RF2-allatom/rf2aa/models/RF2_26a_last.pt"
|
||||
119
rf2aa/config/train/rf2aa.yaml
Normal file
119
rf2aa/config/train/rf2aa.yaml
Normal file
@@ -0,0 +1,119 @@
|
||||
defaults:
|
||||
- base
|
||||
|
||||
experiment:
|
||||
name: rf2aa-reproduction
|
||||
trainer: "composed"
|
||||
|
||||
model:
|
||||
global_params:
|
||||
d_msa: 256
|
||||
d_msa_full: 64
|
||||
d_pair: 192
|
||||
d_state: 64
|
||||
|
||||
embedding:
|
||||
rf2aa:
|
||||
params:
|
||||
p_drop: 0.15
|
||||
d_templ: 64
|
||||
n_head_templ: 4
|
||||
d_hidden_templ: 64
|
||||
templ_p_drop: 0.25
|
||||
symmetrize_repeats: False
|
||||
repeat_length: null
|
||||
symmsub_k: null
|
||||
sym_method: null
|
||||
main_block: null
|
||||
copy_main_block_template: False
|
||||
additional_dt1d: 0
|
||||
recycling_type: "msa_pair"
|
||||
use_same_chain: False
|
||||
blocks:
|
||||
RF2aa_full:
|
||||
num_blocks: 4
|
||||
params:
|
||||
d_rbf: 64
|
||||
p_drop_row: 0.25
|
||||
p_drop_pair: 0.25
|
||||
msa_transition_drop: 0.0
|
||||
outer_product_channels: 16
|
||||
p_drop_outer_product: 0.0
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 32
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 8
|
||||
structure_bias_gate_channels: 16
|
||||
structure_bias_channels: 64
|
||||
n_se3_layers: 1
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 6
|
||||
l1_out_features: 2
|
||||
n_se3_edge_features: 64
|
||||
sc_pred_d_hidden: 128
|
||||
sc_pred_p_drop: 0.0
|
||||
RF2aa:
|
||||
num_blocks: 32
|
||||
params:
|
||||
d_rbf: 64
|
||||
p_drop_row: 0.25
|
||||
p_drop_pair: 0.25
|
||||
msa_transition_drop: 0.0
|
||||
outer_product_channels: 16
|
||||
p_drop_outer_product: 0.0
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 32
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 32
|
||||
structure_bias_gate_channels: 16
|
||||
structure_bias_channels: 64
|
||||
n_se3_layers: 1
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 6
|
||||
l1_out_features: 2
|
||||
n_se3_edge_features: 64
|
||||
sc_pred_d_hidden: 128
|
||||
sc_pred_p_drop: 0.0
|
||||
refinement:
|
||||
local:
|
||||
params:
|
||||
num_iterations: 4
|
||||
d_rbf: 64
|
||||
n_se3_layers: 2
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 6
|
||||
l1_out_features: 2
|
||||
n_se3_edge_features: 64
|
||||
top_k: 64
|
||||
sc_pred_d_hidden: 128
|
||||
sc_pred_p_drop: 0.15
|
||||
|
||||
auxiliary_predictors:
|
||||
c6d:
|
||||
n_feat: 192
|
||||
input_feature: "pair"
|
||||
mlm:
|
||||
n_feat: 256
|
||||
input_feature: "msa"
|
||||
pae:
|
||||
n_feat: 192
|
||||
input_feature: "pair"
|
||||
plddt:
|
||||
n_feat: 64
|
||||
input_feature: "state"
|
||||
|
||||
111
rf2aa/config/train/rf_with_gradients.yaml
Normal file
111
rf2aa/config/train/rf_with_gradients.yaml
Normal file
@@ -0,0 +1,111 @@
|
||||
defaults:
|
||||
- base
|
||||
|
||||
experiment:
|
||||
name: rf2a-fd4-20240116
|
||||
trainer: "composed"
|
||||
|
||||
model:
|
||||
global_params:
|
||||
d_msa: 256
|
||||
d_msa_full: 64
|
||||
d_pair: 192
|
||||
d_state: 32
|
||||
|
||||
embedding:
|
||||
rf2aa_nostate:
|
||||
params:
|
||||
p_drop: 0.15
|
||||
d_templ: 64
|
||||
n_head_templ: 4
|
||||
d_hidden_templ: 64
|
||||
templ_p_drop: 0.25
|
||||
recycling_type: "msa_pair"
|
||||
use_same_chain: False
|
||||
|
||||
blocks:
|
||||
RF2_withgradients_full:
|
||||
num_blocks: 4
|
||||
params:
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 64
|
||||
p_drop_row: 0.25
|
||||
p_drop_col: 0.25
|
||||
structure_bias_type: "ungated"
|
||||
structure_bias_channels: 64
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 192
|
||||
outer_product_channels: 16
|
||||
intermediate_seq_channels: 192 #this is from AF2
|
||||
n_se3_layers: 1
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 32
|
||||
l1_in_features: 1
|
||||
l1_out_features: 1
|
||||
n_se3_edge_features: 64
|
||||
RF2_withgradients:
|
||||
num_blocks: 36
|
||||
params:
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 256
|
||||
p_drop_row: 0.25
|
||||
p_drop_col: 0.25
|
||||
structure_bias_type: "ungated"
|
||||
structure_bias_channels: 64
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 192
|
||||
outer_product_channels: 16
|
||||
intermediate_seq_channels: 192 #this is from AF2
|
||||
n_se3_layers: 1
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 32
|
||||
l1_in_features: 1
|
||||
l1_out_features: 1
|
||||
n_se3_edge_features: 64
|
||||
refinement:
|
||||
local:
|
||||
params:
|
||||
num_iterations: 4
|
||||
d_rbf: 64
|
||||
n_se3_layers: 2
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 6
|
||||
l1_out_features: 2
|
||||
n_se3_edge_features: 64
|
||||
top_k: 64
|
||||
sc_pred_d_hidden: 128
|
||||
sc_pred_p_drop: 0.15
|
||||
|
||||
auxiliary_predictors:
|
||||
c6d:
|
||||
n_feat: 192
|
||||
input_feature: "pair"
|
||||
mlm:
|
||||
n_feat: 256
|
||||
input_feature: "msa"
|
||||
pae:
|
||||
n_feat: 192
|
||||
input_feature: "pair"
|
||||
plddt:
|
||||
n_feat: 64
|
||||
input_feature: "state"
|
||||
|
||||
dataset_params:
|
||||
fraction_sm_compl: 1
|
||||
n_train: 25600
|
||||
|
||||
training_params:
|
||||
resume_train: False
|
||||
365
rf2aa/data/compose_dataset.py
Normal file
365
rf2aa/data/compose_dataset.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch.utils.data as data
|
||||
from collections import OrderedDict
|
||||
|
||||
from rf2aa.data.data_loader import get_train_valid_set, loader_pdb, loader_complex, loader_na_complex, \
|
||||
loader_distil_tf, loader_tf_complex, loader_fb, loader_dna_rna, loader_sm_compl_assembly_single, \
|
||||
loader_sm_compl_assembly, loader_sm, loader_atomize_pdb, loader_atomize_complex, DistilledDataset, \
|
||||
DistributedWeightedSampler, Dataset, DatasetRNA, DatasetNAComplex, DatasetNAComplex, \
|
||||
DatasetSMComplexAssembly, DatasetSM, _load_df
|
||||
|
||||
#### handle defaults
|
||||
#TODO: shouldn't have to do this in the future, should all be handled in config
|
||||
|
||||
base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02"
|
||||
compl_dir = "/projects/ml/RoseTTAComplex"
|
||||
na_dir = "/projects/ml/nucleic"
|
||||
fb_dir = "/projects/ml/TrRosetta/fb_af"
|
||||
sm_compl_dir = "/projects/ml/RF2_allatom"
|
||||
mol_dir = "/projects/ml/RF2_allatom/rcsb/pkl" # for phase 3 dataloaders
|
||||
# mol_dir = "/projects/ml/RF2_allatom/isdf" # for legacy datasets
|
||||
tf_dir = "/projects/ml/prot_dna"
|
||||
csd_dir = "/databases/csd543"
|
||||
sample_lengths_dir = "/projects/ml/RF2_allatom"
|
||||
|
||||
default_dataloader_params = {
|
||||
"COMPL_LIST" : "%s/list.hetero.csv"%compl_dir,
|
||||
"HOMO_LIST" : "%s/list.homo.csv"%compl_dir,
|
||||
"NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir,
|
||||
"RNA_LIST" : "%s/list.rnaonly.csv"%na_dir,
|
||||
"DNA_LIST" : "%s/list.dnaonly.v3.csv"%na_dir,
|
||||
"NA_COMPL_LIST" : "%s/list.nucleic.v3.csv"%na_dir,
|
||||
"NEG_NA_COMPL_LIST": "%s/list.na_negatives.v3.csv"%na_dir,
|
||||
"TF_DISTIL_LIST" : "%s/prot_na_distill.v3.csv"%tf_dir,
|
||||
"TF_COMPL_LIST" : "%s/tf_compl_list.v4.csv"%tf_dir,
|
||||
"SM_LIST" : "%s/sm_compl_all_20230418.csv"%sm_compl_dir,
|
||||
"PDB_LIST" : "%s/list_v02_w_taxid.csv"%sm_compl_dir, # on digs
|
||||
"PDB_METADATA" : "%s/list_v00_w_taxid_20230201.csv"%sm_compl_dir, # on digs
|
||||
"FB_LIST" : "%s/list_b1-3.csv"%fb_dir,
|
||||
"CSD_LIST" : "%s/csd543_cleaned01.csv"%csd_dir,
|
||||
"VAL_PDB" : "%s/valid_remapped"%sm_compl_dir,
|
||||
"VAL_RNA" : "%s/rna_valid.csv"%na_dir,
|
||||
"VAL_DNA" : "%s/dna_valid.csv"%na_dir,
|
||||
"VAL_COMPL" : "%s/val_lists/xaa"%compl_dir,
|
||||
"VAL_NEG" : "%s/val_lists/xaa.neg"%compl_dir,
|
||||
"VAL_TF" : "%s/tf_valid_clusters_v4.txt"%tf_dir,
|
||||
"VAL_SM_STRICT" : "%s/sm_compl_valid_strict_20230418.csv"%sm_compl_dir,
|
||||
"TEST_SM" : "%s/sm_test_heldout_test_clusters.txt"%sm_compl_dir,
|
||||
"DATAPKL" : "%s/dataset_20231116.pkl"%sm_compl_dir, # cache for faster loading
|
||||
"DSLF_LIST" : "%s/list.dslf.csv"%na_dir,
|
||||
"DSLF_FB_LIST" : "%s/list.dslf_fb.csv"%na_dir,
|
||||
"DUDE_LIST" : "/home/dnan/projects/gald_distil_set/nbs/dude_dataset_cutoff_-5.csv", # on digs (dnan)
|
||||
"DUDE_MSAS" : "/home/dnan/projects/gald_distil_set/DUDE/fastas", # on digs (dnan)
|
||||
"DUDE_PDB_DIR" : "/home/dnan/projects/gald_distil_set/DUDE/pdbs_all",
|
||||
"EXAMPLE_LENGTHS" : "%s/all_sample_lengths_crop_1K.pt"%sample_lengths_dir,
|
||||
"PDB_DIR" : base_dir,
|
||||
"FB_DIR" : fb_dir,
|
||||
"COMPL_DIR" : compl_dir,
|
||||
"NA_DIR" : na_dir,
|
||||
"TF_DIR" : tf_dir,
|
||||
"MOL_DIR" : mol_dir,
|
||||
"CSD_DIR" : csd_dir,
|
||||
"MINTPLT" : 0,
|
||||
"MAXTPLT" : 5,
|
||||
"MINSEQ" : 1,
|
||||
"MAXSEQ" : 1024,
|
||||
"MAXLAT" : 128,
|
||||
"CROP" : 256,
|
||||
"DATCUT" : "2021-Aug-1",
|
||||
"RESCUT" : 4.5,
|
||||
"BLOCKCUT" : 5,
|
||||
"PLDDTCUT" : 70.0,
|
||||
"SCCUT" : 90.0,
|
||||
"ROWS" : 1,
|
||||
"SEQID" : 95.0,
|
||||
"MAXCYCLE" : 4,
|
||||
"RMAX" : 5.0,
|
||||
"MAXRES" : 1,
|
||||
"MINATOMS" : 5,
|
||||
"MAXATOMS" : 100,
|
||||
"MAXSIM" : 0.85,
|
||||
"MAXNSYMM" : 1024,
|
||||
"NRES_ATOMIZE_MIN" : 5,
|
||||
"NRES_ATOMIZE_MAX" : 15,
|
||||
"ATOMIZE_FLANK" : 0,
|
||||
"MAXPROTCHAINS" : 6,
|
||||
"MAXLIGCHAINS" : 10,
|
||||
"MAXMASKEDLIGATOMS": 30,
|
||||
"P_METAL" : 0.75,
|
||||
"P_ATOMIZE_MODRES" : 0.75,
|
||||
"MAXMONOMERLENGTH" : None,
|
||||
"ATOMIZE_CLUSTER" : True,
|
||||
"P_ATOMIZE_TEMPLATE": 0.0,
|
||||
"NUM_SEQS_SUBSAMPLE": 50,
|
||||
"BLACK_HOLE_INIT" : False,
|
||||
"SHOW_SM_TEMPLATES" : False,
|
||||
"BATCH_BY_DATASET" : False,
|
||||
"BATCH_BY_LENGTH" : False,
|
||||
}
|
||||
|
||||
def set_data_loader_params(loader_params):
|
||||
""" add things from config into default dataloader params """
|
||||
for param in default_dataloader_params:
|
||||
if hasattr(loader_params, param.lower()):
|
||||
default_dataloader_params[param] = getattr(loader_params, param.lower())
|
||||
|
||||
# cursed but add things in the param but not the default params back
|
||||
loader_params_dict = dict(loader_params)
|
||||
for param, value in loader_params_dict.items():
|
||||
if param.upper() not in default_dataloader_params:
|
||||
default_dataloader_params[param] = value
|
||||
return default_dataloader_params
|
||||
|
||||
def compose_dataset(dataset_params, loader_params, rank, world_size):
|
||||
# define dataset & data loader
|
||||
# this function overrides the default dataloader params with those in the config
|
||||
#TODO: cache this in checkpoints so checkpoints use the same dataloder params as training
|
||||
loader_params = set_data_loader_params(loader_params=loader_params)
|
||||
train_ID_dict, valid_ID_dict, weights_dict, train_dict, valid_dict, homo, chid2hash, chid2taxid, chid2smpartners = \
|
||||
get_train_valid_set(loader_params)
|
||||
|
||||
# define atomize_pdb train/valid sets, which use the same examples as pdb set
|
||||
train_ID_dict['atomize_pdb'] = train_ID_dict['pdb']
|
||||
valid_ID_dict['atomize_pdb'] = valid_ID_dict['pdb']
|
||||
weights_dict['atomize_pdb'] = weights_dict['pdb']
|
||||
train_dict['atomize_pdb'] = train_dict['pdb']
|
||||
valid_dict['atomize_pdb'] = valid_dict['pdb']
|
||||
|
||||
# define atomize_pdb train/valid sets, which use the same examples as pdb set
|
||||
train_ID_dict['atomize_complex'] = train_ID_dict['compl']
|
||||
valid_ID_dict['atomize_complex'] = valid_ID_dict['compl']
|
||||
weights_dict['atomize_complex'] = weights_dict['compl']
|
||||
train_dict['atomize_complex'] = train_dict['compl']
|
||||
valid_dict['atomize_complex'] = valid_dict['compl']
|
||||
|
||||
# reweight fb examples containing disulfide loops
|
||||
to_reweight_ex = train_dict['fb']['HAS_DSLF_LOOP']
|
||||
to_reweight_cluster = train_dict['fb'][to_reweight_ex].CLUSTER.unique()
|
||||
reweight_mask = np.in1d(train_ID_dict['fb'],to_reweight_cluster)
|
||||
weights_dict['fb'][ reweight_mask ] *= dataset_params['dslf_fb_upsample']
|
||||
|
||||
# set number of validation examples being used
|
||||
for k in valid_dict:
|
||||
if dataset_params['n_valid_'+k] is None:
|
||||
dataset_params["n_valid_"+k] = len(valid_dict[k])
|
||||
|
||||
loader_dict = dict(
|
||||
pdb = loader_pdb,
|
||||
peptide = loader_pdb,
|
||||
compl = loader_complex,
|
||||
neg_compl = loader_complex,
|
||||
na_compl = loader_na_complex,
|
||||
neg_na_compl = loader_na_complex,
|
||||
distil_tf = loader_distil_tf,
|
||||
tf = loader_tf_complex,
|
||||
neg_tf = loader_tf_complex,
|
||||
fb = loader_fb,
|
||||
rna = loader_dna_rna,
|
||||
dna = loader_dna_rna,
|
||||
sm_compl = loader_sm_compl_assembly_single,
|
||||
metal_compl = loader_sm_compl_assembly_single,
|
||||
sm_compl_multi = loader_sm_compl_assembly_single,
|
||||
sm_compl_covale = loader_sm_compl_assembly_single,
|
||||
sm_compl_asmb = loader_sm_compl_assembly,
|
||||
sm = loader_sm,
|
||||
atomize_pdb = loader_atomize_pdb,
|
||||
atomize_complex = loader_atomize_complex,
|
||||
sm_compl_furthest_neg = loader_sm_compl_assembly,
|
||||
sm_compl_permuted_neg = loader_sm_compl_assembly,
|
||||
sm_compl_docked_neg = loader_sm_compl_assembly,
|
||||
)
|
||||
|
||||
train_set = DistilledDataset(
|
||||
train_ID_dict, train_dict, loader_dict, homo, chid2hash, chid2taxid, chid2smpartners,
|
||||
loader_params, native_NA_frac=0.25,
|
||||
p_short_crop=dataset_params['p_short_crop'],
|
||||
p_dslf_crop=dataset_params['p_dslf_crop'])
|
||||
|
||||
train_sampler = DistributedWeightedSampler(
|
||||
train_set,
|
||||
weights_dict,
|
||||
num_example_per_epoch=dataset_params['n_train'],
|
||||
fractions=OrderedDict([(k, dataset_params['fraction_'+k]) for k in train_dict]),
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
lengths=loader_params["EXAMPLE_LENGTHS"],
|
||||
batch_by_dataset=loader_params["BATCH_BY_DATASET"],
|
||||
batch_by_length=loader_params["BATCH_BY_LENGTH"],
|
||||
)
|
||||
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=1, **loader_params["dataloader_kwargs"])
|
||||
|
||||
valid_sets = dict(
|
||||
atomize_pdb = Dataset(
|
||||
valid_ID_dict['atomize_pdb'][:dataset_params['n_valid_atomize_pdb']],
|
||||
loader_atomize_pdb, valid_dict['atomize_pdb'],
|
||||
loader_params, homo, p_homo_cut=-1.0, n_res_atomize=9, flank=0, p_short_crop=-1.0
|
||||
),
|
||||
atomize_complex = Dataset(
|
||||
valid_ID_dict['atomize_complex'][:dataset_params['n_valid_atomize_complex']],
|
||||
loader_atomize_complex, valid_dict['atomize_complex'],
|
||||
loader_params, homo, p_homo_cut=-1.0, n_res_atomize=9, flank=0, p_short_crop=-1.0
|
||||
),
|
||||
pdb = Dataset(
|
||||
valid_ID_dict['pdb'][:dataset_params['n_valid_pdb']],
|
||||
loader_pdb, valid_dict['pdb'],
|
||||
loader_params, homo, p_homo_cut=-1.0, p_short_crop=-1.0, p_dslf_crop=-1.0
|
||||
),
|
||||
dslf = Dataset(
|
||||
valid_ID_dict['dslf'][:dataset_params['n_valid_dslf']],
|
||||
loader_pdb, valid_dict['dslf'],
|
||||
loader_params, homo, p_homo_cut=-1.0, p_short_crop=-1.0, p_dslf_crop=1.0
|
||||
),
|
||||
homo = Dataset(
|
||||
valid_ID_dict['homo'][:dataset_params['n_valid_homo']],
|
||||
loader_pdb, valid_dict['homo'],
|
||||
loader_params, homo, p_homo_cut=1.0, p_short_crop=-1.0, p_dslf_crop=-1.0
|
||||
),
|
||||
rna = DatasetRNA(
|
||||
valid_ID_dict['rna'][:dataset_params['n_valid_rna']],
|
||||
loader_dna_rna, valid_dict['rna'],
|
||||
loader_params,
|
||||
),
|
||||
dna = DatasetRNA(
|
||||
valid_ID_dict['dna'][:dataset_params['n_valid_dna']],
|
||||
loader_dna_rna, valid_dict['dna'],
|
||||
loader_params,
|
||||
),
|
||||
distil_tf = DatasetNAComplex(
|
||||
valid_ID_dict['distil_tf'][:dataset_params['n_valid_distil_tf']],
|
||||
loader_distil_tf, valid_dict['distil_tf'],
|
||||
loader_params, negative=False, native_NA_frac=0.0
|
||||
),
|
||||
metal_compl = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['metal_compl'][:dataset_params['n_valid_metal_compl']],
|
||||
loader_sm_compl_assembly, valid_dict['metal_compl'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='metal_compl',
|
||||
num_protein_chains=1,
|
||||
),
|
||||
sm_compl = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['sm_compl'][:dataset_params['n_valid_sm_compl']],
|
||||
loader_sm_compl_assembly, valid_dict['sm_compl'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl',
|
||||
num_protein_chains=1,
|
||||
),
|
||||
sm_compl_multi = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['sm_compl_multi'][:dataset_params['n_valid_sm_compl_multi']],
|
||||
loader_sm_compl_assembly, valid_dict['sm_compl_multi'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl_multi',
|
||||
num_protein_chains=1,
|
||||
),
|
||||
sm_compl_covale = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['sm_compl_covale'][:dataset_params['n_valid_sm_compl_covale']],
|
||||
loader_sm_compl_assembly, valid_dict['sm_compl_covale'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl_covale',
|
||||
num_protein_chains=1,
|
||||
),
|
||||
sm_compl_strict = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['sm_compl_strict'][:dataset_params['n_valid_sm_compl_strict']],
|
||||
loader_sm_compl_assembly, valid_dict['sm_compl_strict'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl_strict',
|
||||
num_protein_chains=1,
|
||||
),
|
||||
sm_compl_asmb = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['sm_compl_asmb'][:dataset_params['n_valid_sm_compl_asmb']],
|
||||
loader_sm_compl_assembly, valid_dict['sm_compl_asmb'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl_asmb'
|
||||
),
|
||||
sm = DatasetSM(
|
||||
valid_ID_dict['sm'][:dataset_params['n_valid_sm']],
|
||||
loader_sm, valid_dict['sm'],
|
||||
loader_params,
|
||||
),
|
||||
)
|
||||
|
||||
valid_headers = dict(
|
||||
distil_tf = 'TF_Distil',
|
||||
pdb = 'Monomer',
|
||||
dslf = 'Disulfide_loop',
|
||||
homo = 'Homo',
|
||||
rna = 'RNA',
|
||||
dna = 'DNA',
|
||||
sm_compl = 'SM_Compl',
|
||||
metal_compl = 'Metal_ion',
|
||||
sm_compl_multi = 'Multires_ligand',
|
||||
sm_compl_covale = "Covalent_ligand",
|
||||
sm_compl_strict = 'SM_Compl_(strict)',
|
||||
sm = 'SM_CSD',
|
||||
atomize_pdb = 'Monomer_atomize',
|
||||
atomize_complex = 'Complex_atomize',
|
||||
sm_compl_asmb = 'SMCompl_Assembly',
|
||||
)
|
||||
valid_samplers = OrderedDict([
|
||||
(k, data.distributed.DistributedSampler(v, num_replicas=world_size, rank=rank))
|
||||
for k,v in valid_sets.items()
|
||||
])
|
||||
valid_loaders = OrderedDict([
|
||||
(k, data.DataLoader(v, sampler=valid_samplers[k], **loader_params["dataloader_kwargs"]))
|
||||
for k,v in valid_sets.items()
|
||||
])
|
||||
return train_loader, train_sampler, valid_loaders, valid_samplers
|
||||
|
||||
|
||||
def compose_posebusters(loader_params, rank, world_size):
|
||||
loader_params = set_data_loader_params(loader_params=loader_params)
|
||||
valid_ID_dict, valid_dict = {}, {}
|
||||
|
||||
valid_dict["benchmark"] = _load_df("/home/rohith/RF2ligand/posebusters_benchmark.csv", pad_hash=False, eval_cols=["LIGAND", "PARTNERS", "LIGXF"])
|
||||
valid_ID_dict["benchmark"] = valid_dict["benchmark"]["CLUSTER"]
|
||||
|
||||
with open(
|
||||
"/projects/ml/RF2_allatom/posebusters/posebusters_chid2hash_081723.pkl", "rb"
|
||||
) as f:
|
||||
chid2hash = pickle.load(f)
|
||||
|
||||
with open(
|
||||
"/projects/ml/RF2_allatom/posebusters/posebusters_chid2taxid_081723.pkl", "rb"
|
||||
) as f:
|
||||
chid2taxid = pickle.load(f)
|
||||
loader_params["MINTPLT"] = 0
|
||||
loader_params["MAXTPLT"] = 0
|
||||
loader_params["PDB_DIR"] = "/projects/ml/RF2_allatom/benchmark"
|
||||
|
||||
benchmark = DatasetSMComplexAssembly(
|
||||
valid_ID_dict['benchmark'],
|
||||
loader_sm_compl_assembly, valid_dict['benchmark'],
|
||||
chid2hash, chid2taxid, # used for MSA generation of assemblies
|
||||
loader_params,
|
||||
task='sm_compl',
|
||||
num_protein_chains=1,
|
||||
num_ligand_chains=2,
|
||||
)
|
||||
sampler = data.distributed.DistributedSampler(benchmark, rank=rank, num_replicas=world_size)
|
||||
loader = data.DataLoader(benchmark, sampler=sampler, **loader_params["dataloader_kwargs"])
|
||||
return loader
|
||||
|
||||
def compose_single_item_dataset(item, loader_params, loader, loader_kwargs):
|
||||
class SpoofDataset(data.Dataset):
|
||||
def __init__(self, loader_params, loader, loader_kwargs) -> None:
|
||||
super().__init__()
|
||||
self.loader_params = loader_params
|
||||
self.loader = loader
|
||||
self.loader_kwargs = loader_kwargs
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.loader(item, self.loader_params, **self.loader_kwargs)
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
dataset = SpoofDataset(loader_params, loader, loader_kwargs)
|
||||
loader = data.DataLoader(dataset, **loader_params["dataloader_kwargs"])
|
||||
return loader
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import warnings
|
||||
import time
|
||||
import deepdiff
|
||||
from icecream import ic
|
||||
from torch.utils import data
|
||||
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
|
||||
@@ -19,121 +18,27 @@ sys.path.append(script_dir+'/../')
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils import data
|
||||
import scipy
|
||||
from scipy.sparse.csgraph import shortest_path
|
||||
import networkx as nx
|
||||
|
||||
from rf2aa import cifutils
|
||||
from rf2aa.parsers import parse_a3m, parse_pdb, parse_fasta_if_exists, parse_mol, parse_mixed_fasta, get_dislf
|
||||
import rf2aa.cifutils as cifutils
|
||||
from rf2aa.data.parsers import parse_a3m, parse_pdb, parse_fasta_if_exists, parse_mol, parse_mixed_fasta, get_dislf
|
||||
from rf2aa.chemical import INIT_CRDS, INIT_NA_CRDS, NAATOKENS, MASKINDEX, UNKINDEX, \
|
||||
NTOTAL, NBTYPES, CHAIN_GAP, num2aa, METAL_RES_NAMES, aa2num, atomnum2atomtype, load_tanimoto_sim_matrix
|
||||
NTOTAL, NBTYPES, CHAIN_GAP, num2aa, METAL_RES_NAMES, aa2num, atomnum2atomtype, load_tanimoto_sim_matrix, NPROTAAS
|
||||
|
||||
from rf2aa.kinematics import get_chirals
|
||||
from rf2aa.symmetry import get_symmetry
|
||||
from rf2aa.identical_ligands import get_extra_identical_copies_from_chains
|
||||
from rf2aa.data.identical_ligands import get_extra_identical_copies_from_chains
|
||||
from rf2aa.util import get_nxgraph, get_atom_frames, get_bond_feats, get_protein_bond_feats, \
|
||||
center_and_realign_missing, random_rot_trans, allatom_mask, cif_prot_to_xyz, \
|
||||
cif_ligand_to_xyz, cif_ligand_to_obmol, get_automorphs, get_ligand_atoms_bonds, \
|
||||
map_identical_prot_chains, cartprodcat, idx_from_Ls, same_chain_2d_from_Ls, bond_feats_from_Ls, \
|
||||
reindex_protein_feats_after_atomize, get_residue_contacts, atomize_discontiguous_residues, pop_protein_feats, \
|
||||
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, is_nucleic
|
||||
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, \
|
||||
is_protein, is_nucleic, is_atom
|
||||
|
||||
# faster for remote/tukwila nodes
|
||||
#base_dir = "/databases/TrRosetta/PDB-2021AUG02"
|
||||
#compl_dir = "/databases/TrRosetta/RoseTTAComplex"
|
||||
#na_dir = "/databases/TrRosetta/nucleic"
|
||||
#sm_compl_dir = "/databases/TrRosetta/RF2_allatom"
|
||||
#mol_dir = "/databases/TrRosetta/RF2_allatom/by-pdb"
|
||||
csd_dir = "/databases/csd543"
|
||||
|
||||
# older paths, still good but best for local/UW nodes
|
||||
base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02"
|
||||
compl_dir = "/projects/ml/RoseTTAComplex"
|
||||
na_dir = "/projects/ml/nucleic"
|
||||
fb_dir = "/projects/ml/TrRosetta/fb_af"
|
||||
sm_compl_dir = "/projects/ml/RF2_allatom"
|
||||
mol_dir = "/projects/ml/RF2_allatom/rcsb/pkl" # for phase 3 dataloaders
|
||||
# mol_dir = "/projects/ml/RF2_allatom/isdf" # for legacy datasets
|
||||
tf_dir = "/projects/ml/prot_dna"
|
||||
|
||||
default_dataloader_params = {
|
||||
"COMPL_LIST" : "%s/list.hetero.csv"%compl_dir,
|
||||
"HOMO_LIST" : "%s/list.homo.csv"%compl_dir,
|
||||
"NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir,
|
||||
"RNA_LIST" : "%s/list.rnaonly.csv"%na_dir,
|
||||
"DNA_LIST" : "%s/list.dnaonly.v3.csv"%na_dir,
|
||||
"NA_COMPL_LIST" : "%s/list.nucleic.v3.csv"%na_dir,
|
||||
"NEG_NA_COMPL_LIST": "%s/list.na_negatives.v3.csv"%na_dir,
|
||||
"TF_DISTIL_LIST" : "%s/prot_na_distill.v3.csv"%tf_dir,
|
||||
"TF_COMPL_LIST" : "%s/tf_compl_list.v4.csv"%tf_dir,
|
||||
"SM_LIST" : "%s/sm_compl_all_20230418.csv"%sm_compl_dir,
|
||||
"PDB_LIST" : "%s/list_v02_w_taxid.csv"%sm_compl_dir, # on digs
|
||||
"PDB_METADATA" : "%s/list_v00_w_taxid_20230201.csv"%sm_compl_dir, # on digs
|
||||
"FB_LIST" : "%s/list_b1-3.csv"%fb_dir,
|
||||
"CSD_LIST" : "%s/csd543_cleaned01.csv"%csd_dir,
|
||||
"VAL_PDB" : "%s/valid_remapped"%sm_compl_dir,
|
||||
"VAL_RNA" : "%s/rna_valid.csv"%na_dir,
|
||||
"VAL_DNA" : "%s/dna_valid.csv"%na_dir,
|
||||
"VAL_COMPL" : "%s/val_lists/xaa"%compl_dir,
|
||||
"VAL_NEG" : "%s/val_lists/xaa.neg"%compl_dir,
|
||||
"VAL_TF" : "%s/tf_valid_clusters_v4.txt"%tf_dir,
|
||||
"VAL_SM_STRICT" : "%s/sm_compl_valid_strict_20230418.csv"%sm_compl_dir,
|
||||
"TEST_SM" : "%s/sm_test_heldout_test_clusters.txt"%sm_compl_dir,
|
||||
"DATAPKL" : "%s/dataset_20231116.pkl"%sm_compl_dir, # cache for faster loading
|
||||
"DSLF_LIST" : "%s/list.dslf.csv"%na_dir,
|
||||
"DSLF_FB_LIST" : "%s/list.dslf_fb.csv"%na_dir,
|
||||
"DUDE_LIST" : "/home/dnan/projects/gald_distil_set/nbs/dude_dataset_cutoff_-5.csv", # on digs (dnan)
|
||||
"DUDE_MSAS" : "/home/dnan/projects/gald_distil_set/DUDE/fastas", # on digs (dnan)
|
||||
"DUDE_PDB_DIR" : "/home/dnan/projects/gald_distil_set/DUDE/pdbs_all",
|
||||
"PDB_DIR" : base_dir,
|
||||
"FB_DIR" : fb_dir,
|
||||
"COMPL_DIR" : compl_dir,
|
||||
"NA_DIR" : na_dir,
|
||||
"TF_DIR" : tf_dir,
|
||||
"MOL_DIR" : mol_dir,
|
||||
"CSD_DIR" : csd_dir,
|
||||
"MINTPLT" : 0,
|
||||
"MAXTPLT" : 5,
|
||||
"MINSEQ" : 1,
|
||||
"MAXSEQ" : 1024,
|
||||
"MAXLAT" : 128,
|
||||
"CROP" : 256,
|
||||
"DATCUT" : "2021-Aug-1",
|
||||
"RESCUT" : 4.5,
|
||||
"BLOCKCUT" : 5,
|
||||
"PLDDTCUT" : 70.0,
|
||||
"SCCUT" : 90.0,
|
||||
"ROWS" : 1,
|
||||
"SEQID" : 95.0,
|
||||
"MAXCYCLE" : 4,
|
||||
"RMAX" : 5.0,
|
||||
"MAXRES" : 1,
|
||||
"MINATOMS" : 5,
|
||||
"MAXATOMS" : 100,
|
||||
"MAXSIM" : 0.85,
|
||||
"MAXNSYMM" : 1024,
|
||||
"NRES_ATOMIZE_MIN" : 5,
|
||||
"NRES_ATOMIZE_MAX" : 15,
|
||||
"ATOMIZE_FLANK" : 0,
|
||||
"MAXPROTCHAINS" : 6,
|
||||
"MAXLIGCHAINS" : 10,
|
||||
"MAXMASKEDLIGATOMS": 30,
|
||||
"P_METAL" : 0.75,
|
||||
"P_ATOMIZE_MODRES" : 0.75,
|
||||
"MAXMONOMERLENGTH" : None,
|
||||
"ATOMIZE_CLUSTER" : True,
|
||||
"P_ATOMIZE_TEMPLATE": 1.0,
|
||||
"NUM_SEQS_SUBSAMPLE": 50,
|
||||
"BLACK_HOLE_INIT" : False,
|
||||
"SHOW_SM_TEMPLATES" : True,
|
||||
}
|
||||
|
||||
def set_data_loader_params(args):
|
||||
for param in default_dataloader_params:
|
||||
if hasattr(args, param.lower()):
|
||||
default_dataloader_params[param] = getattr(args, param.lower())
|
||||
return default_dataloader_params
|
||||
assert "rf2aa" in os.path.abspath(cifutils.__file__)
|
||||
|
||||
def MSABlockDeletion(msa, ins, nb=5):
|
||||
'''
|
||||
@@ -197,6 +102,12 @@ def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
|
||||
- insertion info (1)
|
||||
- N-term or C-term? (2)
|
||||
'''
|
||||
# Truncate MSA (for efficiency when pre-computing lengths)
|
||||
if params.get("MSA_LIMIT") is not None:
|
||||
# Raise a warning that we are truncating the MSA
|
||||
warnings.warn(f"Truncating MSA to {params['MSA_LIMIT']} sequences. Only to be used for length pre-computation, NOT training.")
|
||||
msa = msa[:params["MSA_LIMIT"]]
|
||||
|
||||
if fixbb:
|
||||
p_mask = 0
|
||||
msa = msa[:1]
|
||||
@@ -258,18 +169,28 @@ def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
|
||||
# - 10%: aa replaced with an amino acid sampled from the MSA profile
|
||||
# - 10%: not replaced
|
||||
# - 70%: replaced with a special token ("mask")
|
||||
seq = msa_clust[0]
|
||||
|
||||
random_aa = torch.tensor([[0.05]*20 + [0.0]*(NAATOKENS-20)], device=msa.device)
|
||||
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=NAATOKENS)
|
||||
# explicitly remove probabilities from nucleic acids and atoms
|
||||
same_aa[..., NPROTAAS:] = 0
|
||||
raw_profile[...,NPROTAAS:] = 0
|
||||
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
|
||||
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
|
||||
probs[...,MASKINDEX]=0.7
|
||||
|
||||
|
||||
# explicitly set the probability of masking for nucleic acids and atoms
|
||||
probs[...,is_protein(seq),MASKINDEX]=0.7
|
||||
probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
|
||||
probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
|
||||
probs[0,is_nucleic(seq), MASKINDEX] = 1.0
|
||||
probs[0,is_atom(seq), aa2num["ATM"]] = 1.0
|
||||
|
||||
sampler = torch.distributions.categorical.Categorical(probs=probs)
|
||||
mask_sample = sampler.sample()
|
||||
|
||||
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
|
||||
mask_pos[msa_clust>MASKINDEX]=False # no masking on NAs
|
||||
|
||||
#mask_pos[msa_clust>MASKINDEX]=False # no masking on NAs
|
||||
use_seq = msa_clust
|
||||
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
|
||||
b_seq.append(msa_masked[0].clone())
|
||||
@@ -369,7 +290,6 @@ def blank_template(n_tmpl, L, random_noise=5.0):
|
||||
|
||||
|
||||
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5):
|
||||
|
||||
seqID_cut = params['SEQID']
|
||||
|
||||
if npick_global == None:
|
||||
@@ -764,6 +684,15 @@ def add_negative_sets(
|
||||
valid_ID_dict["dude_inactives"] = dude_inactives_df["CLUSTER"].unique()
|
||||
return train_ID_dict, valid_ID_dict, weights_dict, train_dict, valid_dict
|
||||
|
||||
def _load_df(filename, pad_hash=True, eval_cols=[]):
|
||||
"""load dataframe, zero-pad hash string, parse columns as python objects"""
|
||||
df = pd.read_csv(filename, na_filter=False) # prevents chain "NA" loading as NaN
|
||||
if pad_hash: # restore leading zeros, make into string
|
||||
df['HASH'] = df['HASH'].apply(lambda x: f'{x:06d}')
|
||||
for col in eval_cols:
|
||||
df[col] = df[col].apply(lambda x: ast.literal_eval(x)) # interpret as list of strings
|
||||
return df
|
||||
|
||||
|
||||
def get_train_valid_set(params, NEG_CLUSID_OFFSET=1000000, no_match_okay=False, diffusion_training=False, add_negatives: bool = True):
|
||||
"""Loads training/validation sets as pandas DataFrames and returns them in
|
||||
@@ -814,15 +743,6 @@ def get_train_valid_set(params, NEG_CLUSID_OFFSET=1000000, no_match_okay=False,
|
||||
f"re-parsing train/valid metadata...")
|
||||
|
||||
# helper functions
|
||||
def _load_df(filename, pad_hash=True, eval_cols=[]):
|
||||
"""load dataframe, zero-pad hash string, parse columns as python objects"""
|
||||
df = pd.read_csv(filename, na_filter=False) # prevents chain "NA" loading as NaN
|
||||
if pad_hash: # restore leading zeros, make into string
|
||||
df['HASH'] = df['HASH'].apply(lambda x: f'{x:06d}')
|
||||
for col in eval_cols:
|
||||
df[col] = df[col].apply(lambda x: ast.literal_eval(x)) # interpret as list of strings
|
||||
return df
|
||||
|
||||
def _apply_date_res_cutoffs(df):
|
||||
"""filter dataframe by date and resolution cutoffs"""
|
||||
return df[(df.RESOLUTION <= params['RESCUT']) &
|
||||
@@ -1359,6 +1279,22 @@ def get_na_crop(seq, xyz, mask, sel, len_s, params, negative=False, incl_protein
|
||||
|
||||
return sel
|
||||
|
||||
def adjust_samples_for_num_replicas(num_per_epoch_dict, num_replicas):
|
||||
"""
|
||||
Modifies the number of examples per epoch for each dataset to be divisible by num_replicas.
|
||||
Args:
|
||||
num_per_epoch_dict (dict): Mapping from dataset name to number of examples to sample from that dataset per epoch
|
||||
num_replicas (int): The number of nodes/GPUs in the distributed training setup.
|
||||
|
||||
Returns:
|
||||
dict: The modified num_per_epoch_dict where the number of examples per epoch for each dataset is divisible by num_replicas.
|
||||
"""
|
||||
adjusted_dict = {}
|
||||
for dataset, num_per_epoch in num_per_epoch_dict.items():
|
||||
# Round down to nearest multiple of num_replicas
|
||||
adjusted_num = num_per_epoch // num_replicas * num_replicas
|
||||
adjusted_dict[dataset] = adjusted_num
|
||||
return adjusted_dict
|
||||
|
||||
def find_msa_hashes(protein_chain_info, params):
|
||||
"""
|
||||
@@ -1606,52 +1542,6 @@ def remove_all_gap_seqs(a3m):
|
||||
a3m['ins'] = a3m['ins'][idx_seq_keep]
|
||||
return a3m
|
||||
|
||||
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
|
||||
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
|
||||
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
|
||||
ID, only the sequence with the highest sequence identity to the query (1st
|
||||
sequence in MSA) will be paired.
|
||||
|
||||
Sequences that aren't paired will be padded and added to the bottom of the
|
||||
joined MSA. If a subregion of the input MSAs overlap (represent the same
|
||||
chain), the subregion residue indices can be given as `idx_overlap`, and
|
||||
the overlap region of the unpaired sequences will be included in the joined
|
||||
MSA.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a3mA : dict
|
||||
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
|
||||
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
|
||||
boolean tensor indicating whether each sequence is fully paired. Can be
|
||||
a multi-MSA (contain >2 sub-MSAs).
|
||||
a3mB : dict
|
||||
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
|
||||
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
|
||||
idx_overlap : tuple or list (optional)
|
||||
Start and end indices of overlap region in 1st MSA, followed by the
|
||||
same in 2nd MSA.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a3m : dict
|
||||
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
|
||||
"""
|
||||
# preprocess & sanity check overlap region
|
||||
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
|
||||
if idx_overlap is not None:
|
||||
i1A, i2A, i1B, i2B = idx_overlap
|
||||
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
|
||||
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
|
||||
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
|
||||
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
|
||||
else:
|
||||
i1B_new, i2B_new = (0, L_B)
|
||||
|
||||
# paired sequences
|
||||
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
|
||||
i_pairedA, i_pairedB = [], []
|
||||
|
||||
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
|
||||
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
|
||||
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
|
||||
@@ -2718,7 +2608,7 @@ def loader_na_complex(item, params, native_NA_frac=0.05, negative=False, pick_to
|
||||
torch.load(params['PDB_DIR']+'/torch/pdb/'+pdb_ids[0][1:3]+'/'+pdb_ids[0]+'.pt'),
|
||||
torch.load(params['PDB_DIR']+'/torch/pdb/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.pt')
|
||||
]
|
||||
filenameB1 = params['NA_DIR']+'/torch/'+pdb_ids[2][1:3]+'/'+pdb_ids[2]+'.pt'
|
||||
filenameB1 = params['NA_DIR'] + '/torch/' + pdb_ids[2][1:3] + '/' + pdb_ids[2] + '.pt'
|
||||
filenameB2 = params['NA_DIR']+'/torch/'+pdb_ids[3][1:3]+'/'+pdb_ids[3]+'.pt'
|
||||
if os.path.exists(filenameB1+".v3"):
|
||||
filenameB1 = filenameB1+".v3"
|
||||
@@ -3230,7 +3120,7 @@ def loader_distil_tf(item, params, random_noise=5.0, pick_top=True, native_NA_fr
|
||||
xyz, mask, _, pdbseq = parse_pdb(
|
||||
params["TF_DIR"]+f'/distill_v2/filtered/{gene_id[:2]}/{gene_id}_{dnaseq}.pdb',
|
||||
seq=True,
|
||||
lddtmask=True
|
||||
lddt_mask=True
|
||||
)
|
||||
|
||||
xyz = torch.from_numpy(xyz)
|
||||
@@ -3637,9 +3527,10 @@ def featurize_asmb_prot(pdb_id, partners, params, chains, asmb_xfs, modres,
|
||||
chnum += 1
|
||||
|
||||
## protein templates
|
||||
random_noise = 0.0
|
||||
ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']+1)
|
||||
if chid2hash is None or ntempl < 1:
|
||||
xyz_t_ch, f1d_t_ch, mask_t_ch = \
|
||||
xyz_t_ch, f1d_t_ch, mask_t_ch, tplt_ids_ch = \
|
||||
blank_template(n_tmpl=1, L=xyz_ch.shape[1], random_noise=random_noise)
|
||||
else:
|
||||
pdb_hash = chid2hash[pdb_id+'_'+list(chlet_set)[0]] # chlet_set all have same hash
|
||||
@@ -4020,11 +3911,15 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
|
||||
"""
|
||||
pdb_chain = item['CHAINID']
|
||||
pdb_id = pdb_chain.split('_')[0]
|
||||
|
||||
# load pre-parsed cif assembly - requires cifutils.py in path for object definitions
|
||||
chains, asmb, covale, modres = \
|
||||
out = \
|
||||
pickle.load(gzip.open(params['MOL_DIR']+f'/{pdb_id[1:3]}/{pdb_id}.pkl.gz'))
|
||||
|
||||
if len(out) == 4:
|
||||
chains, asmb, covale, modres = out
|
||||
elif len(out) == 5:
|
||||
chains, asmb, covale, meta, modres = out
|
||||
else:
|
||||
raise ValueError(f"cif parser returns {len(out)} values")
|
||||
# list of proteins and ligands to featurize
|
||||
prot_partners = [p for p in item['PARTNERS'] if p[-1]=='polypeptide(L)']
|
||||
prot_partners = prot_partners[:params['MAXPROTCHAINS']]
|
||||
@@ -4075,6 +3970,7 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
|
||||
|
||||
# combine protein & ligand templates
|
||||
N_tmpl = xyz_t_prot.shape[0]
|
||||
random_noise = 0.0
|
||||
if chid2smpartners is not None and params["SHOW_SM_TEMPLATES"]:
|
||||
assert num_protein_chains == 1, "templating ligands not supported for multiple protein chains (complications in xyz_prev)"
|
||||
xyz_t_sm, f1d_t_sm, mask_t_sm = generate_sm_template_feats(tplt_ids, resnames, akeys_sm, Ls_sm,chid2smpartners, params)
|
||||
@@ -4165,7 +4061,8 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
|
||||
sel = crop_sm_compl(xyz_prot, xyz_sm[0], Ls_prot + Ls_sm, params['CROP'], mask_prot,
|
||||
seq_prot, select_farthest_residues=select_farthest_residues)
|
||||
else:
|
||||
sel = crop_sm_compl_assembly(xyz[0], mask[0], Ls_prot, Ls_sm, params['CROP'])
|
||||
#sel = crop_sm_compl_assembly(xyz[0], mask[0], Ls_prot, Ls_sm, params['CROP'])
|
||||
sel = crop_sm_compl_asmb_contig(xyz[0], mask[0], Ls_prot, Ls_sm, bond_feats, params['CROP'], use_partial_ligands=False)
|
||||
mask = reassign_symmetry_after_cropping(sel, Ls_prot, ch_label, mask, item)
|
||||
|
||||
msa = msa[:, sel]
|
||||
@@ -4202,7 +4099,7 @@ def loader_sm_compl_assembly(item, params, chid2hash=None, chid2taxid=None, chid
|
||||
if max_msa_seqs is not None:
|
||||
msa = msa[:max_msa_seqs]
|
||||
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = \
|
||||
MSAFeaturize(msa.long(), ins.long(), params, term_info=term_info, fixbb=fixbb, seed_msa_clus=seed_msa_clus)
|
||||
MSAFeaturize(msa.long(), ins.long(), params, p_mask=params["p_msa_mask"], term_info=term_info, fixbb=fixbb, seed_msa_clus=seed_msa_clus)
|
||||
|
||||
return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\
|
||||
xyz.float(), mask, idx.long(), \
|
||||
@@ -4693,6 +4590,7 @@ def crop_sm_compl(prot_xyz, lig_xyz, Ls, crop_size, mask_prot, seq_prot,
|
||||
remaining_residues_to_select = crop_size - len(lig_xyz) - min_resolved_residues
|
||||
assert remaining_residues_to_select > 0, f"For some reason I encountered a scenario in which we are cropping a protein but I was unable to select enough residues from that protein. This probably means you passed a protein that was too short into this function. {crop_size}, {len(lig_xyz)}, {min_resolved_residues}"
|
||||
dist[sel_min_resolved] = nan_fill_value
|
||||
remaining_residues_to_select = min(remaining_residues_to_select, len(dist))
|
||||
_, sel_remaining = torch.topk(dist, remaining_residues_to_select)
|
||||
sel_unsorted = torch.concat((sel_min_resolved, sel_remaining))
|
||||
sel, _ = torch.sort(sel_unsorted)
|
||||
@@ -4845,6 +4743,121 @@ def crop_sm_compl_assembly(all_xyz, all_mask, Ls_prot, Ls_sm, n_crop, use_partia
|
||||
sel = prot_sel
|
||||
return torch.from_numpy(sel).long()
|
||||
|
||||
def crop_sm_compl_asmb_contig(all_xyz, all_mask, Ls_prot, Ls_sm, bond_feats, n_crop, use_partial_ligands=False):
|
||||
"""
|
||||
instead of conducting a radial crop around a random atom, construct a crop with contiguous protein segments
|
||||
the way this works is that a graph data structure is constructed where contiguous residues are connected,
|
||||
close interchain contacts are connected and residues within a ligand are fully connected. each edge is weighted
|
||||
and then the crop is chosen by selecting a random residue and traversing the graph to find the n_crop closest nodes
|
||||
"""
|
||||
def find_edges_based_on_distance(all_xyz, all_mask, chain_i_start_index, chain_i_end_index, chain_j_start_index, chain_j_end_index, dist_cutoff):
|
||||
xyz_chain_i = all_xyz[chain_i_start_index:chain_i_end_index]
|
||||
xyz_chain_j = all_xyz[chain_j_start_index:chain_j_end_index]
|
||||
|
||||
dist = torch.cdist(xyz_chain_i[:, 1], xyz_chain_j[:, 1]) # calpha distogram
|
||||
chain_i_ca_mask = all_mask[chain_i_start_index:chain_i_end_index, 1]
|
||||
chain_j_ca_mask = all_mask[chain_j_start_index:chain_j_end_index, 1]
|
||||
|
||||
mask_2d = chain_i_ca_mask[:, None] * chain_j_ca_mask[None, :]
|
||||
dist[~mask_2d] = 99999
|
||||
new_edges = (dist<dist_cutoff).nonzero()
|
||||
return new_edges
|
||||
L = all_xyz.shape[0]
|
||||
num_prot_chains = len(Ls_prot)
|
||||
num_sm_chains = len(Ls_sm)
|
||||
# construct weighted graph
|
||||
graph = np.full((L, L), n_crop, dtype=np.float32)
|
||||
|
||||
# set neighboring residues to have edge weight = 1
|
||||
for chain_index, L_prot in enumerate(Ls_prot):
|
||||
chain_start_index = sum(Ls_prot[:chain_index])
|
||||
residues = torch.arange(L_prot-1) + chain_start_index
|
||||
graph[residues, residues+1] = 1
|
||||
graph[residues+1, residues] = 1
|
||||
|
||||
# set all intra ligand chain values to 0 so that if one atom is sampled the whole ligand is sampled (we will still confirm this later)
|
||||
total_protein_L = sum(Ls_prot)
|
||||
for chain_index in range(len(Ls_sm)):
|
||||
chain_start_index = sum(Ls_sm[:chain_index])+ total_protein_L
|
||||
chain_end_index = sum(Ls_sm[:chain_index+1])+ total_protein_L
|
||||
graph[chain_start_index: chain_end_index, chain_start_index:chain_end_index] = 0.1
|
||||
|
||||
# set interchain edges between protein chains
|
||||
for chain_i, chain_j in itertools.combinations(range(num_prot_chains), 2):
|
||||
chain_i_start_index = sum(Ls_prot[:chain_i])
|
||||
chain_i_end_index = sum(Ls_prot[:chain_i+1])
|
||||
|
||||
chain_j_start_index = sum(Ls_prot[:chain_j])
|
||||
chain_j_end_index = sum(Ls_prot[:chain_j+1])
|
||||
|
||||
new_edges = find_edges_based_on_distance(all_xyz, all_mask, chain_i_start_index, chain_i_end_index, chain_j_start_index, chain_j_end_index, dist_cutoff=8)
|
||||
for edge in new_edges:
|
||||
start = edge[0] + chain_i_start_index
|
||||
end = edge[1] +chain_j_start_index
|
||||
graph[start,end] = 8
|
||||
graph[end, start]= 8
|
||||
|
||||
# set interchain edges between proteins and small molecules (non_covalent)
|
||||
for protein_chain, sm_chain in itertools.product(range(num_prot_chains), range(num_sm_chains)):
|
||||
protein_chain_start_index = sum(Ls_prot[:protein_chain])
|
||||
protein_chain_end_index = sum(Ls_prot[:protein_chain+1])
|
||||
|
||||
sm_chain_start_index = sum(Ls_sm[:sm_chain]) + total_protein_L
|
||||
sm_chain_end_index = sum(Ls_sm[:sm_chain+1]) + total_protein_L
|
||||
if torch.any(bond_feats[protein_chain_start_index:protein_chain_end_index][:, sm_chain_start_index: sm_chain_end_index] == 8): # skip chains that are covalently connected
|
||||
continue
|
||||
new_edges = find_edges_based_on_distance(all_xyz, all_mask, protein_chain_start_index, protein_chain_end_index, sm_chain_start_index, sm_chain_end_index, dist_cutoff=5)
|
||||
for edge in new_edges:
|
||||
start = edge[0] + protein_chain_start_index
|
||||
end = edge[1] +sm_chain_start_index
|
||||
graph[start,end] = 2
|
||||
graph[end, start]= 2
|
||||
|
||||
# edges to covalent modifications should be similar to residue edges not ligand edges
|
||||
covalent_bonds = (bond_feats==6).nonzero()
|
||||
for bond in covalent_bonds:
|
||||
graph[bond[0], bond[1]] = 1
|
||||
graph[bond[1], bond[0]] = 1
|
||||
|
||||
# find an interface residue to start at by finding random residue near a ligand atom
|
||||
starting_edges = find_edges_based_on_distance(all_xyz, all_mask, 0,Ls_prot[0],total_protein_L, total_protein_L+Ls_sm[0], dist_cutoff=10)
|
||||
if starting_edges.numel() == 0:
|
||||
startres = random.randint(0,Ls_prot[0])
|
||||
else:
|
||||
startres = random.choice(starting_edges[:, 0].unique()).item()
|
||||
|
||||
d_res = shortest_path(graph, directed = False, indices=startres)
|
||||
n_crop = min(d_res.shape[0], n_crop)
|
||||
|
||||
_, idx = torch.topk(torch.from_numpy(d_res).to(device=all_xyz.device), n_crop, largest=False)
|
||||
sel, _ = torch.sort(idx)
|
||||
|
||||
#make sure that all ligands were fully pulled into the crop
|
||||
# print(f"total number of chain: {len(Ls_sm)}")
|
||||
for sm_chain_index, L_sm in enumerate(Ls_sm):
|
||||
sm_chain_start_index = sum(Ls_sm[:sm_chain_index]) + total_protein_L
|
||||
chain_indices = torch.arange(L_sm) + sm_chain_start_index
|
||||
chain_in_crop = torch.isin(chain_indices, sel) # tensor with length chain_indices indicating which elements from chain_indices are in sel
|
||||
is_subset = torch.all(chain_in_crop)
|
||||
has_overlap = torch.any(chain_in_crop)
|
||||
if has_overlap == True:
|
||||
if is_subset == False:
|
||||
#if sm_chain_index == 0:
|
||||
# print("WARNING: PART OF QUERY LIGAND WAS CROPPED; ADDING REST BACK IN")
|
||||
# sel = torch.cat((sel, chain_indices), dim=0)
|
||||
# sel = sel.unique()
|
||||
# continue
|
||||
print("WARNING: removing partially cropped small molecule")
|
||||
print(f"chain: {sm_chain_index}")
|
||||
if use_partial_ligands == False:
|
||||
crop_in_chain = torch.isin(sel, chain_indices) # tensor with length of sel indicating which indices in sel are also in chain_indices
|
||||
sel = sel[~crop_in_chain]
|
||||
else:
|
||||
sel = torch.cat((sel, chain_indices), dim=0)
|
||||
sel = sel.unique()
|
||||
return sel
|
||||
|
||||
|
||||
def crop_chirals(chirals, atom_sel):
|
||||
"""
|
||||
this function returns only chiral centers that appear in molecules that are chosen after cropping
|
||||
@@ -4905,7 +4918,6 @@ def sample_item_sm_compl(df, ID, dedup_ligand=True):
|
||||
# uniformly sample from unique PDB chains
|
||||
chid = np.random.choice(tmp_df.CHAINID.drop_duplicates().values)
|
||||
tmp_df = tmp_df[tmp_df.CHAINID==chid]
|
||||
|
||||
if dedup_ligand and "LIGAND" in tmp_df:
|
||||
# uniform sample from unique ligands
|
||||
lignames = list(set([x[0][2] for x in tmp_df['LIGAND']]))
|
||||
@@ -4914,7 +4926,7 @@ def sample_item_sm_compl(df, ID, dedup_ligand=True):
|
||||
|
||||
item = tmp_df.sample(1).to_dict(orient='records')[0] # choose 1 random row
|
||||
return copy.deepcopy(item) # prevents dataframe from being modified by downstream changes
|
||||
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(
|
||||
@@ -5401,7 +5413,6 @@ class DistilledDataset(data.Dataset):
|
||||
except Exception as e:
|
||||
print('error loading',item, '\n',repr(e), task)
|
||||
raise e
|
||||
|
||||
return out
|
||||
|
||||
class DistributedWeightedSampler(data.Sampler):
|
||||
@@ -5436,83 +5447,164 @@ class DistributedWeightedSampler(data.Sampler):
|
||||
),
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
replacement=False
|
||||
datasets_with_replacement=["pdb", "fb", "compl", "neg_compl", "na_compl", "neg_na_compl", "distil_tf", "tf", "neg_tf", "rna", "dna"],
|
||||
lengths=None,
|
||||
batch_by_dataset=False,
|
||||
batch_by_length=False,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
assert num_example_per_epoch % num_replicas == 0
|
||||
assert (num_example_per_epoch % num_replicas) == 0, "Please ensure that the number of examples per epoch is evenly divisible by the number of nodes"
|
||||
assert (np.allclose(sum([v for k,v in fractions.items()]), 1.0)), \
|
||||
f"Fractions of datasets add up to {sum([v for k,v in fractions.items()])}, should add up to 1.0"
|
||||
|
||||
|
||||
# Load lengths into a tensor, if file exists
|
||||
if lengths is not None and os.path.isfile(lengths):
|
||||
lengths = torch.load(lengths)
|
||||
else:
|
||||
if lengths is not None:
|
||||
warnings.warn(f"Lengths file {lengths} does not exist. Ignoring lengths file.")
|
||||
lengths = None
|
||||
|
||||
if batch_by_length:
|
||||
assert lengths is not None, "If batching by length, must pass a valid lengths tensor."
|
||||
|
||||
if not batch_by_dataset:
|
||||
assert not batch_by_length, "Cannot batch by length without also batching by dataset."
|
||||
|
||||
self.dataset = dataset
|
||||
self.weights_dict = weights_dict
|
||||
self.num_replicas = num_replicas
|
||||
self.num_per_epoch_dict = OrderedDict([
|
||||
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
|
||||
for dataset_name in self.dataset.dataset_dict.keys()
|
||||
])
|
||||
self.lengths = lengths
|
||||
self.batch_by_length = batch_by_length
|
||||
self.batch_by_dataset = batch_by_dataset
|
||||
|
||||
if batch_by_dataset:
|
||||
# Ensure that all GPU's can process an example from the same dataset at once
|
||||
self.num_per_epoch_dict = adjust_samples_for_num_replicas(
|
||||
OrderedDict([
|
||||
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
|
||||
for dataset_name in self.dataset.dataset_dict.keys()
|
||||
]),
|
||||
num_replicas
|
||||
)
|
||||
else:
|
||||
self.num_per_epoch_dict = OrderedDict([
|
||||
(dataset_name, int(round(num_example_per_epoch * fractions[dataset_name])))
|
||||
for dataset_name in self.dataset.dataset_dict.keys()
|
||||
])
|
||||
|
||||
# account for rounding error
|
||||
# Account for rounding error
|
||||
dataset_names = list(self.dataset.dataset_dict.keys())
|
||||
nonzero_dataset_names = [name for name in dataset_names if self.num_per_epoch_dict[name] > 0]
|
||||
|
||||
# Calculate the actual number of examples that will be sampled (will be a multiple of num_replicas)
|
||||
num_per_epoch_actual = sum([self.num_per_epoch_dict[name] for name in nonzero_dataset_names])
|
||||
self.num_per_epoch_dict[nonzero_dataset_names[0]] += num_example_per_epoch - num_per_epoch_actual
|
||||
|
||||
# Handle remainders by rounding down to the nearest multiple of num_replicas and sampling from `pdb`
|
||||
remainder = num_example_per_epoch - num_per_epoch_actual
|
||||
remainder = remainder - (remainder % num_replicas)
|
||||
self.num_per_epoch_dict[nonzero_dataset_names[0]] += remainder # The first dataset is the pdb
|
||||
|
||||
self.total_size = num_example_per_epoch
|
||||
self.total_size = num_per_epoch_actual + remainder
|
||||
self.num_samples = self.total_size // self.num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.replacement = replacement
|
||||
|
||||
# Sample the protein datasets with replacement to account for length weighting
|
||||
# Other datasets (e.g., small molecule datasets) will be sampled WITHOUT replacement (since LEN_EXIST is not the appropriate weighting)
|
||||
self.datasets_with_replacement = datasets_with_replacement
|
||||
if (rank==0):
|
||||
print(f"Training examples per epoch ({self.total_size} total):")
|
||||
for k,v in self.num_per_epoch_dict.items():
|
||||
print(' '+k, ':', v)
|
||||
|
||||
def _sample_from_dataset(self, dataset_name, g):
|
||||
"""
|
||||
Samples a specified number of sequences from the given dataset.
|
||||
Samples with replacement based on the dataset type, forcing replacement if sampling more than dataset length.
|
||||
Parameters:
|
||||
dataset_name (str): The name of the dataset to sample from.
|
||||
g (torch.Generator): A pre-seeded generator to ensure consistency across nodes.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of sampled indices from the dataset.
|
||||
"""
|
||||
# Throw warning if the number of sequences to be sampled is not more than the number of sequences in the dataset
|
||||
if self.num_per_epoch_dict[dataset_name] > len(self.dataset.ID_dict[dataset_name]):
|
||||
warnings.warn(f"Number of sequences to be sampled in one epoch is greater than the number of " \
|
||||
f"sequences in the dataset. Must sample with replacement. Ensure that this is the desired behavior. Dataset: {dataset_name}, " \
|
||||
f"Dataset length: {len(self.dataset.ID_dict[dataset_name])}, " \
|
||||
f"# to be sampled: {self.num_per_epoch_dict[dataset_name]}")
|
||||
|
||||
# Determine if sampling with replacement based on the dataset type, forcing replacement if sampling more than dataset length
|
||||
replacement = dataset_name in self.datasets_with_replacement or self.num_per_epoch_dict[dataset_name] > len(self.dataset.ID_dict[dataset_name])
|
||||
|
||||
# Sample indices from the dataset based on the weights (prefer longer sequences)
|
||||
return torch.multinomial(self.weights_dict[dataset_name],
|
||||
self.num_per_epoch_dict[dataset_name],
|
||||
generator=g,
|
||||
replacement=replacement)
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
# get indices (fb + pdb models)
|
||||
# get indices (all models)
|
||||
indices = torch.arange(len(self.dataset))
|
||||
|
||||
# weighted subsampling
|
||||
# order of datasets in this loop should match order in DistilledDataset.__getitem__()
|
||||
offset = 0
|
||||
sel_indices = torch.tensor((),dtype=int)
|
||||
print(self.dataset.dataset_dict.keys())
|
||||
for dataset_name in self.dataset.dataset_dict.keys():
|
||||
|
||||
if self.batch_by_dataset:
|
||||
for dataset_name in self.dataset.dataset_dict.keys():
|
||||
if (self.num_per_epoch_dict[dataset_name]> 0):
|
||||
# Sample and adjust for offset; _sample_from_dataset handles replacement
|
||||
sampled_idx = self._sample_from_dataset(dataset_name, g) + offset
|
||||
|
||||
# Divide sampled_idx into num_replicas chunks and assign each chunk to a node
|
||||
sampled_idx_split = torch.split(sampled_idx, len(sampled_idx) // self.num_replicas)
|
||||
assert all([len(x) == len(sampled_idx_split[0]) for x in sampled_idx_split])
|
||||
|
||||
# If also batching by sequence length, sort the indices by length
|
||||
if self.batch_by_length and self.lengths is not None:
|
||||
sampled_idx_split = [x[torch.argsort(self.lengths[x])] for x in sampled_idx_split]
|
||||
|
||||
# Add the sampled indices to the running tensor based on the node rank
|
||||
sel_indices = torch.cat((sel_indices, indices[sampled_idx_split[self.rank]]))
|
||||
offset += len(self.dataset.ID_dict[dataset_name])
|
||||
|
||||
if (self.num_per_epoch_dict[dataset_name]> 0):
|
||||
sampled_idx = torch.multinomial(self.weights_dict[dataset_name],
|
||||
self.num_per_epoch_dict[dataset_name],
|
||||
self.replacement,
|
||||
generator=g)
|
||||
sel_indices = torch.cat((sel_indices, indices[sampled_idx + offset]))
|
||||
# For each node, the indices are shuffled with the same seed, and so will draw from the same datasets in the same order
|
||||
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
|
||||
else:
|
||||
# Standard implementation of WeightedDistributedSampler without batching by dataset or length
|
||||
for dataset_name in self.dataset.dataset_dict.keys():
|
||||
if (self.num_per_epoch_dict[dataset_name]> 0):
|
||||
sampled_idx = self._sample_from_dataset(dataset_name, g)
|
||||
sel_indices = torch.cat((sel_indices, indices[sampled_idx + offset]))
|
||||
offset += len(self.dataset.ID_dict[dataset_name])
|
||||
|
||||
# shuffle indices
|
||||
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
|
||||
# shuffle indices
|
||||
indices = sel_indices[torch.randperm(len(sel_indices), generator=g)]
|
||||
|
||||
# per each gpu
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
# per each gpu
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
|
||||
#print('rank',self.rank,': expecting',self.num_samples,'examples, drew',len(indices),'examples')
|
||||
assert len(indices) == self.num_samples # more stringent, switch with line above during debugging
|
||||
|
||||
assert len(indices) == self.num_samples # more stringent, switch with line above during debugging
|
||||
return iter(indices.tolist())
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
self.epoch = epoch
|
||||
158
rf2aa/data/dataloader_adaptor.py
Normal file
158
rf2aa/data/dataloader_adaptor.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import torch
|
||||
|
||||
from rf2aa.kinematics import xyz_to_t2d
|
||||
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs
|
||||
from rf2aa.util import INIT_CRDS, NTOTAL, NTOTALDOFS, is_atom, \
|
||||
Ls_from_same_chain_2d, xyz_t_to_frame_xyz, get_prot_sm_mask
|
||||
|
||||
|
||||
|
||||
def prepare_input(inputs, xyz_converter, gpu):
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
|
||||
# transfer inputs to device
|
||||
B, _, N, L = msa.shape
|
||||
|
||||
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
|
||||
true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3)
|
||||
mask_crds = mask_crds.to(gpu, non_blocking=True) # (B, L, 27)
|
||||
same_chain = same_chain.to(gpu, non_blocking=True)
|
||||
|
||||
xyz_t = xyz_t.to(gpu, non_blocking=True)
|
||||
t1d = t1d.to(gpu, non_blocking=True)
|
||||
mask_t = mask_t.to(gpu, non_blocking=True)
|
||||
|
||||
#fd --- use black hole initialization
|
||||
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
|
||||
mask_prev = torch.zeros((1,L,NTOTAL), dtype=torch.bool).to(gpu, non_blocking=True)
|
||||
|
||||
atom_frames = atom_frames.to(gpu, non_blocking=True)
|
||||
bond_feats = bond_feats.to(gpu, non_blocking=True)
|
||||
dist_matrix = dist_matrix.to(gpu, non_blocking=True)
|
||||
chirals = chirals.to(gpu, non_blocking=True)
|
||||
assert (len(symmgp)==1)
|
||||
symmgp = symmgp[0]
|
||||
|
||||
# symmetry - reprocess (many) inputs
|
||||
if (symmgp != 'C1'):
|
||||
Lasu = L//2 # msa contains intra/inter block
|
||||
symmids, symmRs, symmmeta, symmoffset = symm_subunit_matrix(symmgp)
|
||||
symmids = symmids.to(gpu, non_blocking=True)
|
||||
symmRs = symmRs.to(gpu, non_blocking=True)
|
||||
symmoffset = symmoffset.to(gpu, non_blocking=True)
|
||||
symmmeta = (
|
||||
[x.to(gpu, non_blocking=True) for x in symmmeta[0]],
|
||||
symmmeta[1])
|
||||
O = symmids.shape[0]
|
||||
xyz_prev = xyz_prev + symmoffset*Lasu**(1/3)
|
||||
|
||||
# find contacting subunits
|
||||
xyz_prev, symmsub = find_symm_subs(xyz_prev[:,:Lasu], symmRs, symmmeta)
|
||||
symmsub = symmsub.to(gpu, non_blocking=True)
|
||||
Osub = symmsub.shape[0]
|
||||
mask_prev = mask_prev[:,:L].repeat(1,Osub,1)
|
||||
|
||||
# symmetrize msa
|
||||
seq = torch.cat([seq[:,:,:Lasu],*[seq[:,:,Lasu:]]*(Osub-1)], dim=2)
|
||||
msa = torch.cat([msa[:,:,:,:Lasu],*[msa[:,:,:,Lasu:]]*(Osub-1)], dim=3)
|
||||
msa_masked = torch.cat([msa_masked[:,:,:,:Lasu],*[msa_masked[:,:,:,Lasu:]]*(Osub-1)], dim=3)
|
||||
msa_full = torch.cat([msa_full[:,:,:,:Lasu],*[msa_full[:,:,:,Lasu:]]*(Osub-1)], dim=3)
|
||||
mask_msa = torch.cat([mask_msa[:,:,:,:Lasu],*[mask_msa[:,:,:,Lasu:]]*(Osub-1)], dim=3)
|
||||
|
||||
# symmetrize templates
|
||||
xyz_t = xyz_t[:,:,:Lasu].repeat(1,1,Osub,1,1)
|
||||
mask_t = mask_t[:,:,:Lasu].repeat(1,1,Osub,1)
|
||||
t1d = t1d[:,:,:Lasu].repeat(1,1,Osub,1)
|
||||
|
||||
# symmetrize atom_frames
|
||||
atom_frames = torch.cat([atom_frames[:,:,:Lasu],*[atom_frames[:,:,Lasu:]]*(Osub-1)], dim=2)
|
||||
|
||||
# index, same chain, bond feats
|
||||
idx_pdb = torch.arange(Osub*Lasu, device=gpu)[None,:]
|
||||
same_chain = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
|
||||
bond_feats_new = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
|
||||
dist_matrix_new = torch.zeros((1,Osub*Lasu,Osub*Lasu), device=gpu).long()
|
||||
for o_i in range(Osub):
|
||||
same_chain[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = 1
|
||||
idx_pdb[:,o_i*Lasu:(o_i+1)*Lasu] += 100*o_i
|
||||
bond_feats_new[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = bond_feats
|
||||
dist_matrix_new[:,o_i*Lasu:(o_i+1)*Lasu,o_i*Lasu:(o_i+1)*Lasu] = dist_matrix
|
||||
|
||||
bond_feats = bond_feats_new
|
||||
dist_matrix = dist_matrix_new
|
||||
|
||||
else:
|
||||
Lasu = L
|
||||
Osub = 1
|
||||
symmids = None
|
||||
symmsub = None
|
||||
symmRs = None
|
||||
symmmeta = None
|
||||
|
||||
# processing template features
|
||||
mask_t_2d = get_prot_sm_mask(mask_t, seq[0][0])
|
||||
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
|
||||
|
||||
# we can provide sm_templates so we want to allow interchain templates bw protein chain 1 and sms
|
||||
# specifically the templates are found for the query protein chain
|
||||
Ls = Ls_from_same_chain_2d(same_chain)
|
||||
prot_ch1_to_sm_2d = torch.zeros_like(same_chain)
|
||||
prot_ch1_to_sm_2d[:, :Ls[0], is_atom(seq)[0][0]] = 1
|
||||
prot_ch1_to_sm_2d[:, is_atom(seq)[0][0], :Ls[0]] = 1
|
||||
|
||||
is_possible_t2d = same_chain.clone()
|
||||
is_possible_t2d[prot_ch1_to_sm_2d.bool()] = 1
|
||||
|
||||
mask_t_2d = mask_t_2d.float() * is_possible_t2d.float()[:,None] # (ignore inter-chain region between proteins)
|
||||
xyz_t_frame = xyz_t_to_frame_xyz(xyz_t, msa[:, 0,0], atom_frames)
|
||||
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d)
|
||||
|
||||
# get torsion angles from templates
|
||||
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,Lasu*Osub)
|
||||
alpha, _, alpha_mask, _ = xyz_converter.get_torsions(xyz_t.reshape(-1,Lasu*Osub,NTOTAL,3), seq_tmp, mask_in=mask_t.reshape(-1,Lasu*Osub,NTOTAL))
|
||||
alpha = alpha.reshape(B,-1,Lasu*Osub,NTOTALDOFS,2)
|
||||
alpha_mask = alpha_mask.reshape(B,-1,Lasu*Osub,NTOTALDOFS,1)
|
||||
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, Lasu*Osub, 3*NTOTALDOFS)
|
||||
alpha_prev = torch.zeros((B,Lasu*Osub,NTOTALDOFS,2))
|
||||
|
||||
network_input = {}
|
||||
network_input['msa_latent'] = msa_masked
|
||||
network_input['msa_full'] = msa_full
|
||||
network_input['seq'] = seq
|
||||
network_input['seq_unmasked'] = msa[:,0,0]
|
||||
network_input['idx'] = idx_pdb
|
||||
network_input['t1d'] = t1d
|
||||
network_input['t2d'] = t2d
|
||||
network_input['xyz_t'] = xyz_t[:,:,:,1]
|
||||
network_input['alpha_t'] = alpha_t
|
||||
network_input['mask_t'] = mask_t_2d
|
||||
network_input['same_chain'] = same_chain
|
||||
network_input['bond_feats'] = bond_feats
|
||||
network_input['dist_matrix'] = dist_matrix
|
||||
|
||||
network_input['chirals'] = chirals
|
||||
network_input['atom_frames'] = atom_frames
|
||||
|
||||
network_input['symmids'] = symmids
|
||||
network_input['symmsub'] = symmsub
|
||||
network_input['symmRs'] = symmRs
|
||||
network_input['symmmeta'] = symmmeta
|
||||
|
||||
network_input["xyz_prev"] = xyz_prev
|
||||
network_input["alpha_prev"] = alpha_prev
|
||||
network_input["mask_recycle"] = None
|
||||
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
|
||||
|
||||
|
||||
def get_loss_calc_items(inputs,device="cpu"):
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
|
||||
return seq.to(device), same_chain.to(device), idx_pdb.to(device), bond_feats.to(device), dist_matrix.to(device), atom_frames.to(device)
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch
|
||||
import networkx as nx
|
||||
from typing import Dict, Optional, Tuple, List, Set, Any
|
||||
from rf2aa import cifutils
|
||||
import rf2aa.cifutils as cifutils
|
||||
from rf2aa.util import get_ligand_atoms_bonds, cif_ligand_to_xyz, cif_ligand_to_obmol, get_automorphs
|
||||
|
||||
|
||||
@@ -479,18 +479,19 @@ def parse_a3m(filename, maxseq=8000, paired=False):
|
||||
|
||||
# read and extract xyz coords of N,Ca,C atoms
|
||||
# from a PDB file
|
||||
def parse_pdb(filename, seq=False):
|
||||
def parse_pdb(filename, seq=False, lddt_mask=False):
|
||||
lines = open(filename,'r').readlines()
|
||||
if seq:
|
||||
return parse_pdb_lines_w_seq(lines)
|
||||
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
|
||||
return parse_pdb_lines(lines)
|
||||
|
||||
def parse_pdb_lines_w_seq(lines):
|
||||
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # (chain letter, res num, aa)
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||
idx_s = [int(r[1]) for r in res]
|
||||
plddt = [float(r[3]) for r in res]
|
||||
seq = [aa2num[r[2]] if r[2] in aa2num.keys() else 20 for r in res]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
@@ -525,6 +526,12 @@ def parse_pdb_lines_w_seq(lines):
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
if lddt_mask == True:
|
||||
plddt = np.array(plddt)
|
||||
mask_lddt = np.full_like(mask, False)
|
||||
mask_lddt[plddt > .85, 5:] = True
|
||||
mask_lddt[plddt > .70, :5] = True
|
||||
mask = np.logical_and(mask, mask_lddt)
|
||||
|
||||
return xyz,mask,np.array(idx_s), np.array(seq)
|
||||
|
||||
@@ -532,7 +539,7 @@ def parse_pdb_lines_w_seq(lines):
|
||||
def parse_pdb_lines(lines):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # (chain letter, res num, aa)
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||
idx_s = [int(r[1]) for r in res]
|
||||
|
||||
18
rf2aa/debug.py
Normal file
18
rf2aa/debug.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def debug_nans(latent_feats):
|
||||
for k, v in latent_feats.items():
|
||||
if torch.is_tensor(v):
|
||||
print(k)
|
||||
print(torch.sum(v.isnan()))
|
||||
|
||||
def debug_unused_params(model):
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is None:
|
||||
print(name)
|
||||
|
||||
def debug_used_params(model):
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
print(name)
|
||||
125
rf2aa/debug_item.py
Normal file
125
rf2aa/debug_item.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import unittest
|
||||
from hydra import compose, initialize
|
||||
import torch
|
||||
from rf2aa.chemical import NBTYPES, NTOTAL
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_single_item_dataset, set_data_loader_params
|
||||
from rf2aa.data.data_loader import loader_atomize_pdb
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input
|
||||
from rf2aa.util import is_atom, writepdb
|
||||
from rf2aa.tensor_util import assert_shape
|
||||
from rf2aa.trainer_new import trainer_factory
|
||||
from rf2aa.training.recycling import recycle_step_legacy
|
||||
|
||||
|
||||
|
||||
#### Setup test case hyperparams
|
||||
|
||||
ITEM = \
|
||||
{'Unnamed: 0': 262672, 'CHAINID': '6ywe_UB', 'DEPOSITION': '2020-04-29', 'RESOLUTION': 2.9900, 'HASH': '072380', 'CLUSTER': 9905, 'SEQUENCE': 'MPNKPIRLPPLKQLRVRQANKAEENPCIAVMSSVLACWASAGYNSAGCATVENALRACMDAPKPAPKPNNTINYHLSRFQERLTQGKSKK', 'LEN_EXIST': 88, 'TAXID': '5141'}
|
||||
|
||||
CONFIG = "legacy_train"
|
||||
LOADER_FN = loader_atomize_pdb
|
||||
LOADER_KWARGS = {
|
||||
"homo": None,
|
||||
"n_res_atomize": 5,
|
||||
"flank": 0
|
||||
}
|
||||
|
||||
class DebugTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
with initialize(version_base=None, config_path="config/train"):
|
||||
self.cfg = compose(config_name=CONFIG)
|
||||
loader_params = set_data_loader_params(self.cfg.loader_params)
|
||||
loader = compose_single_item_dataset(
|
||||
ITEM,
|
||||
loader_params,
|
||||
LOADER_FN,
|
||||
LOADER_KWARGS
|
||||
)
|
||||
self.loader = loader
|
||||
|
||||
def test_correct_shapes(self):
|
||||
""" test shapes are all consistent with each other """
|
||||
for inputs in self.loader:
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
B, recycles, N, L = msa.shape[:4]
|
||||
num_atoms = (is_atom(seq[0,0]).sum()).item()
|
||||
assert_shape(seq, (B, recycles, L))
|
||||
assert_shape(msa, (B, recycles, N, L))
|
||||
assert_shape(msa_masked, (B, recycles, N, L, 164)) #Hack: hardcoded for current featurization
|
||||
N_full = msa_full.shape[2]
|
||||
assert_shape(msa_full, (B, recycles, N_full, L, 83)) #HACK:: hardcoded for current features
|
||||
assert_shape(mask_msa, (B, recycles, N, L))
|
||||
N_symm = true_crds.shape[1]
|
||||
assert_shape(true_crds, (B, N_symm, L, NTOTAL, 3))
|
||||
assert_shape(mask_crds, (B, N_symm, L, NTOTAL))
|
||||
assert_shape(idx_pdb, (B, L))
|
||||
N_templ = xyz_t.shape[1]
|
||||
assert_shape(xyz_t, (B, N_templ, L, NTOTAL, 3))
|
||||
assert_shape(t1d, (B, N_templ, L, 80)) # hack hard coded dimension
|
||||
assert_shape(mask_t, (B, N_templ, L, NTOTAL))
|
||||
assert_shape(xyz_prev, (B, L, NTOTAL, 3))
|
||||
assert_shape(mask_prev, (B, L, NTOTAL))
|
||||
assert_shape(same_chain, (B, L, L))
|
||||
assert type(unclamp.item()) == bool
|
||||
assert type(negative.item()) == bool
|
||||
assert_shape(atom_frames, (B, num_atoms, 3,2))
|
||||
assert_shape(bond_feats, (B, L, L))
|
||||
assert_shape(dist_matrix, (B, L, L))
|
||||
n_chirals = chirals.shape[1]
|
||||
assert_shape(chirals, (B, n_chirals, 5))
|
||||
assert_shape(ch_label, (B, L))
|
||||
assert symmgp[0] == "C1", f"{symmgp}"
|
||||
|
||||
def test_forward_pass(self):
|
||||
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
|
||||
trainer.construct_model()
|
||||
trainer.model.device = "cpu"
|
||||
trainer.move_constants_to_device(gpu="cpu")
|
||||
for inputs in self.loader:
|
||||
loss, loss_dict = trainer.train_step(inputs, 1)
|
||||
|
||||
def test_forward_pass_with_checkpoint(self):
|
||||
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
|
||||
trainer.construct_model()
|
||||
trainer.model.device = "cpu"
|
||||
trainer.move_constants_to_device(gpu="cpu")
|
||||
checkpoint_path = "/home/rohith/rf2a-fd3/models/rf2a_fd3_20221125_714.pt"
|
||||
trainer.checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
trainer.model.model.load_state_dict(trainer.checkpoint["final_state_dict"])
|
||||
trainer.model.shadow.load_state_dict(trainer.checkpoint["model_state_dict"])
|
||||
for inputs in self.loader:
|
||||
loss, loss_dict = trainer.train_step(inputs, 1)
|
||||
#TODO: check something about the loss
|
||||
|
||||
def test_forward_pass_outputs(self):
|
||||
trainer = trainer_factory[self.cfg.experiment.trainer](self.cfg)
|
||||
trainer.construct_model()
|
||||
trainer.model.device = "cpu"
|
||||
trainer.move_constants_to_device(gpu="cpu")
|
||||
checkpoint_path = "/home/rohith/rf2a-fd3/models/rf2a_fd3_20221125_714.pt"
|
||||
trainer.checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
trainer.model.model.load_state_dict(trainer.checkpoint["final_state_dict"])
|
||||
trainer.model.shadow.load_state_dict(trainer.checkpoint["model_state_dict"])
|
||||
for inputs in self.loader:
|
||||
gpu = trainer.model.device
|
||||
# HACK: certain features are constructed during the train step
|
||||
# in the future this should only promote the constructed features onto gpu
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input(inputs, trainer.xyz_converter, gpu)
|
||||
n_cycle = 1
|
||||
output_i = recycle_step_legacy(trainer.model, network_input, n_cycle, trainer.config.training_params.use_amp)
|
||||
c6d, mlm, pae, pde, p_bind, xyz, alphas, _, _, _, _, _ = output_i
|
||||
seq_unmasked = network_input["seq_unmasked"]
|
||||
writepdb("test.pdb", xyz[-1], seq_unmasked)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,40 +0,0 @@
|
||||
from rf2aa.train_multi_EMA import Trainer
|
||||
from rf2aa.evaluate import Evaluator
|
||||
|
||||
|
||||
def trainer_factory(args, dataset_param, model_param, loader_param, loss_param):
|
||||
dataloader_kwargs = {
|
||||
"shuffle": args.shuffle_dataloader,
|
||||
"num_workers": args.dataloader_num_workers,
|
||||
"pin_memory": not args.dont_pin_memory,
|
||||
}
|
||||
trainer_class = Trainer
|
||||
if args.mode == "eval":
|
||||
trainer_class = Evaluator
|
||||
args.eval = True
|
||||
|
||||
trainer_object = trainer_class(
|
||||
model_name=args.model_name,
|
||||
n_epoch=args.num_epochs,
|
||||
step_lr=args.step_lr,
|
||||
lr=args.lr,
|
||||
l2_coeff=1.0e-2,
|
||||
port=args.port,
|
||||
model_param=model_param,
|
||||
loader_param=loader_param,
|
||||
loss_param=loss_param,
|
||||
batch_size=args.batch_size,
|
||||
accum_step=args.accum,
|
||||
maxcycle=args.maxcycle,
|
||||
eval=args.eval,
|
||||
interactive=args.interactive,
|
||||
out_dir=args.out_dir,
|
||||
wandb_prefix=args.wandb_prefix,
|
||||
model_dir=args.model_dir,
|
||||
dataset_param=dataset_param,
|
||||
dataloader_kwargs=dataloader_kwargs,
|
||||
debug_mode=args.debug,
|
||||
skip_valid=args.skip_valid,
|
||||
start_epoch=args.start_epoch,
|
||||
)
|
||||
return trainer_object
|
||||
@@ -1,248 +0,0 @@
|
||||
import sys, os, json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from loss_halluc import calc_entropy_loss, calc_pae_loss
|
||||
from optimization import run_gradient_descent, run_mcmc
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0,script_dir+'/models/fold_and_dock3/')
|
||||
import parsers
|
||||
from RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from data_loader import merge_a3m_hetero
|
||||
import util
|
||||
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
|
||||
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
|
||||
from model_params import MODEL_PARAM
|
||||
|
||||
def get_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="RF2-allatom hallucination: protein scaffold design with explicit modeling of small-molecule ligands")
|
||||
parser.add_argument("-checkpoint",
|
||||
default="/databases/TrRosetta/RF2_allatom/checkpoints/rf2a_fd3_20221125_115.pt",
|
||||
help="Path to model weights")
|
||||
parser.add_argument("-pdb", help='PDB of motif')
|
||||
parser.add_argument("-parse_hetatm", action="store_true", default=False, help="parse ligand information from input pdb")
|
||||
parser.add_argument("-mol2", help='mol2 of small molecule')
|
||||
parser.add_argument("-num", type=int, default=1, help='number of designs')
|
||||
parser.add_argument("-start_num", type=int, default=0, help='start number of output designs')
|
||||
parser.add_argument("-grad_steps", type=int, required=True, help='number of gradient descent steps')
|
||||
parser.add_argument("-mcmc_steps", type=int, required=True, help='number of mcmc steps')
|
||||
parser.add_argument("-out", help='prefix of output files')
|
||||
parser.add_argument("-L", type=int, help='length of hallucinated protein')
|
||||
parser.add_argument("-T0", type=float, default=0.02, help='initial temperature for simulated annealing')
|
||||
parser.add_argument("-mcmc_halflife", default=100, help='half-life of simulated annealing')
|
||||
parser.add_argument("-cycles", type=str, default='10', help='number of recycles')
|
||||
parser.add_argument("-seq_prob_type", default='hard', help='soft or hard probabilities of sequence')
|
||||
parser.add_argument("-init_sd", type=float, default=1e-6, help='random initial logit standard deviation')
|
||||
parser.add_argument("-template_ligand", action='store_true', default=False, help='input template features for ligand structure')
|
||||
parser.add_argument("-init_ligand_xyz", action='store_true', default=False, help='input initial xyz coords for ligand structure')
|
||||
parser.add_argument("-learning_rate", type=float, default=0.05, help='gradient descent learning rate')
|
||||
parser.add_argument("-device", type=str, default='cuda:0', help='gpu to run on')
|
||||
parser.add_argument("-w_ent", type=float, default=1.0, help='weight of entropy loss')
|
||||
parser.add_argument("-w_pae", type=float, default=1.0, help='weight of pae loss')
|
||||
parser.add_argument("-w_ipae", type=float, default=1.0, help='weight of inter-pae loss')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
# compute expected value from binned lddt
|
||||
def lddt_unbin(pred_lddt):
|
||||
nbin = pred_lddt.shape[1]
|
||||
bin_step = 1.0 / nbin
|
||||
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
|
||||
|
||||
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
|
||||
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
|
||||
|
||||
def load_model(args, MODEL_PARAM):
|
||||
device = args.device
|
||||
model = RoseTTAFoldModule(
|
||||
**MODEL_PARAM,
|
||||
aamask = util.allatom_mask.to(device),
|
||||
atom_type_index = util.atom_type_index.to(device),
|
||||
ljlk_parameters = util.ljlk_parameters.to(device),
|
||||
lj_correction_parameters = util.lj_correction_parameters.to(device),
|
||||
num_bonds = util.num_bonds.to(device),
|
||||
cb_len = util.cb_length_t.to(device),
|
||||
cb_ang = util.cb_angle_t.to(device),
|
||||
cb_tor = util.cb_torsion_t.to(device),
|
||||
).to(device)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
return model
|
||||
|
||||
|
||||
def prepare_inputs(args, init_seq=None, random_noise=5):
|
||||
|
||||
B = 1 # batch
|
||||
N = 1 # msa depth
|
||||
protein_L = args.L
|
||||
device = args.device
|
||||
|
||||
if init_seq is None:
|
||||
msa_prot = torch.randint(20, (1,protein_L))
|
||||
else:
|
||||
msa_prot = init_seq
|
||||
ins_prot = torch.zeros((1,protein_L)).long()
|
||||
idx_prot = torch.arange(protein_L)
|
||||
|
||||
if args.mol2 is not None:
|
||||
a3m_prot = {"msa": msa_prot, "ins": ins_prot}
|
||||
mol, msa_sm, ins_sm, xyz_sm, mask_sm = parsers.parse_mol(args.mol2)
|
||||
a3m_sm = {"msa": msa_sm.unsqueeze(0), "ins": ins_sm.unsqueeze(0)}
|
||||
G = util.get_nxgraph(mol)
|
||||
atom_frames = util.get_atom_frames(msa_sm, G)
|
||||
N_symmetry, sm_L, _ = xyz_sm.shape
|
||||
|
||||
Ls = [protein_L, sm_L]
|
||||
a3m = merge_a3m_hetero(a3m_prot, a3m_sm, Ls)
|
||||
msa = a3m['msa'].long()
|
||||
ins = a3m['ins'].long()
|
||||
chirals = get_chirals(mol, xyz_sm[0])
|
||||
|
||||
xyz = torch.full((N_symmetry, sum(Ls), NTOTAL, 3), np.nan).float()
|
||||
mask = torch.full(xyz.shape[:-1], False).bool()
|
||||
if args.pdb is not None:
|
||||
xyz[:, :Ls[0], :nprotatoms, :] = xyz_prot.expand(N_symmetry, Ls[0], nprotatoms, 3)
|
||||
mask[:, :protein_L, :nprotatoms] = mask_prot.expand(N_symmetry, Ls[0], nprotatoms)
|
||||
if args.mol2 is not None:
|
||||
xyz[:, Ls[0]:, 1, :] = xyz_sm
|
||||
mask[:, protein_L:, 1] = mask_sm
|
||||
|
||||
idx_sm = torch.arange(max(idx_prot),max(idx_prot)+Ls[1])+200
|
||||
idx_pdb = torch.concat([idx_prot, idx_sm])
|
||||
|
||||
chain_idx = torch.zeros((sum(Ls), sum(Ls))).long()
|
||||
chain_idx[:Ls[0], :Ls[0]] = 1
|
||||
chain_idx[Ls[0]:, Ls[0]:] = 1
|
||||
bond_feats = torch.zeros((sum(Ls), sum(Ls))).long()
|
||||
bond_feats[:Ls[0], :Ls[0]] = util.get_protein_bond_feats(Ls[0])
|
||||
if args.mol2 is not None:
|
||||
bond_feats[Ls[0]:, Ls[0]:] = util.get_bond_feats(mol)
|
||||
|
||||
# blank template
|
||||
xyz_t = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,sum(Ls),1,1) \
|
||||
+ torch.rand(1,sum(Ls),1,3)*random_noise - random_noise/2
|
||||
f1d_t = torch.nn.functional.one_hot(torch.full((1, sum(Ls)), 20).long(), num_classes=NAATOKENS-1).float() # all gaps
|
||||
conf = torch.zeros((1, sum(Ls), 1)).float()
|
||||
f1d_t = torch.cat((f1d_t, conf), -1)
|
||||
mask_t = torch.full((1,sum(Ls),NTOTAL), False)
|
||||
|
||||
if args.template_ligand: # input true s.m. xyz as template
|
||||
xyz_t[0, Ls[0]:, 1] = xyz_sm[0] - xyz_sm[0].mean(-2) # centroid at origin
|
||||
f1d_t[0, Ls[0]:] = torch.cat((
|
||||
torch.nn.functional.one_hot(msa[0, Ls[0]:]-1, num_classes=NAATOKENS-1).float(),
|
||||
torch.ones((Ls[1], 1)).float()
|
||||
), -1) # (1, L_sm, NAATOKENS)
|
||||
mask_t[0, Ls[0]:, 1] = mask_sm[0] # all symmetry variants have same mask
|
||||
|
||||
xyz_t = torch.nan_to_num(xyz_t)
|
||||
|
||||
# black-hole coordinates
|
||||
init = INIT_CRDS.reshape(1,NTOTAL,3).repeat(sum(Ls),1,1)
|
||||
xyz_prev = init + torch.rand(sum(Ls),1,3)*random_noise - random_noise/2
|
||||
mask_prev = torch.full(xyz_prev.shape[:-1], False).bool()
|
||||
|
||||
if args.init_ligand_xyz:
|
||||
xyz_prev[Ls[0]:, 1] = xyz_sm[0] - xyz_sm[0].mean(-2) # centroid at origin
|
||||
mask_prev[Ls[0]:, 1] = mask_sm[0]
|
||||
|
||||
# transfer inputs to device
|
||||
atom_frames = atom_frames[None].to(device, non_blocking=True)
|
||||
atom_mask = mask[None].to(device, non_blocking=True) # (B, L, 27)
|
||||
idx_pdb = idx_pdb[None].to(device, non_blocking=True) # (B, L)
|
||||
xyz_t = xyz_t[None].to(device, non_blocking=True)
|
||||
mask_t = mask_t[None].to(device, non_blocking=True)
|
||||
t1d = f1d_t[None].to(device, non_blocking=True)
|
||||
xyz_prev = xyz_prev[None].to(device, non_blocking=True)
|
||||
mask_prev = mask_prev[None].to(device, non_blocking=True)
|
||||
same_chain = chain_idx[None].to(device, non_blocking=True)
|
||||
bond_feats = bond_feats[None].to(device, non_blocking=True)
|
||||
chirals = chirals[None].to(device, non_blocking=True)
|
||||
|
||||
# processing template features
|
||||
seq_unmasked = msa.clone() # (B, L)
|
||||
mask_t_2d = util.get_prot_sm_mask(mask_t, seq_unmasked[0]) # (B, T, L)
|
||||
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
|
||||
mask_t_2d = mask_t_2d.float() * same_chain.float()[:,None] # (ignore inter-chain region)
|
||||
mask_recycle = util.get_prot_sm_mask(mask_prev, seq_unmasked[0])
|
||||
mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
|
||||
mask_recycle = same_chain.float()*mask_recycle.float()
|
||||
|
||||
xyz_t_frames = util.xyz_t_to_frame_xyz(xyz_t, seq_unmasked, atom_frames)
|
||||
t2d = xyz_to_t2d(xyz_t_frames, mask_t_2d)
|
||||
|
||||
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,sum(Ls))
|
||||
|
||||
alpha, _, alpha_mask, _ = util.get_torsions(
|
||||
xyz_t.reshape(-1,sum(Ls),NTOTAL,3),
|
||||
seq_tmp,
|
||||
util.torsion_indices.to(device),
|
||||
util.torsion_can_flip.to(device),
|
||||
util.reference_angles.to(device)
|
||||
)
|
||||
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
|
||||
alpha[torch.isnan(alpha)] = 0.0
|
||||
alpha = alpha.reshape(1,-1,sum(Ls),NTOTALDOFS,2)
|
||||
alpha_mask = alpha_mask.reshape(1,-1,sum(Ls),NTOTALDOFS,1)
|
||||
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, sum(Ls), 3*NTOTALDOFS).to(device)
|
||||
|
||||
return dict(
|
||||
Ls=Ls,
|
||||
msa=msa,
|
||||
ins=ins,
|
||||
xyz_prev=xyz_prev,
|
||||
idx_pdb=idx_pdb,
|
||||
bond_feats=bond_feats,
|
||||
chirals=chirals,
|
||||
atom_frames=atom_frames,
|
||||
t1d=t1d,
|
||||
t2d=t2d,
|
||||
xyz_t=xyz_t,
|
||||
alpha_t=alpha_t,
|
||||
mask_t=mask_t,
|
||||
mask_t_2d=mask_t_2d,
|
||||
same_chain=same_chain,
|
||||
mask_recycle=mask_recycle,
|
||||
)
|
||||
|
||||
def main():
|
||||
|
||||
args = get_args()
|
||||
|
||||
print('Loading network weights...')
|
||||
model = load_model(args, MODEL_PARAM)
|
||||
|
||||
inputs = prepare_inputs(args)
|
||||
|
||||
loss_funcs = [
|
||||
dict(w=args.w_ent, name='ent', func=calc_entropy_loss),
|
||||
dict(w=args.w_pae, name='pae', func=calc_pae_loss),
|
||||
dict(w=args.w_ipae, name='ipae', func=lambda out: calc_pae_loss(out, inter=True))
|
||||
]
|
||||
|
||||
for i_des in range(args.start_num, args.start_num+args.num):
|
||||
start_time = time.time()
|
||||
cycles = [int(i) for i in args.cycles.split(',')]
|
||||
|
||||
if args.grad_steps > 0:
|
||||
out = run_gradient_descent(model, args, inputs, cycles[0], loss_funcs)
|
||||
inputs['msa'] = out['msa'][0]
|
||||
|
||||
if args.mcmc_steps > 0:
|
||||
out = run_mcmc(model, args, inputs, cycles[min(len(cycles)-1,1)], loss_funcs)
|
||||
inputs['msa'] = out['msa'][0]
|
||||
|
||||
out_prefix = args.out+f'_{i_des}'
|
||||
best_lddt = lddt_unbin(out['pred_lddt_binned'])
|
||||
util.writepdb(out_prefix+".pdb", out['pred_allatom'], out['msa'][0], bfacts=100*best_lddt[0].float(),
|
||||
bond_feats=inputs['bond_feats'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import sys, os, json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0,script_dir+'/models/fold_and_dock3/')
|
||||
import util
|
||||
|
||||
def get_c6d_dict(logits, grad=True, args=None):
|
||||
if grad:
|
||||
logits = [logit.float() for logit in logits]
|
||||
else:
|
||||
logits = [logit.float().detach() for logit in logits]
|
||||
probs = [nn.functional.softmax(l, dim=1) for l in logits]
|
||||
dict_pred = {}
|
||||
dict_pred['p_dist'] = probs[0].permute([0,2,3,1])
|
||||
dict_pred['p_omega'] = probs[1].permute([0,2,3,1])
|
||||
dict_pred['p_theta'] = probs[2].permute([0,2,3,1])
|
||||
dict_pred['p_phi'] = probs[3].permute([0,2,3,1])
|
||||
return dict_pred, probs
|
||||
|
||||
def calc_entropy_loss(out):
|
||||
dict_pred, probs = get_c6d_dict(out['logit_s'], grad=True)
|
||||
probs = [dict_pred[key] for key in ['p_dist','p_omega','p_theta','p_phi']]
|
||||
|
||||
# exclude last bin, then renormalize
|
||||
probs = [prob[...,:-1]/prob[...,:-1].sum(-1,keepdim=True) for prob in probs]
|
||||
|
||||
L_mask = probs[0].shape[1]
|
||||
loss_mask = 1-torch.eye(L_mask)[None].to(probs[0].device).float() # (B, L, L)
|
||||
|
||||
def calc_entropy(p, mask, eps=1e-6):
|
||||
S_ij = -(p * torch.log(p + eps)).sum(axis=-1)
|
||||
S_ave = torch.sum(mask * S_ij, axis=(1,2)) / (torch.sum(mask, axis=(1,2)) + eps)
|
||||
return S_ave
|
||||
|
||||
entropy_s = [calc_entropy(prob, loss_mask) for prob in probs]
|
||||
|
||||
loss = torch.stack(entropy_s, dim=0).mean(dim=0)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def pae_unbin(logits_pae, bin_step=0.5):
|
||||
nbin = logits_pae.shape[1]
|
||||
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin, dtype=logits_pae.dtype, device=logits_pae.device)
|
||||
logits_pae = torch.nn.Softmax(dim=1)(logits_pae)
|
||||
return torch.sum(bins[None,:,None,None]*logits_pae, dim=1)
|
||||
|
||||
def calc_pae_loss(out, inter=False):
|
||||
pae = pae_unbin(out['logit_pae'])
|
||||
if inter:
|
||||
sm_mask = util.is_atom(out['msa'])[0,0]
|
||||
inter_mask = sm_mask[None]*(~sm_mask[:,None]) + (~sm_mask[None])*sm_mask[:,None]
|
||||
pae = pae_unbin(out['logit_pae'])
|
||||
return (pae*inter_mask).sum(dim=[-1,-2]) / inter_mask.sum(dim=[-1,-2])
|
||||
else:
|
||||
return pae.mean(dim=[-1,-2])
|
||||
return pae
|
||||
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from opt_einsum import contract as einsum
|
||||
from rf2aa.util_module import init_lecun_normal
|
||||
class FeedForwardLayer(nn.Module):
|
||||
def __init__(self, d_model, r_ff, p_drop=0.1):
|
||||
super(FeedForwardLayer, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize linear layer right before ReLu: He initializer (kaiming normal)
|
||||
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear1.bias)
|
||||
|
||||
# initialize linear layer right before residual connection: zero initialize
|
||||
nn.init.zeros_(self.linear2.weight)
|
||||
nn.init.zeros_(self.linear2.bias)
|
||||
|
||||
def forward(self, src):
|
||||
src = self.norm(src)
|
||||
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
||||
return src
|
||||
|
||||
class Attention(nn.Module):
|
||||
# calculate multi-head attention
|
||||
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
|
||||
super(Attention, self).__init__()
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
#
|
||||
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||
#
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
#
|
||||
# initialize all parameters properly
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, query, key, value):
|
||||
B, Q = query.shape[:2]
|
||||
B, K = key.shape[:2]
|
||||
#
|
||||
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
||||
key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
||||
value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bqhd,bkhd->bhqk', query, key)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bhqk,bkhd->bqhd', attn, value)
|
||||
out = out.reshape(B, Q, self.h*self.dim)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
|
||||
return out
|
||||
|
||||
# MSA Attention (row/column) from AlphaFold architecture
|
||||
class SequenceWeight(nn.Module):
|
||||
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
|
||||
super(SequenceWeight, self).__init__()
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
self.scale = 1.0 / math.sqrt(self.dim)
|
||||
|
||||
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_query.weight)
|
||||
nn.init.xavier_uniform_(self.to_key.weight)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
|
||||
tar_seq = msa[:,0]
|
||||
|
||||
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
|
||||
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
|
||||
|
||||
q = q * self.scale
|
||||
attn = einsum('bqihd,bkihd->bkihq', q, k)
|
||||
attn = F.softmax(attn, dim=1)
|
||||
return self.dropout(attn)
|
||||
|
||||
class MSARowAttentionWithBias(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
||||
super(MSARowAttentionWithBias, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
#
|
||||
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa, pair): # TODO: make this as tied-attention
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
#
|
||||
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(msa))
|
||||
#
|
||||
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
|
||||
key = key * self.scaling
|
||||
attn = einsum('bsqhd,bskhd->bqkh', query, key)
|
||||
attn = attn + bias
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
#
|
||||
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
|
||||
out = gate * out
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class MSAColAttention(nn.Module):
|
||||
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
||||
super(MSAColAttention, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
#
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
#
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||
gate = torch.sigmoid(self.to_g(msa))
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bqihd,bkihd->bihqk', query, key)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
|
||||
out = gate * out
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class MSAColGlobalAttention(nn.Module):
|
||||
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
|
||||
super(MSAColGlobalAttention, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
#
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
#
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
query = query.mean(dim=1) # (B, L, h, dim)
|
||||
key = self.to_k(msa) # (B, N, L, dim)
|
||||
value = self.to_v(msa) # (B, N, L, dim)
|
||||
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
|
||||
out = gate * out # (B, N, L, h*dim)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
|
||||
class TriangleAttention(nn.Module):
|
||||
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
|
||||
super(TriangleAttention, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_pair)
|
||||
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
|
||||
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
self.start_node=start_node
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, pair):
|
||||
B, L = pair.shape[:2]
|
||||
|
||||
pair = self.norm(pair)
|
||||
|
||||
# input projection
|
||||
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
|
||||
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
|
||||
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||
|
||||
# attention
|
||||
query = query * self.scaling
|
||||
if self.start_node:
|
||||
attn = einsum('bijhd,bikhd->bijkh', query, key)
|
||||
else:
|
||||
attn = einsum('bijhd,bkjhd->bijkh', query, key)
|
||||
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
if self.start_node:
|
||||
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||
else:
|
||||
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||
out = gate * out # gated attention
|
||||
|
||||
# output projection
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class TriangleMultiplication(nn.Module):
|
||||
def __init__(self, d_pair, d_hidden=128, outgoing=True):
|
||||
super(TriangleMultiplication, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_pair)
|
||||
self.left_proj = nn.Linear(d_pair, d_hidden)
|
||||
self.right_proj = nn.Linear(d_pair, d_hidden)
|
||||
self.left_gate = nn.Linear(d_pair, d_hidden)
|
||||
self.right_gate = nn.Linear(d_pair, d_hidden)
|
||||
#
|
||||
self.gate = nn.Linear(d_pair, d_pair)
|
||||
self.norm_out = nn.LayerNorm(d_hidden)
|
||||
self.out_proj = nn.Linear(d_hidden, d_pair)
|
||||
|
||||
self.outgoing = outgoing
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal distribution for regular linear weights
|
||||
self.left_proj = init_lecun_normal(self.left_proj)
|
||||
self.right_proj = init_lecun_normal(self.right_proj)
|
||||
|
||||
# Set Bias of Linear layers to zeros
|
||||
nn.init.zeros_(self.left_proj.bias)
|
||||
nn.init.zeros_(self.right_proj.bias)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.left_gate.weight)
|
||||
nn.init.ones_(self.left_gate.bias)
|
||||
|
||||
nn.init.zeros_(self.right_gate.weight)
|
||||
nn.init.ones_(self.right_gate.bias)
|
||||
|
||||
nn.init.zeros_(self.gate.weight)
|
||||
nn.init.ones_(self.gate.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.out_proj.weight)
|
||||
nn.init.zeros_(self.out_proj.bias)
|
||||
|
||||
def forward(self, pair):
|
||||
B, L = pair.shape[:2]
|
||||
pair = self.norm(pair)
|
||||
|
||||
left = self.left_proj(pair) # (B, L, L, d_h)
|
||||
left_gate = torch.sigmoid(self.left_gate(pair))
|
||||
left = left_gate * left
|
||||
|
||||
right = self.right_proj(pair) # (B, L, L, d_h)
|
||||
right_gate = torch.sigmoid(self.right_gate(pair))
|
||||
right = right_gate * right
|
||||
|
||||
if self.outgoing:
|
||||
out = einsum('bikd,bjkd->bijd', left, right/float(L))
|
||||
else:
|
||||
out = einsum('bkid,bkjd->bijd', left, right/float(L))
|
||||
out = self.norm_out(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
|
||||
out = gate * out
|
||||
return out
|
||||
|
||||
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
|
||||
class BiasedAxialAttention(nn.Module):
|
||||
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
||||
super(BiasedAxialAttention, self).__init__()
|
||||
#
|
||||
self.is_row = is_row
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.norm_bias = nn.LayerNorm(d_bias)
|
||||
|
||||
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
# initialize all parameters properly
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, pair, bias):
|
||||
# pair: (B, L, L, d_pair)
|
||||
B, L = pair.shape[:2]
|
||||
|
||||
if self.is_row:
|
||||
pair = pair.permute(0,2,1,3)
|
||||
bias = bias.permute(0,2,1,3)
|
||||
|
||||
pair = self.norm_pair(pair)
|
||||
bias = self.norm_bias(bias)
|
||||
|
||||
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
||||
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
||||
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
||||
bias = self.to_b(bias) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||
|
||||
query = query * self.scaling
|
||||
key = key / L # normalize for tied attention
|
||||
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
|
||||
attn = attn + bias # apply bias
|
||||
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
|
||||
|
||||
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(B, L, L, -1)
|
||||
out = gate * out
|
||||
|
||||
out = self.to_out(out)
|
||||
if self.is_row:
|
||||
out = out.permute(0,2,1,3)
|
||||
return out
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.chemical import NAATOKENS
|
||||
|
||||
class DistanceNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.1):
|
||||
super(DistanceNetwork, self).__init__()
|
||||
#
|
||||
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
|
||||
self.proj_asymm = nn.Linear(n_feat, 37+19)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize linear layer for final logit prediction
|
||||
nn.init.zeros_(self.proj_symm.weight)
|
||||
nn.init.zeros_(self.proj_asymm.weight)
|
||||
nn.init.zeros_(self.proj_symm.bias)
|
||||
nn.init.zeros_(self.proj_asymm.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# input: pair info (B, L, L, C)
|
||||
|
||||
# predict theta, phi (non-symmetric)
|
||||
logits_asymm = self.proj_asymm(x)
|
||||
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
|
||||
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
|
||||
|
||||
# predict dist, omega
|
||||
logits_symm = self.proj_symm(x)
|
||||
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
|
||||
logits_dist = logits_symm[:,:,:,:61].permute(0,3,1,2)
|
||||
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
|
||||
|
||||
return logits_dist, logits_omega, logits_theta, logits_phi
|
||||
|
||||
class MaskedTokenNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.1):
|
||||
super(MaskedTokenNetwork, self).__init__()
|
||||
|
||||
#fd note this predicts probability for the mask token (which is never in ground truth)
|
||||
# it should be ok though(?)
|
||||
self.proj = nn.Linear(n_feat, NAATOKENS)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, L = x.shape[:3]
|
||||
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
|
||||
|
||||
return logits
|
||||
|
||||
class LDDTNetwork(nn.Module):
|
||||
def __init__(self, n_feat, n_bin_lddt=50):
|
||||
super(LDDTNetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, n_bin_lddt)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.proj(x) # (B, L, 50)
|
||||
|
||||
return logits.permute(0,2,1)
|
||||
|
||||
class PAENetwork(nn.Module):
|
||||
def __init__(self, n_feat, n_bin_pae=64):
|
||||
super(PAENetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, n_bin_pae)
|
||||
self.reset_parameter()
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.proj(x) # (B, L, L, 64)
|
||||
|
||||
return logits.permute(0,3,1,2)
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from opt_einsum import contract as einsum
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from util import *
|
||||
from util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
|
||||
from Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
|
||||
from Track_module import PairStr2Pair, PositionalEncoding2D
|
||||
from chemical import NAATOKENS,NTOTALDOFS, NBTYPES
|
||||
|
||||
# Module contains classes and functions to generate initial embeddings
|
||||
|
||||
class MSA_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2,
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
|
||||
super(MSA_emb, self).__init__()
|
||||
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
|
||||
self.emb_left = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||
self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||
self.emb_state = nn.Embedding(NAATOKENS, d_state)
|
||||
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
self.emb_q = init_lecun_normal(self.emb_q)
|
||||
self.emb_left = init_lecun_normal(self.emb_left)
|
||||
self.emb_right = init_lecun_normal(self.emb_right)
|
||||
self.emb_state = init_lecun_normal(self.emb_state)
|
||||
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
def forward(self, msa, seq1hot, idx, bond_feats, same_chain):
|
||||
# Inputs:
|
||||
# - msa: Input MSA (B, N, L, d_init)
|
||||
# - seq: Input Sequence (B, L)
|
||||
# - idx: Residue index
|
||||
# - bond_feats: Bond features (B, L, L)
|
||||
# Outputs:
|
||||
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||
# - pair: Initial Pair embedding (B, L, L, d_pair)
|
||||
|
||||
N = msa.shape[1] # number of sequenes in MSA
|
||||
|
||||
# msa embedding
|
||||
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
|
||||
tmp = (seq1hot @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
|
||||
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
#msa = self.drop(msa)
|
||||
|
||||
# pair embedding
|
||||
left = (seq1hot @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
|
||||
right = (seq1hot @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
|
||||
pair = left + right # (B, L, L, d_pair)
|
||||
pair = pair + self.pos(seq1hot.argmax(-1), idx, bond_feats, same_chain) # add relative position
|
||||
|
||||
# state embedding
|
||||
state = (seq1hot @ self.emb_state.weight)
|
||||
|
||||
return msa, pair, state
|
||||
|
||||
class Extra_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_init=NAATOKENS+1+2, p_drop=0.1):
|
||||
super(Extra_emb, self).__init__()
|
||||
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence
|
||||
#self.drop = nn.Dropout(p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
def forward(self, msa, seq1hot, idx):
|
||||
# Inputs:
|
||||
# - msa: Input MSA (B, N, L, d_init)
|
||||
# - seq: Input Sequence (B, L)
|
||||
# - idx: Residue index
|
||||
# Outputs:
|
||||
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||
N = msa.shape[1] # number of sequenes in MSA
|
||||
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
|
||||
seq_emb = (seq1hot @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||
msa = msa + seq_emb.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
#return self.drop(msa)
|
||||
return (msa)
|
||||
|
||||
class Bond_emb(nn.Module):
|
||||
def __init__(self, d_pair=128, d_init=NBTYPES):
|
||||
super(Bond_emb, self).__init__()
|
||||
self.emb = nn.Linear(d_init, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
def forward(self, bond_feats):
|
||||
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
|
||||
return self.emb(bond_feats.float())
|
||||
|
||||
class TemplatePairStack(nn.Module):
|
||||
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, d_t1d=22, d_state=32, p_drop=0.25):
|
||||
super(TemplatePairStack, self).__init__()
|
||||
self.n_block = n_block
|
||||
self.proj_t1d = nn.Linear(d_t1d, d_state)
|
||||
proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, d_state=d_state, p_drop=p_drop) for i in range(n_block)]
|
||||
self.block = nn.ModuleList(proc_s)
|
||||
self.norm = nn.LayerNorm(d_templ)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||
nn.init.zeros_(self.proj_t1d.bias)
|
||||
|
||||
def forward(self, templ, rbf_feat, t1d, use_checkpoint=False):
|
||||
B, T, L = templ.shape[:3]
|
||||
templ = templ.reshape(B*T, L, L, -1)
|
||||
t1d = t1d.reshape(B*T, L, -1)
|
||||
state = self.proj_t1d(t1d)
|
||||
|
||||
for i_block in range(self.n_block):
|
||||
if use_checkpoint:
|
||||
templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ,
|
||||
rbf_feat, state)
|
||||
else:
|
||||
templ = self.block[i_block](templ, rbf_feat, state)
|
||||
return self.norm(templ).reshape(B, T, L, L, -1)
|
||||
|
||||
|
||||
class Templ_emb(nn.Module):
|
||||
# Get template embedding
|
||||
# Features are
|
||||
# t2d:
|
||||
# - 61 distogram bins + 6 orientations (67)
|
||||
# - Mask (missing/unaligned) (1)
|
||||
# t1d:
|
||||
# - tiled AA sequence (20 standard aa + gap)
|
||||
# - confidence (1)
|
||||
#
|
||||
def __init__(self, d_t1d=(NAATOKENS-1)+1, d_t2d=67+1, d_tor=3*NTOTALDOFS, d_pair=128, d_state=32,
|
||||
n_block=2, d_templ=64,
|
||||
n_head=4, d_hidden=16, p_drop=0.25):
|
||||
super(Templ_emb, self).__init__()
|
||||
# process 2D features
|
||||
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
|
||||
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
d_hidden=d_hidden, d_t1d=d_t1d, d_state=d_state, p_drop=p_drop)
|
||||
|
||||
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
|
||||
|
||||
# process torsion angles
|
||||
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
|
||||
self.proj_t1d = nn.Linear(d_templ, d_templ)
|
||||
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
# d_hidden=d_hidden, p_drop=p_drop)
|
||||
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.emb_t1d.bias)
|
||||
|
||||
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||
nn.init.zeros_(self.proj_t1d.bias)
|
||||
|
||||
def _get_templ_emb(self, t1d, t2d):
|
||||
B, T, L, _ = t1d.shape
|
||||
# Prepare 2D template features
|
||||
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
|
||||
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
|
||||
#
|
||||
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
|
||||
return self.emb(templ) # Template templures (B, T, L, L, d_templ)
|
||||
|
||||
def _get_templ_rbf(self, xyz_t, mask_t):
|
||||
B, T, L = xyz_t.shape[:3]
|
||||
|
||||
# process each template features
|
||||
xyz_t = xyz_t.reshape(B*T, L, 3).contiguous()
|
||||
mask_t = mask_t.reshape(B*T, L, L)
|
||||
assert(xyz_t.is_contiguous())
|
||||
rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf)
|
||||
return rbf_feat
|
||||
|
||||
def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False):
|
||||
# Input
|
||||
# - t1d: 1D template info (B, T, L, 30)
|
||||
# - t2d: 2D template info (B, T, L, L, 44)
|
||||
# - alpha_t: torsion angle info (B, T, L, 30) - DOUBLE-CHECK
|
||||
# - xyz_t: template CA coordinates (B, T, L, 3)
|
||||
# - mask_t: is valid residue pair? (B, T, L, L)
|
||||
# - pair: query pair features (B, L, L, d_pair)
|
||||
# - state: query state features (B, L, d_state)
|
||||
B, T, L, _ = t1d.shape
|
||||
|
||||
templ = self._get_templ_emb(t1d, t2d)
|
||||
rbf_feat = self._get_templ_rbf(xyz_t, mask_t)
|
||||
|
||||
# process each template pair feature
|
||||
templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
|
||||
|
||||
# Prepare 1D template torsion angle features
|
||||
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
|
||||
# process each template features
|
||||
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
|
||||
|
||||
# mixing query state features to template state features
|
||||
state = state.reshape(B*L, 1, -1)
|
||||
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
|
||||
if use_checkpoint:
|
||||
out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
|
||||
out = out.reshape(B, L, -1)
|
||||
else:
|
||||
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
|
||||
state = state.reshape(B, L, -1)
|
||||
state = state + out
|
||||
|
||||
# mixing query pair features to template information (Template pointwise attention)
|
||||
pair = pair.reshape(B*L*L, 1, -1)
|
||||
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
|
||||
if use_checkpoint:
|
||||
out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
|
||||
out = out.reshape(B, L, L, -1)
|
||||
else:
|
||||
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
|
||||
#
|
||||
pair = pair.reshape(B, L, L, -1)
|
||||
pair = pair + out
|
||||
|
||||
return pair, state
|
||||
|
||||
|
||||
class Recycling(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||
super(Recycling, self).__init__()
|
||||
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa)
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||
nn.init.zeros_(self.proj_dist.bias)
|
||||
self.proj_sctors = init_lecun_normal(self.proj_sctors)
|
||||
nn.init.zeros_(self.proj_sctors.bias)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = pair.shape[:2]
|
||||
state = self.norm_state(state)
|
||||
|
||||
left = state.unsqueeze(2).expand(-1,-1,L,-1)
|
||||
right = state.unsqueeze(1).expand(-1,L,-1,-1)
|
||||
|
||||
Ca_or_P = xyz[:,:,1].contiguous()
|
||||
|
||||
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
|
||||
if mask_recycle != None:
|
||||
dist = mask_recycle[...,None].float()*dist
|
||||
dist = torch.cat((dist, left, right), dim=-1)
|
||||
dist = self.proj_dist(dist)
|
||||
pair = dist + self.norm_pair(pair)
|
||||
|
||||
sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS))
|
||||
msa = sctors + self.norm_msa(msa)
|
||||
|
||||
return msa, pair, state
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling
|
||||
from Track_module import IterativeSimulator
|
||||
from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, LDDTNetwork, PAENetwork
|
||||
from chemical import INIT_CRDS,NAATOKENS, NBTYPES
|
||||
from util import Ls_from_same_chain_2d
|
||||
from data_loader import get_term_feats
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
def __init__(
|
||||
self, n_extra_block=4, n_main_block=8, n_ref_block=4, n_finetune_block=0,\
|
||||
d_msa=256, d_msa_full=64, d_pair=128, d_templ=64,
|
||||
n_head_msa=8, n_head_pair=4, n_head_templ=4,
|
||||
d_hidden=32, d_hidden_templ=64, p_drop=0.15,
|
||||
SE3_param={}, SE3_ref_param={},
|
||||
atom_type_index=None, aamask=None, ljlk_parameters=None, lj_correction_parameters=None,
|
||||
cb_len=None, cb_ang=None, cb_tor=None,
|
||||
num_bonds=None, lj_lin=0.6, use_extra_l1=True, use_atom_frames=True
|
||||
):
|
||||
super(RoseTTAFoldModule, self).__init__()
|
||||
#
|
||||
# Input Embeddings
|
||||
d_state = SE3_param['l0_out_features']
|
||||
self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
|
||||
self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=NAATOKENS-1+4, p_drop=p_drop)
|
||||
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES)
|
||||
self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, n_head=n_head_templ,
|
||||
d_hidden=d_hidden_templ, p_drop=0.25)
|
||||
|
||||
# Update inputs with outputs from previous round
|
||||
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
#
|
||||
self.simulator = IterativeSimulator(
|
||||
n_extra_block=n_extra_block,
|
||||
n_main_block=n_main_block,
|
||||
n_ref_block=n_ref_block,
|
||||
n_finetune_block=n_finetune_block,
|
||||
d_msa=d_msa,
|
||||
d_msa_full=d_msa_full,
|
||||
d_pair=d_pair,
|
||||
d_hidden=d_hidden,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
SE3_param=SE3_param,
|
||||
SE3_ref_param=SE3_ref_param,
|
||||
p_drop=p_drop,
|
||||
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
|
||||
aamask=aamask,
|
||||
ljlk_parameters=ljlk_parameters,
|
||||
lj_correction_parameters=lj_correction_parameters,
|
||||
num_bonds=num_bonds,
|
||||
cb_len=cb_len,
|
||||
cb_ang=cb_ang,
|
||||
cb_tor=cb_tor,
|
||||
lj_lin=lj_lin,
|
||||
use_extra_l1=use_extra_l1,
|
||||
)
|
||||
|
||||
##
|
||||
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||
self.lddt_pred = LDDTNetwork(d_state)
|
||||
if use_extra_l1: # extra l1 features introduced at the same time as pAE/pDE heads
|
||||
self.pae_pred = PAENetwork(d_pair)
|
||||
self.pde_pred = PAENetwork(d_pair) # distance error, but use same architecture as aligned error
|
||||
self.use_extra_l1 = use_extra_l1
|
||||
self.use_atom_frames = use_atom_frames
|
||||
|
||||
def forward(
|
||||
self, msa_one_hot, seq_unmasked, xyz, sctors, idx, bond_feats, chirals,
|
||||
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
||||
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None,
|
||||
return_raw=False, return_full=False,
|
||||
use_checkpoint=False
|
||||
):
|
||||
B, N, L = msa_one_hot.shape[:3]
|
||||
|
||||
seq1hot = msa_one_hot[:,0]
|
||||
|
||||
# generate input msa features
|
||||
Ls = Ls_from_same_chain_2d(same_chain)
|
||||
term_feats = get_term_feats(L,Ls).to(msa_one_hot.device)
|
||||
|
||||
msa_feat = torch.cat([msa_one_hot,
|
||||
msa_one_hot,
|
||||
torch.zeros(B,N,L,2).to(msa_one_hot.device),
|
||||
term_feats[None,None].expand(B,N,-1,-1)], dim=3)
|
||||
extra_feat = torch.cat([msa_one_hot,
|
||||
torch.zeros(B,N,L,1).to(msa_one_hot.device),
|
||||
term_feats[None,None].expand(B,N,-1,-1)], dim=3)
|
||||
|
||||
# Get embeddings
|
||||
msa_latent, pair, state = self.latent_emb(msa_feat, seq1hot, idx, bond_feats, same_chain)
|
||||
msa_full = self.full_emb(extra_feat, seq1hot, idx)
|
||||
pair = pair + self.bond_emb(bond_feats)
|
||||
#
|
||||
# Do recycling
|
||||
if msa_prev == None:
|
||||
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||
pair_prev = torch.zeros_like(pair)
|
||||
state_prev = torch.zeros_like(state)
|
||||
|
||||
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
state = state + state_recycle
|
||||
|
||||
# add template embedding
|
||||
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint)
|
||||
|
||||
# Predict coordinates from given inputs
|
||||
msa, pair, xyz, alpha_s, xyz_allatom, state = self.simulator(
|
||||
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx, bond_feats, same_chain, chirals, atom_frames, use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames)
|
||||
|
||||
if return_raw:
|
||||
# get last structure
|
||||
xyz_last = xyz_allatom[-1].unsqueeze(0)
|
||||
return msa[:,0], pair, xyz_last, state, alpha_s[-1], None
|
||||
|
||||
# predict masked amino acids
|
||||
logits_aa = self.aa_pred(msa)
|
||||
|
||||
# predict distogram & orientograms
|
||||
logits = self.c6d_pred(pair)
|
||||
|
||||
# Predict LDDT
|
||||
lddt = self.lddt_pred(state)
|
||||
|
||||
# predict aligned error and distance error
|
||||
if self.use_extra_l1:
|
||||
logits_pae = self.pae_pred(pair)
|
||||
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||
else:
|
||||
logits_pae = None
|
||||
logits_pde = None
|
||||
|
||||
return logits, logits_aa, logits_pae, logits_pde, xyz, alpha_s, xyz_allatom, \
|
||||
lddt, msa[:,0], pair, state
|
||||
@@ -1,88 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import sys, os
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
sys.path.insert(0,script_dir+'../../../SE3Transformer/')
|
||||
|
||||
from util_module import init_lecun_normal_param
|
||||
from se3_transformer.model import SE3Transformer
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
|
||||
class SE3TransformerWrapper(nn.Module):
|
||||
"""SE(3) equivariant GCN with attention"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=2,
|
||||
num_edge_features=32):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.l1_in = l1_in_features
|
||||
self.l1_out = l1_out_features
|
||||
#
|
||||
fiber_edge = Fiber({0: num_edge_features})
|
||||
if l1_out_features > 0:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
|
||||
self.se3 = SE3Transformer(num_layers=num_layers,
|
||||
fiber_in=fiber_in,
|
||||
fiber_hidden=fiber_hidden,
|
||||
fiber_out = fiber_out,
|
||||
num_heads=n_heads,
|
||||
channels_div=div,
|
||||
fiber_edge=fiber_edge,
|
||||
populate_edge="arcsin",
|
||||
final_layer="lin",
|
||||
use_layer_norm=True)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
|
||||
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||
for n, p in self.se3.named_parameters():
|
||||
if "bias" in n:
|
||||
nn.init.zeros_(p)
|
||||
elif len(p.shape) == 1:
|
||||
continue
|
||||
else:
|
||||
if "radial_func" not in n:
|
||||
p = init_lecun_normal_param(p)
|
||||
else:
|
||||
if "net.6" in n:
|
||||
nn.init.zeros_(p)
|
||||
else:
|
||||
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||
|
||||
# make last layers to be zero-initialized
|
||||
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
if self.l1_out > 0:
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
|
||||
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||
if self.l1_in > 0:
|
||||
node_features = {'0': type_0_features, '1': type_1_features}
|
||||
else:
|
||||
node_features = {'0': type_0_features}
|
||||
edge_features = {'0': edge_features}
|
||||
return self.se3(G, node_features, edge_features)
|
||||
@@ -1,788 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from opt_einsum import contract as einsum
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from rf2aa.util_module import *
|
||||
from rf2aa.Attention_module import *
|
||||
from rf2aa.SE3_network import SE3TransformerWrapper
|
||||
from rf2aa.resnet import ResidualNetwork
|
||||
from rf2aa.util import INIT_CRDS, is_atom, xyz_frame_from_rotation_mask
|
||||
from rf2aa.loss import (
|
||||
calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads,
|
||||
calc_lj, calc_cart_bonded, calc_chiral_grads
|
||||
)
|
||||
from rf2aa.chemical import NTOTALDOFS
|
||||
|
||||
|
||||
# Components for three-track blocks
|
||||
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
|
||||
# 2. Pair -> Pair update (biased attention. bias from structure)
|
||||
# 3. MSA -> Pair update (extract coevolution signal)
|
||||
# 4. Str -> Str update (node from MSA, edge from Pair)
|
||||
|
||||
class PositionalEncoding2D(nn.Module):
|
||||
# Add relative positional encoding to pair features
|
||||
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
|
||||
super(PositionalEncoding2D, self).__init__()
|
||||
self.minpos = minpos
|
||||
self.maxpos = maxpos
|
||||
self.maxpos_atom = maxpos_atom
|
||||
self.nbin_res = abs(minpos)+maxpos+2 # include 0 and "unknown" value (maxpos+1)
|
||||
self.nbin_atom = maxpos_atom+2 # include 0 and "unknown" token (maxpos_sm + 1)
|
||||
self.emb_res = nn.Embedding(self.nbin_res, d_pair)
|
||||
self.emb_atom = nn.Embedding(self.nbin_atom, d_pair)
|
||||
self.emb_chain = nn.Embedding(2, d_pair)
|
||||
|
||||
def forward(self, seq, idx, bond_feats, same_chain=None):
|
||||
sm_mask = is_atom(seq[0])
|
||||
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, sm_mask,
|
||||
minpos_res=self.minpos, maxpos_res=self.maxpos, maxpos_atom=self.maxpos_atom)
|
||||
|
||||
bins = torch.arange(self.minpos, self.maxpos+1, device=seq.device)
|
||||
ib_res = torch.bucketize(res_dist, bins).long() # (B, L, L)
|
||||
emb_res = self.emb_res(ib_res) #(B, L, L, d_pair)
|
||||
|
||||
bins = torch.arange(0, self.maxpos_atom+1, device=seq.device)
|
||||
ib_atom = torch.bucketize(atom_dist, bins).long() # (B, L, L)
|
||||
emb_atom = self.emb_atom(ib_atom) #(B, L, L, d_pair)
|
||||
|
||||
out = emb_res + emb_atom
|
||||
|
||||
if same_chain is not None:
|
||||
emb_c = self.emb_chain(same_chain.long()) # this is used for MSA_emb but not in IterBlock
|
||||
out += emb_c
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Update MSA with biased self-attention. bias from Pair & Str
|
||||
class MSAPairStr2MSA(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, d_rbf=64,
|
||||
d_hidden=32, p_drop=0.15, use_global_attn=False):
|
||||
super(MSAPairStr2MSA, self).__init__()
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.emb_rbf = nn.Linear(d_rbf, d_pair)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
self.proj_state = nn.Linear(d_state, d_msa)
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
||||
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
|
||||
n_head=n_head, d_hidden=d_hidden)
|
||||
if use_global_attn:
|
||||
self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
||||
else:
|
||||
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
||||
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
|
||||
|
||||
# Do proper initialization
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distrib
|
||||
self.emb_rbf= init_lecun_normal(self.emb_rbf)
|
||||
self.proj_state = init_lecun_normal(self.proj_state)
|
||||
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.emb_rbf.bias)
|
||||
nn.init.zeros_(self.proj_state.bias)
|
||||
|
||||
def forward(self, msa, pair, rbf_feat, state):
|
||||
'''
|
||||
Inputs:
|
||||
- msa: MSA feature (B, N, L, d_msa)
|
||||
- pair: Pair feature (B, L, L, d_pair)
|
||||
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
|
||||
- xyz: xyz coordinates (B, L, n_atom, 3)
|
||||
- state: updated node features after SE(3)-Transformer layer (B, L, d_state)
|
||||
Output:
|
||||
- msa: Updated MSA feature (B, N, L, d_msa)
|
||||
'''
|
||||
B, N, L = msa.shape[:3]
|
||||
|
||||
# prepare input bias feature by combining pair & coordinate info
|
||||
pair = self.norm_pair(pair)
|
||||
pair = pair + self.emb_rbf(rbf_feat)
|
||||
#
|
||||
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
|
||||
state = self.norm_state(state)
|
||||
state = self.proj_state(state).reshape(B, 1, L, -1)
|
||||
msa = msa.type_as(state)
|
||||
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state)
|
||||
#
|
||||
# Apply row/column attention to msa & transform
|
||||
msa = msa + self.drop_row(self.row_attn(msa, pair))
|
||||
msa = msa + self.col_attn(msa)
|
||||
msa = msa + self.ff(msa)
|
||||
|
||||
return msa
|
||||
|
||||
class PairStr2Pair(nn.Module):
|
||||
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_hidden_state=16, d_rbf=64, d_state=32, p_drop=0.15):
|
||||
super(PairStr2Pair, self).__init__()
|
||||
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
self.proj_left = nn.Linear(d_state, d_hidden_state)
|
||||
self.proj_right = nn.Linear(d_state, d_hidden_state)
|
||||
self.to_gate = nn.Linear(d_hidden_state*d_hidden_state, d_pair)
|
||||
|
||||
self.emb_rbf = nn.Linear(d_rbf, d_pair)
|
||||
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
||||
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplication(d_pair, d_hidden=d_hidden)
|
||||
self.tri_mul_in = TriangleMultiplication(d_pair, d_hidden, outgoing=False)
|
||||
|
||||
self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
|
||||
self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
|
||||
|
||||
self.ff = FeedForwardLayer(d_pair, 2)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb_rbf = init_lecun_normal(self.emb_rbf)
|
||||
nn.init.zeros_(self.emb_rbf.bias)
|
||||
|
||||
self.proj_left = init_lecun_normal(self.proj_left)
|
||||
nn.init.zeros_(self.proj_left.bias)
|
||||
self.proj_right = init_lecun_normal(self.proj_right)
|
||||
nn.init.zeros_(self.proj_right.bias)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_gate.weight)
|
||||
nn.init.ones_(self.to_gate.bias)
|
||||
|
||||
def forward(self, pair, rbf_feat, state):
|
||||
B, L = pair.shape[:2]
|
||||
|
||||
rbf_feat = self.emb_rbf(rbf_feat)
|
||||
|
||||
state = self.norm_state(state)
|
||||
left = self.proj_left(state)
|
||||
right = self.proj_right(state)
|
||||
gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1)
|
||||
gate = torch.sigmoid(self.to_gate(gate))
|
||||
rbf_feat = gate*rbf_feat
|
||||
|
||||
pair = pair + self.drop_row(self.tri_mul_out(pair))
|
||||
pair = pair + self.drop_row(self.tri_mul_in(pair))
|
||||
pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
|
||||
pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
|
||||
pair = pair + self.ff(pair)
|
||||
return pair
|
||||
|
||||
class MSA2Pair(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
|
||||
super(MSA2Pair, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_msa)
|
||||
self.proj_left = nn.Linear(d_msa, d_hidden)
|
||||
self.proj_right = nn.Linear(d_msa, d_hidden)
|
||||
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal initialization
|
||||
self.proj_left = init_lecun_normal(self.proj_left)
|
||||
self.proj_right = init_lecun_normal(self.proj_right)
|
||||
nn.init.zeros_(self.proj_left.bias)
|
||||
nn.init.zeros_(self.proj_right.bias)
|
||||
|
||||
# zero initialize output
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
|
||||
def forward(self, msa, pair):
|
||||
B, N, L = msa.shape[:3]
|
||||
msa = self.norm(msa)
|
||||
left = self.proj_left(msa)
|
||||
right = self.proj_right(msa)
|
||||
right = right / float(N)
|
||||
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
|
||||
out = self.proj_out(out)
|
||||
|
||||
pair = pair + out
|
||||
|
||||
return pair
|
||||
|
||||
class Str2Str(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=16, d_rbf=64,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
nextra_l0=0, nextra_l1=0, p_drop=0.1
|
||||
):
|
||||
super(Str2Str, self).__init__()
|
||||
|
||||
# initial node & pair feature process
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
self.embed_node = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
|
||||
self.ff_node = FeedForwardLayer(SE3_param['l0_in_features'], 2, p_drop=p_drop)
|
||||
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
|
||||
|
||||
self.embed_edge = nn.Linear(d_pair+d_rbf+1, SE3_param['num_edge_features'])
|
||||
self.ff_edge = FeedForwardLayer(SE3_param['num_edge_features'], 2, p_drop=p_drop)
|
||||
self.norm_edge = nn.LayerNorm(SE3_param['num_edge_features'])
|
||||
|
||||
SE3_param_temp = SE3_param.copy()
|
||||
SE3_param_temp['l0_in_features'] += nextra_l0
|
||||
SE3_param_temp['l1_in_features'] += nextra_l1
|
||||
|
||||
self.se3 = SE3TransformerWrapper(**SE3_param_temp)
|
||||
|
||||
self.sc_predictor = SCPred(
|
||||
d_msa=d_msa,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
p_drop=p_drop)
|
||||
|
||||
self.nextra_l0 = nextra_l0
|
||||
self.nextra_l1 = nextra_l1
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.embed_node = init_lecun_normal(self.embed_node)
|
||||
self.embed_edge = init_lecun_normal(self.embed_edge)
|
||||
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.embed_node.bias)
|
||||
nn.init.zeros_(self.embed_edge.bias)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, msa, pair, xyz, state, idx, rotation_mask, bond_feats, atom_frames, extra_l0=None, extra_l1=None, use_atom_frames=True, top_k=128, eps=1e-5):
|
||||
# process msa & pair features
|
||||
B, N, L = msa.shape[:3]
|
||||
seq = self.norm_msa(msa[:,0])
|
||||
pair = self.norm_pair(pair)
|
||||
state = self.norm_state(state)
|
||||
|
||||
node = torch.cat((seq, state), dim=-1)
|
||||
node = self.embed_node(node)
|
||||
node = node + self.ff_node(node)
|
||||
node = self.norm_node(node)
|
||||
|
||||
neighbor = get_seqsep_protein_sm(idx, bond_feats, rotation_mask)
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
rbf_feat = rbf(torch.cdist(cas, cas))
|
||||
edge = torch.cat((pair, rbf_feat, neighbor), dim=-1)
|
||||
edge = self.embed_edge(edge)
|
||||
edge = edge + self.ff_edge(edge)
|
||||
edge = self.norm_edge(edge)
|
||||
|
||||
# define graph
|
||||
if top_k != 0:
|
||||
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=top_k)
|
||||
else:
|
||||
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx)
|
||||
|
||||
if use_atom_frames: # ligand l1 features are vectors to neighboring atoms
|
||||
xyz_frame = xyz_frame_from_rotation_mask(xyz, rotation_mask, atom_frames)
|
||||
l1_feats = xyz_frame - xyz_frame[:,:,1,:].unsqueeze(2)
|
||||
else: # old (incorrect) behavior: vectors to random initial coords of virtual N and C
|
||||
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
|
||||
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
||||
|
||||
if extra_l1 is not None:
|
||||
l1_feats = torch.cat( (l1_feats,extra_l1), dim=1 )
|
||||
if extra_l0 is not None:
|
||||
node = torch.cat( (node,extra_l0), dim=2 )
|
||||
|
||||
# apply SE(3) Transformer & update coordinates
|
||||
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||||
|
||||
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
||||
|
||||
offset = shift['1'].reshape(B, L, 2, 3)
|
||||
T = offset[:,:,0,:] / 10.0
|
||||
R = offset[:,:,1,:] / 100.0
|
||||
|
||||
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
|
||||
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
|
||||
|
||||
v = xyz - xyz[:,:,1:2,:]
|
||||
Rout = torch.zeros((B,L,3,3), device=xyz.device)
|
||||
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
|
||||
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
|
||||
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
|
||||
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
|
||||
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
|
||||
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
|
||||
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
|
||||
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
|
||||
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
|
||||
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
|
||||
Rout = torch.where(rotation_mask.reshape(B, L, 1,1), I, Rout)
|
||||
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
|
||||
|
||||
alpha = self.sc_predictor(msa[:,0], state)
|
||||
|
||||
return xyz, state, alpha
|
||||
|
||||
|
||||
class Allatom2Allatom(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
SE3_param
|
||||
):
|
||||
super(Allatom2Allatom, self).__init__()
|
||||
|
||||
self.se3 = SE3TransformerWrapper(**SE3_param)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, seq, xyz, aamask, num_bonds, state, grads, top_k=24, eps=1e-5):
|
||||
# seq (B,L)
|
||||
# xyz (B,L,27,3)
|
||||
# aamask (22,27) [per-amino-acid]
|
||||
# num_bonds (22,27,27) [per-amino-acid]
|
||||
# state (N,B,L,K) [K channels]
|
||||
# grads (N,B,L,27,3) [N terms]
|
||||
|
||||
B, L = xyz.shape[:2]
|
||||
|
||||
mask = aamask[seq]
|
||||
G, edge = make_atom_graph( xyz, mask, num_bonds[seq], top_k, maxbonds=4 )
|
||||
node = state[mask]
|
||||
node_l1 = grads[:,mask].permute(1,0,2)
|
||||
|
||||
# apply SE(3) Transformer & update coordinates
|
||||
shift = self.se3(G, node[...,None], node_l1, edge)
|
||||
|
||||
state[mask] = shift['0'][...,0]
|
||||
xyz[mask] = xyz[mask] + shift['1'].squeeze(1) / 100.0
|
||||
|
||||
return xyz, state
|
||||
|
||||
class AllatomEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_state_in=64,
|
||||
d_state_out=32,
|
||||
p_mask=0.15
|
||||
):
|
||||
super(AllatomEmbed, self).__init__()
|
||||
|
||||
self.p_mask = p_mask
|
||||
|
||||
# initial node & pair feature process
|
||||
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
|
||||
self.norm_state = nn.LayerNorm(d_state_out)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.compress_embed.bias)
|
||||
|
||||
def forward(self, state, seq, eltmap):
|
||||
B,L = state.shape[:2]
|
||||
mask = torch.rand(B,L) < self.p_mask
|
||||
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
|
||||
state[mask] = 0.0
|
||||
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
|
||||
state = self.compress_embed(
|
||||
torch.cat( (state,elements), dim=-1 )
|
||||
)
|
||||
state = self.norm_state( state )
|
||||
|
||||
return state
|
||||
|
||||
# embed residue state + atomtype -> per-atom state
|
||||
#
|
||||
class AllatomEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_state_in=64,
|
||||
d_state_out=32,
|
||||
p_mask=0.15
|
||||
):
|
||||
super(AllatomEmbed, self).__init__()
|
||||
|
||||
self.p_mask = p_mask
|
||||
|
||||
# initial node & pair feature process
|
||||
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
|
||||
self.norm_state = nn.LayerNorm(d_state_out)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.compress_embed.bias)
|
||||
|
||||
def forward(self, state, seq, eltmap):
|
||||
B,L = state.shape[:2]
|
||||
mask = torch.rand(B,L) < self.p_mask
|
||||
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
|
||||
state[mask] = 0.0
|
||||
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
|
||||
state = self.compress_embed(
|
||||
torch.cat( (state,elements), dim=-1 )
|
||||
)
|
||||
state = self.norm_state( state )
|
||||
|
||||
return state
|
||||
|
||||
# embed per-atom state -> residue state
|
||||
class ResidueEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_state_in=16,
|
||||
d_state_out=64
|
||||
):
|
||||
super(ResidueEmbed, self).__init__()
|
||||
|
||||
self.compress_embed = nn.Linear(27*d_state_in, d_state_out)
|
||||
self.norm_state = nn.LayerNorm(d_state_out)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.compress_embed.bias)
|
||||
|
||||
def forward(self, state):
|
||||
B,L = state.shape[:2]
|
||||
|
||||
state = self.compress_embed( state.reshape(B,L,-1) )
|
||||
state = self.norm_state( state )
|
||||
|
||||
return state
|
||||
|
||||
class SCPred(nn.Module):
|
||||
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
|
||||
super(SCPred, self).__init__()
|
||||
self.norm_s0 = nn.LayerNorm(d_msa)
|
||||
self.norm_si = nn.LayerNorm(d_state)
|
||||
self.linear_s0 = nn.Linear(d_msa, d_hidden)
|
||||
self.linear_si = nn.Linear(d_state, d_hidden)
|
||||
|
||||
# ResNet layers
|
||||
self.linear_1 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_2 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_3 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_4 = nn.Linear(d_hidden, d_hidden)
|
||||
|
||||
# Final outputs
|
||||
self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal initialization
|
||||
self.linear_s0 = init_lecun_normal(self.linear_s0)
|
||||
self.linear_si = init_lecun_normal(self.linear_si)
|
||||
self.linear_out = init_lecun_normal(self.linear_out)
|
||||
nn.init.zeros_(self.linear_s0.bias)
|
||||
nn.init.zeros_(self.linear_si.bias)
|
||||
nn.init.zeros_(self.linear_out.bias)
|
||||
|
||||
# right before relu activation: He initializer (kaiming normal)
|
||||
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear_1.bias)
|
||||
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear_3.bias)
|
||||
|
||||
# right before residual connection: zero initialize
|
||||
nn.init.zeros_(self.linear_2.weight)
|
||||
nn.init.zeros_(self.linear_2.bias)
|
||||
nn.init.zeros_(self.linear_4.weight)
|
||||
nn.init.zeros_(self.linear_4.bias)
|
||||
|
||||
def forward(self, seq, state):
|
||||
'''
|
||||
Predict side-chain torsion angles along with backbone torsions
|
||||
Inputs:
|
||||
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
|
||||
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
|
||||
Outputs:
|
||||
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
|
||||
'''
|
||||
B, L = seq.shape[:2]
|
||||
seq = self.norm_s0(seq)
|
||||
state = self.norm_si(state)
|
||||
si = self.linear_s0(seq) + self.linear_si(state)
|
||||
|
||||
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
|
||||
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
|
||||
|
||||
si = self.linear_out(F.relu_(si))
|
||||
return si.view(B, L, NTOTALDOFS, 2)
|
||||
|
||||
class IterBlock(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_rbf=64,
|
||||
n_head_msa=8, n_head_pair=4,
|
||||
use_global_attn=False,
|
||||
d_hidden=32, d_hidden_msa=None,
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
nextra_l0=0, nextra_l1=0):
|
||||
super(IterBlock, self).__init__()
|
||||
if d_hidden_msa == None:
|
||||
d_hidden_msa = d_hidden
|
||||
|
||||
self.pos = PositionalEncoding2D(d_rbf, minpos=minpos, maxpos=maxpos,
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop)
|
||||
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
|
||||
n_head=n_head_msa,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
use_global_attn=use_global_attn,
|
||||
d_hidden=d_hidden_msa, p_drop=p_drop)
|
||||
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
|
||||
d_hidden=d_hidden//2, p_drop=p_drop)
|
||||
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, d_rbf=d_rbf,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
d_hidden=d_hidden, p_drop=p_drop)
|
||||
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
SE3_param=SE3_param,
|
||||
p_drop=p_drop,
|
||||
nextra_l0=nextra_l0,
|
||||
nextra_l1=nextra_l1)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, seq_unmasked, idx, bond_feats, same_chain, use_checkpoint=False, top_k=128, rotation_mask=None, atom_frames=None, extra_l0=None, extra_l1=None, use_atom_frames=True):
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, same_chain)
|
||||
if use_checkpoint:
|
||||
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
|
||||
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
|
||||
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat, state)
|
||||
|
||||
xyz, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=top_k),
|
||||
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, bond_feats, atom_frames, extra_l0, extra_l1, use_atom_frames)
|
||||
|
||||
else:
|
||||
msa = self.msa2msa(msa, pair, rbf_feat, state)
|
||||
pair = self.msa2pair(msa, pair)
|
||||
pair = self.pair2pair(pair, rbf_feat, state)
|
||||
|
||||
xyz, state, alpha = self.str2str(msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, bond_feats, atom_frames, extra_l0, extra_l1, use_atom_frames, top_k=top_k)
|
||||
|
||||
return msa, pair, xyz, state, alpha
|
||||
|
||||
class IterativeSimulator(nn.Module):
|
||||
def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, n_finetune_block=0,
|
||||
d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
|
||||
n_head_msa=8, n_head_pair=4,
|
||||
SE3_param={}, SE3_ref_param={}, p_drop=0.15,
|
||||
atom_type_index=None, aamask=None,
|
||||
ljlk_parameters=None, lj_correction_parameters=None,
|
||||
cb_len=None, cb_ang=None, cb_tor=None,
|
||||
num_bonds=None, lj_lin=0.6, use_extra_l1=True
|
||||
):
|
||||
super(IterativeSimulator, self).__init__()
|
||||
self.n_extra_block = n_extra_block
|
||||
self.n_main_block = n_main_block
|
||||
self.n_ref_block = n_ref_block
|
||||
self.n_finetune_block = n_finetune_block
|
||||
|
||||
self.atom_type_index = atom_type_index
|
||||
self.aamask = aamask
|
||||
self.ljlk_parameters = ljlk_parameters
|
||||
self.lj_correction_parameters = lj_correction_parameters
|
||||
self.num_bonds = num_bonds
|
||||
self.lj_lin = lj_lin
|
||||
self.cb_len = cb_len
|
||||
self.cb_ang = cb_ang
|
||||
self.cb_tor = cb_tor
|
||||
self.use_extra_l1 = use_extra_l1 # set to False to not use chiral & LJ grads
|
||||
|
||||
# Update with extra sequences
|
||||
if n_extra_block > 0:
|
||||
self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
d_hidden_msa=8,
|
||||
d_hidden=d_hidden,
|
||||
p_drop=p_drop,
|
||||
use_global_attn=True,
|
||||
SE3_param=SE3_param,
|
||||
nextra_l1=3 if self.use_extra_l1 else 0)
|
||||
for i in range(n_extra_block)])
|
||||
|
||||
# Update with seed sequences
|
||||
if n_main_block > 0:
|
||||
self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
d_hidden=d_hidden,
|
||||
p_drop=p_drop,
|
||||
use_global_attn=False,
|
||||
SE3_param=SE3_param,
|
||||
nextra_l1=3 if self.use_extra_l1 else 0)
|
||||
for i in range(n_main_block)])
|
||||
|
||||
# Final SE(3) refinement
|
||||
if n_ref_block > 0:
|
||||
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
SE3_param=SE3_ref_param,
|
||||
p_drop=p_drop,
|
||||
nextra_l0=2*NTOTALDOFS if self.use_extra_l1 else 0,
|
||||
nextra_l1=6 if self.use_extra_l1 else 0
|
||||
)
|
||||
|
||||
# Fine-tuning all-atom SE(3) refinement
|
||||
if n_finetune_block > 0:
|
||||
d_state=16
|
||||
self.allatom_embed = AllatomEmbed(
|
||||
d_state_in = SE3_param['l0_out_features'],
|
||||
d_state_out = d_state,
|
||||
p_mask = 0.15
|
||||
)
|
||||
self.finetune_refiner = Allatom2Allatom(
|
||||
SE3_param = {
|
||||
'num_layers':1,
|
||||
'num_channels':16,
|
||||
'num_degrees':2,
|
||||
'l0_in_features':d_state,
|
||||
'l0_out_features':d_state,
|
||||
'l1_in_features':2,
|
||||
'l1_out_features':1,
|
||||
'num_edge_features':4,
|
||||
'n_heads':4,
|
||||
'div':2,
|
||||
}
|
||||
)
|
||||
self.residue_embed = ResidueEmbed(
|
||||
d_state_in = d_state,
|
||||
d_state_out = SE3_param['l0_out_features']
|
||||
)
|
||||
|
||||
# To get all-atom coordinates
|
||||
self.compute_allatom_coords = ComputeAllAtomCoords()
|
||||
|
||||
|
||||
def forward(self, seq_unmasked, msa, msa_full, pair, xyz, state, idx, bond_feats, same_chain, chirals, atom_frames=None, use_checkpoint=False, use_atom_frames=True):
|
||||
# input:
|
||||
# msa: initial MSA embeddings (N, L, d_msa)
|
||||
# pair: initial residue pair embeddings (L, L, d_pair)
|
||||
rotation_mask = is_atom(seq_unmasked)
|
||||
xyz_s = list()
|
||||
alpha_s = list()
|
||||
for i_m in range(self.n_extra_block):
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
if self.use_extra_l1:
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l1 = dchiraldxyz[0].detach()
|
||||
msa_full, pair, xyz, state, alpha = self.extra_block[i_m](msa_full, pair,
|
||||
xyz, state, seq_unmasked, idx, bond_feats,
|
||||
same_chain,
|
||||
use_checkpoint=use_checkpoint,
|
||||
top_k=0, rotation_mask=rotation_mask,
|
||||
atom_frames=atom_frames,
|
||||
extra_l0=extra_l0,
|
||||
extra_l1=extra_l1,
|
||||
use_atom_frames=use_atom_frames)
|
||||
xyz_s.append(xyz)
|
||||
alpha_s.append(alpha)
|
||||
|
||||
for i_m in range(self.n_main_block):
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
if self.use_extra_l1:
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l1 = dchiraldxyz[0].detach()
|
||||
msa, pair, xyz, state, alpha = self.main_block[i_m](msa, pair,
|
||||
xyz, state, seq_unmasked, idx, bond_feats,
|
||||
same_chain,
|
||||
use_checkpoint=use_checkpoint,
|
||||
top_k=0, rotation_mask=rotation_mask,
|
||||
atom_frames=atom_frames,
|
||||
extra_l0=extra_l0,
|
||||
extra_l1=extra_l1,
|
||||
use_atom_frames=use_atom_frames)
|
||||
xyz_s.append(xyz)
|
||||
alpha_s.append(alpha)
|
||||
|
||||
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
|
||||
# now use unmasked seq (no cross-talk for msa prediction)
|
||||
for i_m in range(self.n_ref_block):
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
if self.use_extra_l1:
|
||||
# dbonddxyz, = calc_BB_bond_geom_grads(seq_unmasked[0], xyz.detach(), idx)
|
||||
dljdxyz, dljdalpha = calc_lj_grads(
|
||||
seq_unmasked, xyz.detach(), alpha.detach(),
|
||||
self.compute_allatom_coords, bond_feats,
|
||||
self.aamask,
|
||||
self.ljlk_parameters,
|
||||
self.lj_correction_parameters,
|
||||
self.num_bonds,
|
||||
lj_lin=self.lj_lin)
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach()
|
||||
extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
|
||||
|
||||
xyz, state, alpha = self.str_refiner(
|
||||
msa, pair, xyz.detach(), state, idx, rotation_mask, bond_feats, atom_frames,
|
||||
extra_l0, extra_l1, top_k=128, use_atom_frames=use_atom_frames)
|
||||
|
||||
xyz_s.append(xyz)
|
||||
alpha_s.append(alpha)
|
||||
|
||||
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
|
||||
xyzallatom_s = list()
|
||||
xyzallatom_s.append(xyzallatom.clone())
|
||||
if (self.n_finetune_block>0):
|
||||
state = self.allatom_embed(state, seq_unmasked, self.atom_type_index)
|
||||
|
||||
for i_m in range(self.n_finetune_block):
|
||||
# dbonddxyz, = calc_cart_bonded_grads(
|
||||
# seq_unmasked, xyzallatom.detach(), idx,
|
||||
# self.cb_len, self.cb_ang, self.cb_tor
|
||||
# )
|
||||
# dljdxyz, = calc_ljallatom_grads(
|
||||
# seq_unmasked,
|
||||
# xyzallatom.detach(),
|
||||
# self.aamask,
|
||||
# self.ljlk_parameters,
|
||||
# self.lj_correction_parameters,
|
||||
# self.num_bonds,
|
||||
# lj_lin=self.lj_lin
|
||||
# )
|
||||
|
||||
# extra_l1 = torch.stack((dbonddxyz.detach(), dljdxyz.detach()))
|
||||
extra_l1 = None
|
||||
|
||||
xyzallatom, state = self.finetune_refiner(
|
||||
seq_unmasked,
|
||||
xyzallatom.detach().float(),
|
||||
self.aamask,
|
||||
self.num_bonds,
|
||||
state,
|
||||
extra_l1.float()
|
||||
)
|
||||
|
||||
# cb_loss = calc_cart_bonded(
|
||||
# seq_unmasked, xyzallatom.detach(), idx,
|
||||
# self.cb_len, self.cb_ang, self.cb_tor
|
||||
# )
|
||||
# lj_loss = calc_lj(
|
||||
# seq_unmasked[0],
|
||||
# xyzallatom.detach(),
|
||||
# self.aamask,
|
||||
# self.ljlk_parameters,
|
||||
# self.lj_correction_parameters,
|
||||
# self.num_bonds,
|
||||
# lj_lin=self.lj_lin
|
||||
# )
|
||||
|
||||
xyzallatom_s.append(xyzallatom.clone())
|
||||
|
||||
state = self.residue_embed(state)
|
||||
|
||||
xyz = torch.stack(xyz_s, dim=0)
|
||||
alpha_s = torch.stack(alpha_s, dim=0)
|
||||
xyzallatom_s = torch.cat(xyzallatom_s, dim=0)
|
||||
|
||||
return msa, pair, xyz, alpha_s, xyzallatom_s, state
|
||||
@@ -1,44 +0,0 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
import pandas as pd
|
||||
|
||||
def parse_training_log(filename):
|
||||
headers = ['Local','Train','Monomer','Homo','Hetero','NA','NAfs','RNA','SM Compl']
|
||||
records = []
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if line.startswith("Header"):
|
||||
# renaming a few things from an earlier version of training script
|
||||
for src,tgt in [('# epochs','num_epochs'), ('processed','examples_seen_in_epoch'),
|
||||
('examples in epoch','examples_per_epoch'), ('Max mem','max_mem'),
|
||||
('seconds','time'), ('total_loss: loss','Total_loss: total_loss')]:
|
||||
line = line.replace(src,tgt)
|
||||
|
||||
columns = re.findall('(\w+)',line)
|
||||
for val in ['Batch','Time','Total_loss']:
|
||||
if val in columns: columns.remove(val)
|
||||
|
||||
if any([line.startswith(h) for h in headers]):
|
||||
values = [line.split(':')[0]]+[float(x) for x in re.findall('(\d+\.*\d*)',line)]
|
||||
records.append(OrderedDict(zip(columns, values)))
|
||||
|
||||
df = pd.DataFrame.from_records(records)
|
||||
df = df.drop_duplicates(['Header','epoch','examples_seen_in_epoch'])
|
||||
|
||||
df_s = []
|
||||
offset = 0
|
||||
for ep in df['epoch'].drop_duplicates():
|
||||
tmp = df[df['epoch']==ep]
|
||||
|
||||
n_per_epoch = tmp['examples_per_epoch'].values[0]
|
||||
tmp['example'] = tmp['examples_seen_in_epoch']+offset
|
||||
offset += n_per_epoch
|
||||
|
||||
mask = tmp['Header']!='Local'
|
||||
tmp.loc[mask,'example'] = offset
|
||||
|
||||
df_s.append(tmp)
|
||||
df = pd.concat(df_s)
|
||||
|
||||
return df
|
||||
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,78 +0,0 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.spatial
|
||||
from rf2aa.chemical import generate_Cbeta
|
||||
|
||||
# calculate dihedral angles defined by 4 sets of points
|
||||
def get_dihedrals(a, b, c, d):
|
||||
|
||||
b0 = -1.0*(b - a)
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
|
||||
b1 /= np.linalg.norm(b1, axis=-1)[:,None]
|
||||
|
||||
v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
|
||||
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
|
||||
|
||||
x = np.sum(v*w, axis=-1)
|
||||
y = np.sum(np.cross(b1, v)*w, axis=-1)
|
||||
|
||||
return np.arctan2(y, x)
|
||||
|
||||
# calculate planar angles defined by 3 sets of points
|
||||
def get_angles(a, b, c):
|
||||
|
||||
v = a - b
|
||||
v /= np.linalg.norm(v, axis=-1)[:,None]
|
||||
|
||||
w = c - b
|
||||
w /= np.linalg.norm(w, axis=-1)[:,None]
|
||||
|
||||
x = np.sum(v*w, axis=1)
|
||||
|
||||
#return np.arccos(x)
|
||||
return np.arccos(np.clip(x, -1.0, 1.0))
|
||||
|
||||
# get 6d coordinates from x,y,z coords of N,Ca,C atoms
|
||||
def get_coords6d(xyz, dmax):
|
||||
|
||||
nres = xyz.shape[1]
|
||||
|
||||
# three anchor atoms
|
||||
N = xyz[0]
|
||||
Ca = xyz[1]
|
||||
C = xyz[2]
|
||||
|
||||
# recreate Cb given N,Ca,C
|
||||
Cb = generate_Cbeta(N,Ca,C)
|
||||
|
||||
# fast neighbors search to collect all
|
||||
# Cb-Cb pairs within dmax
|
||||
kdCb = scipy.spatial.cKDTree(Cb)
|
||||
indices = kdCb.query_ball_tree(kdCb, dmax)
|
||||
|
||||
# indices of contacting residues
|
||||
idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T
|
||||
idx0 = idx[0]
|
||||
idx1 = idx[1]
|
||||
|
||||
# Cb-Cb distance matrix
|
||||
dist6d = np.full((nres, nres),999.9, dtype=np.float32)
|
||||
dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1)
|
||||
|
||||
# matrix of Ca-Cb-Cb-Ca dihedrals
|
||||
omega6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
|
||||
|
||||
# matrix of polar coord theta
|
||||
theta6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
|
||||
|
||||
# matrix of polar coord phi
|
||||
phi6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1])
|
||||
|
||||
mask = np.zeros((nres, nres), dtype=np.float32)
|
||||
mask[idx0, idx1] = 1.0
|
||||
return dist6d, omega6d, theta6d, phi6d, mask
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,91 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
|
||||
|
||||
'''
|
||||
Created on Apr 30, 2014
|
||||
|
||||
@author: meiermark
|
||||
'''
|
||||
|
||||
|
||||
import sys
|
||||
import mmap
|
||||
from collections import namedtuple
|
||||
|
||||
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
|
||||
|
||||
|
||||
def read_index(ffindex_filename):
|
||||
entries = []
|
||||
|
||||
fh = open(ffindex_filename)
|
||||
for line in fh:
|
||||
tokens = line.split("\t")
|
||||
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
|
||||
fh.close()
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def read_data(ffdata_filename):
|
||||
fh = open(ffdata_filename, "r+b")
|
||||
data = mmap.mmap(fh.fileno(), 0)
|
||||
fh.close()
|
||||
return data
|
||||
|
||||
|
||||
def get_entry_by_name(name, index):
|
||||
#TODO: bsearch
|
||||
for entry in index:
|
||||
if(name == entry.name):
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
def read_entry_lines(entry, data):
|
||||
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
|
||||
return lines
|
||||
|
||||
|
||||
def read_entry_data(entry, data):
|
||||
return data[entry.offset:entry.offset + entry.length - 1]
|
||||
|
||||
|
||||
def write_entry(entries, data_fh, entry_name, offset, data):
|
||||
data_fh.write(data[:-1])
|
||||
data_fh.write(bytearray(1))
|
||||
|
||||
entry = FFindexEntry(entry_name, offset, len(data))
|
||||
entries.append(entry)
|
||||
|
||||
return offset + len(data)
|
||||
|
||||
|
||||
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
|
||||
with open(file_name, "rb") as fh:
|
||||
data = bytearray(fh.read())
|
||||
return write_entry(entries, data_fh, entry_name, offset, data)
|
||||
|
||||
|
||||
def finish_db(entries, ffindex_filename, data_fh):
|
||||
data_fh.close()
|
||||
write_entries_to_db(entries, ffindex_filename)
|
||||
|
||||
|
||||
def write_entries_to_db(entries, ffindex_filename):
|
||||
sorted(entries, key=lambda x: x.name)
|
||||
index_fh = open(ffindex_filename, "w")
|
||||
|
||||
for entry in entries:
|
||||
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
|
||||
|
||||
index_fh.close()
|
||||
|
||||
|
||||
def write_entry_to_file(entry, data, file):
|
||||
lines = read_lines(entry, data)
|
||||
|
||||
fh = open(file, "w")
|
||||
for line in lines:
|
||||
fh.write(line+"\n")
|
||||
fh.close()
|
||||
@@ -1,292 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from openbabel import openbabel
|
||||
from chemical import aachirals, NTOTAL, generate_Cbeta
|
||||
|
||||
PARAMS = {
|
||||
'DMIN':1,
|
||||
'DMID':4,
|
||||
'DMAX':20.0,
|
||||
'DBINS1':30,
|
||||
'DBINS2':30,
|
||||
'ABINS':36
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
def get_pair_dist(a, b):
|
||||
"""calculate pair distances between two sets of points
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of two sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dist : pytorch tensor of shape [batch,nres,nres]
|
||||
stores paitwise distances between atoms in a and b
|
||||
"""
|
||||
|
||||
dist = torch.cdist(a, b, p=2)
|
||||
return dist
|
||||
|
||||
# ============================================================
|
||||
def get_ang(a, b, c, eps=1e-6):
|
||||
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
||||
from Cartesian coordinates of three sets of atoms a,b,c
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of three sets of atoms
|
||||
Returns
|
||||
-------
|
||||
ang : pytorch tensor of shape [batch,nres]
|
||||
stores resulting planar angles
|
||||
"""
|
||||
v = a - b
|
||||
w = c - b
|
||||
vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps)
|
||||
wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps)
|
||||
vw = torch.sum(vn*wn, dim=-1)
|
||||
|
||||
return torch.acos(torch.clamp(vw,-0.999,0.999))
|
||||
|
||||
# ============================================================
|
||||
def get_dih(a, b, c, d, eps=1e-6):
|
||||
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
||||
given Cartesian coordinates of four sets of atoms a,b,c,d
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of four sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dih : pytorch tensor of shape [batch,nres]
|
||||
stores resulting dihedrals
|
||||
"""
|
||||
b0 = a - b
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
|
||||
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
|
||||
|
||||
v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n
|
||||
w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n
|
||||
|
||||
x = torch.sum(v*w, dim=-1)
|
||||
y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1)
|
||||
|
||||
return torch.atan2(y+eps, x+eps)
|
||||
|
||||
|
||||
# ============================================================
|
||||
def xyz_to_c6d(xyz, params=PARAMS):
|
||||
"""convert cartesian coordinates into 2d distance
|
||||
and orientation maps
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xyz : pytorch tensor of shape [batch,nres,3,3]
|
||||
stores Cartesian coordinates of backbone N,Ca,C atoms
|
||||
Returns
|
||||
-------
|
||||
c6d : pytorch tensor of shape [batch,nres,nres,4]
|
||||
stores stacked dist,omega,theta,phi 2D maps
|
||||
"""
|
||||
|
||||
batch = xyz.shape[0]
|
||||
nres = xyz.shape[1]
|
||||
|
||||
# three anchor atoms
|
||||
N = xyz[:,:,0]
|
||||
Ca = xyz[:,:,1]
|
||||
C = xyz[:,:,2]
|
||||
|
||||
# recreate Cb given N,Ca,C
|
||||
Cb = generate_Cbeta(N,Ca,C)
|
||||
|
||||
# 6d coordinates order: (dist,omega,theta,phi)
|
||||
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
|
||||
|
||||
dist = get_pair_dist(Cb,Cb)
|
||||
dist[torch.isnan(dist)] = 999.9
|
||||
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
|
||||
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
|
||||
|
||||
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
|
||||
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
|
||||
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
|
||||
|
||||
# fix long-range distances
|
||||
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
|
||||
c6d = torch.nan_to_num(c6d)
|
||||
|
||||
return c6d
|
||||
|
||||
def xyz_to_t2d(xyz_t, mask, params=PARAMS):
|
||||
"""convert template cartesian coordinates into 2d distance
|
||||
and orientation maps
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
|
||||
stores Cartesian coordinates of template backbone N,Ca,C atoms
|
||||
mask : pytorch tensor [batch,templ,nres,nres]
|
||||
indicates whether valid residue pairs or not
|
||||
Returns
|
||||
-------
|
||||
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
|
||||
stores stacked dist,omega,theta,phi 2D maps
|
||||
"""
|
||||
B, T, L = xyz_t.shape[:3]
|
||||
c6d = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params)
|
||||
c6d = c6d.view(B, T, L, L, 4)
|
||||
|
||||
# dist to one-hot encoded
|
||||
mask = mask[...,None]
|
||||
dist = dist_to_onehot(c6d[...,0], params)*mask
|
||||
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
|
||||
#
|
||||
t2d = torch.cat((dist, orien, mask), dim=-1)
|
||||
return t2d
|
||||
|
||||
def xyz_to_bbtor(xyz, params=PARAMS):
|
||||
batch = xyz.shape[0]
|
||||
nres = xyz.shape[1]
|
||||
|
||||
# three anchor atoms
|
||||
N = xyz[:,:,0]
|
||||
Ca = xyz[:,:,1]
|
||||
C = xyz[:,:,2]
|
||||
|
||||
# recreate Cb given N,Ca,C
|
||||
next_N = torch.roll(N, -1, dims=1)
|
||||
prev_C = torch.roll(C, 1, dims=1)
|
||||
phi = get_dih(prev_C, N, Ca, C)
|
||||
psi = get_dih(N, Ca, C, next_N)
|
||||
#
|
||||
phi[:,0] = 0.0
|
||||
psi[:,-1] = 0.0
|
||||
#
|
||||
astep = 2.0*np.pi / params['ABINS']
|
||||
phi_bin = torch.round((phi+np.pi-astep/2)/astep)
|
||||
psi_bin = torch.round((psi+np.pi-astep/2)/astep)
|
||||
return torch.stack([phi_bin, psi_bin], axis=-1).long()
|
||||
|
||||
# ============================================================
|
||||
def dist_to_onehot(dist, params=PARAMS):
|
||||
db = dist_to_bins(dist, params)
|
||||
dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS1'] + params['DBINS2']+1).float()
|
||||
return dist
|
||||
|
||||
# ============================================================
|
||||
def dist_to_bins(dist,params=PARAMS):
|
||||
"""bin 2d distance maps
|
||||
"""
|
||||
dist[torch.isnan(dist)] = 999.9
|
||||
dstep1 = (params['DMID'] - params['DMIN']) / params['DBINS1']
|
||||
dstep2 = (params['DMAX'] - params['DMID']) / params['DBINS2']
|
||||
dbins = torch.cat([
|
||||
torch.linspace(params['DMIN']+dstep1, params['DMID'], params['DBINS1'],
|
||||
dtype=dist.dtype,device=dist.device),
|
||||
torch.linspace(params['DMID']+dstep2, params['DMAX'], params['DBINS2'],
|
||||
dtype=dist.dtype,device=dist.device),
|
||||
])
|
||||
db = torch.bucketize(dist.contiguous(),dbins).long()
|
||||
|
||||
return db
|
||||
|
||||
# ============================================================
|
||||
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
|
||||
"""bin 2d distance and orientation maps
|
||||
"""
|
||||
|
||||
db = dist_to_bins(c6d[...,0], params) # all dist < DMIN are in bin 0
|
||||
|
||||
astep = 2.0*np.pi / params['ABINS']
|
||||
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
|
||||
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
|
||||
pb = torch.round((c6d[...,3]-astep/2)/astep)
|
||||
|
||||
# synchronize no-contact bins
|
||||
params['DBINS'] = params['DBINS1'] + params['DBINS2']
|
||||
ob[db==params['DBINS']] = params['ABINS']
|
||||
tb[db==params['DBINS']] = params['ABINS']
|
||||
pb[db==params['DBINS']] = params['ABINS']//2
|
||||
|
||||
if negative:
|
||||
db = torch.where(same_chain.bool(), db.long(), params['DBINS'])
|
||||
ob = torch.where(same_chain.bool(), ob.long(), params['ABINS'])
|
||||
tb = torch.where(same_chain.bool(), tb.long(), params['ABINS'])
|
||||
pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2)
|
||||
|
||||
return torch.stack([db,ob,tb,pb],axis=-1).long()
|
||||
|
||||
def get_chirals(obmol, xyz):
|
||||
'''get all quadruples of atoms forming chiral centers
|
||||
'''
|
||||
|
||||
# detect stereo centers
|
||||
# stereo = [a for a in openbabel.OBMolAtomIter(obmol)
|
||||
# if (a.GetHvyDegree()==3 or a.GetHvyDegree()==4) and a.GetHyb()!=2]
|
||||
stereo = openbabel.OBStereoFacade(obmol)
|
||||
stereo_atoms = [obmol.GetAtom(i+1) for i in range(obmol.NumAtoms()) if stereo.HasTetrahedralStereo(i)]
|
||||
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
|
||||
|
||||
chirals = []
|
||||
for o in stereo_atoms:
|
||||
neigh = [b.GetIdx() for b in openbabel.OBAtomAtomIter(o)]
|
||||
if len(neigh)==3:
|
||||
chirals.append([o.GetIdx(),*neigh,angle])
|
||||
elif len(neigh)==4:
|
||||
a,b,c,d = neigh
|
||||
chirals.extend([[o.GetIdx(),a,b,c,angle],
|
||||
[o.GetIdx(),b,a,d,angle],
|
||||
[o.GetIdx(),a,c,d,angle],
|
||||
[o.GetIdx(),c,b,d,angle]])
|
||||
|
||||
|
||||
n = len(chirals)
|
||||
if n>0:
|
||||
chirals = torch.tensor(chirals*3).float()
|
||||
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
|
||||
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
|
||||
chirals[:,:-1] -= 1
|
||||
dih = get_dih(*xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
|
||||
chirals[dih<0.0,-1] = -angle
|
||||
else:
|
||||
chirals = torch.Tensor()
|
||||
|
||||
return chirals
|
||||
|
||||
def get_atomize_protein_chirals(residues_atomize, lig_xyz, residue_atomize_mask, bond_feats):
|
||||
"""
|
||||
Enumerate chiral centers in residues and provide features for chiral centers
|
||||
"""
|
||||
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
|
||||
chiral_atoms = aachirals[residues_atomize]
|
||||
ra = residue_atomize_mask.nonzero()
|
||||
r,a = ra.T
|
||||
|
||||
chiral_atoms = chiral_atoms[r,a].nonzero().squeeze(1) #num_chiral_centers
|
||||
num_chiral_centers = chiral_atoms.shape[0]
|
||||
chiral_bonds = bond_feats[chiral_atoms] # find bonds to each chiral atom
|
||||
chiral_bonds_idx = chiral_bonds.nonzero() # find indices of each bonded neighbor to chiral atom
|
||||
# in practice all chiral atoms in proteins have 3 heavy atom neighbors, so reshape to 3
|
||||
chiral_bonds_idx = chiral_bonds_idx.reshape(num_chiral_centers, 3, 2)
|
||||
|
||||
chirals = torch.zeros((num_chiral_centers, 5))
|
||||
chirals[:,0] = chiral_atoms.long()
|
||||
chirals[:, 1:-1] = chiral_bonds_idx[...,-1].long()
|
||||
chirals[:, -1] = angle
|
||||
n = chirals.shape[0]
|
||||
if n>0:
|
||||
chirals = chirals.repeat(3,1).float()
|
||||
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
|
||||
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
|
||||
dih = get_dih(*lig_xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
|
||||
chirals[dih<0.0,-1] = -angle
|
||||
else:
|
||||
chirals = torch.Tensor()
|
||||
return chirals
|
||||
@@ -1,985 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import scipy
|
||||
import networkx as nx
|
||||
|
||||
from util import (
|
||||
rigid_from_3_points,
|
||||
cb_lengths_CN,
|
||||
cb_angles_CACN,
|
||||
cb_angles_CNCA,
|
||||
cb_torsions_CACNH,
|
||||
cb_torsions_CANCO,
|
||||
is_nucleic,
|
||||
find_all_paths_of_length_n,
|
||||
find_all_rigid_groups
|
||||
)
|
||||
from chemical import NFRAMES, NTOTAL
|
||||
|
||||
from kinematics import get_dih, get_ang
|
||||
from scoring import HbHybType
|
||||
|
||||
# Loss functions for the training
|
||||
# 1. BB rmsd loss
|
||||
# 2. distance loss (or 6D loss?)
|
||||
# 3. bond geometry loss
|
||||
# 4. predicted lddt loss
|
||||
|
||||
#fd use improved coordinate frame generation
|
||||
def get_t(N, Ca, C, eps=1e-5):
|
||||
I,B,L=N.shape[:3]
|
||||
Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), eps=eps)
|
||||
Rs = Rs.view(I,B,L,3,3)
|
||||
Ts = Ts.view(I,B,L,3)
|
||||
t = Ts.unsqueeze(-2) - Ts.unsqueeze(-3)
|
||||
return torch.einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3) **fixed
|
||||
|
||||
def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra=10.0, d_clamp_inter=30.0, A=10.0, gamma=0.99, eps=1e-6):
|
||||
'''
|
||||
Calculate Backbone FAPE loss
|
||||
Input:
|
||||
- pred: predicted coordinates (I, B, L, n_atom, 3)
|
||||
- true: true coordinates (B, L, n_atom, 3)
|
||||
Output: str loss
|
||||
'''
|
||||
I = pred.shape[0]
|
||||
true = true.unsqueeze(0)
|
||||
t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
|
||||
t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])
|
||||
|
||||
difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
|
||||
clamp = torch.zeros_like(difference)
|
||||
clamp[:,same_chain==1] = d_clamp_intra
|
||||
clamp[:,same_chain==0] = d_clamp_inter
|
||||
difference = torch.clamp(difference, max=clamp)
|
||||
loss = difference / A # (I, B, L, L)
|
||||
|
||||
# Get a mask information (ignore missing residue + inter-chain residues)
|
||||
# for positive cases, mask = mask_2d
|
||||
# for negative cases (non-interacting pairs) mask = mask_2d*same_chain
|
||||
if negative:
|
||||
mask = mask_2d * same_chain
|
||||
else:
|
||||
mask = mask_2d
|
||||
# calculate masked loss (ignore missing regions when calculate loss)
|
||||
loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
|
||||
|
||||
# weighting loss
|
||||
w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
|
||||
w_loss = torch.flip(w_loss, (0,))
|
||||
w_loss = w_loss / w_loss.sum()
|
||||
|
||||
tot_loss = (w_loss * loss).sum()
|
||||
return tot_loss, loss.detach()
|
||||
|
||||
#resolve rotationally equivalent sidechains
|
||||
def resolve_symmetry(xs, Rsnat_all, xsnat, Rsnat_all_alt, xsnat_alt, atm_mask):
|
||||
dists = torch.linalg.norm( xs[:,:,None,:] - xs[atm_mask,:][None,None,:,:], dim=-1)
|
||||
dists_nat = torch.linalg.norm( xsnat[:,:,None,:] - xsnat[atm_mask,:][None,None,:,:], dim=-1)
|
||||
dists_natalt = torch.linalg.norm( xsnat_alt[:,:,None,:] - xsnat_alt[atm_mask,:][None,None,:,:], dim=-1)
|
||||
|
||||
drms_nat = torch.sum(torch.abs(dists_nat-dists),dim=(-1,-2))
|
||||
drms_natalt = torch.sum(torch.abs(dists_nat-dists_natalt), dim=(-1,-2))
|
||||
|
||||
Rsnat_symm = Rsnat_all
|
||||
xs_symm = xsnat
|
||||
|
||||
toflip = drms_natalt<drms_nat
|
||||
|
||||
Rsnat_symm[toflip,...] = Rsnat_all_alt[toflip,...]
|
||||
xs_symm[toflip,...] = xsnat_alt[toflip,...]
|
||||
|
||||
return Rsnat_symm, xs_symm
|
||||
|
||||
# resolve "equivalent" natives
|
||||
def resolve_equiv_natives(xs, natstack, maskstack):
|
||||
if (len(natstack.shape)==4):
|
||||
return natstack, maskstack
|
||||
if (natstack.shape[1]==1):
|
||||
return natstack[:,0,...], maskstack[:,0,...]
|
||||
dx = torch.norm( xs[:,None,:,None,1,:]-xs[:,None,None,:,1,:], dim=-1)
|
||||
dnat = torch.norm( natstack[:,:,:,None,1,:]-natstack[:,:,None,:,1,:], dim=-1)
|
||||
delta = torch.sum( torch.abs(dnat-dx), dim=(-2,-1))
|
||||
return natstack[:,torch.argmin(delta),...], maskstack[:,torch.argmin(delta),...]
|
||||
|
||||
|
||||
#torsion angle predictor loss
|
||||
def torsionAngleLoss( alpha, alphanat, alphanat_alt, tors_mask, tors_planar, eps=1e-8 ):
|
||||
I = alpha.shape[0]
|
||||
lnat = torch.sqrt( torch.sum( torch.square(alpha), dim=-1 ) + eps )
|
||||
anorm = alpha / (lnat[...,None])
|
||||
|
||||
l_tors_ij = torch.min(
|
||||
torch.sum(torch.square( anorm - alphanat[None] ),dim=-1),
|
||||
torch.sum(torch.square( anorm - alphanat_alt[None] ),dim=-1)
|
||||
)
|
||||
l_tors = torch.sum( l_tors_ij*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
|
||||
l_norm = torch.sum( torch.abs(lnat-1.0)*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
|
||||
l_planar = torch.sum( torch.abs( alpha[...,0] )*tors_planar[None] ) / (torch.sum( tors_planar )*I + eps)
|
||||
|
||||
return l_tors+0.02*l_norm+0.02*l_planar
|
||||
|
||||
def compute_FAPE(Rs, Ts, xs, Rsnat, Tsnat, xsnat, Z=10.0, dclamp=10.0, eps=1e-4):
|
||||
xij = torch.einsum('rji,rsj->rsi', Rs, xs[None,...] - Ts[:,None,...])
|
||||
xij_t = torch.einsum('rji,rsj->rsi', Rsnat, xsnat[None,...] - Tsnat[:,None,...])
|
||||
|
||||
#torch.norm(xij-xij_t,dim=-1)
|
||||
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t), dim=-1 ) + eps )
|
||||
|
||||
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae, pae_bin_step=0.5, eps=1e-4):
|
||||
# predicted aligned error: C-alpha (or sm. mol atom) distances in backbone frames
|
||||
xij_ca = torch.einsum('rji,rsj->rsi', uX[-1,:,0], X[-1,:,None,1] - X_y[-1,None,:,0,:]) # last bb prediction
|
||||
xij_ca_t = torch.einsum('rji,rsj->rsi', uY[0,:,0], Y[0,:,None,1] - Y_y[0,None,:,0,:]) # assumes B=1
|
||||
eij_label = torch.sqrt(torch.square(xij_ca - xij_ca_t).sum(dim=-1)+eps).clone().detach()
|
||||
|
||||
nbin = logit_pae.shape[1]
|
||||
pae_bins = torch.linspace(pae_bin_step, pae_bin_step*(nbin-1), nbin-1, dtype=logit_pae.dtype, device=logit_pae.device)
|
||||
true_pae_label = torch.bucketize(eij_label, pae_bins, right=True).long()
|
||||
return torch.nn.CrossEntropyLoss(reduction='mean')(logit_pae, true_pae_label[None]) # assumes B=1
|
||||
|
||||
def compute_pde_loss(X, Y, logit_pde, pde_bin_step=0.3):
|
||||
# predicted distance error: C-alpha (or sm. mol atom) pairwise distances
|
||||
dX = torch.cdist(X[-1,:,1], X[-1,:,1], compute_mode='donot_use_mm_for_euclid_dist')
|
||||
dY = torch.cdist(Y[0,:,1], Y[0,:,1], compute_mode='donot_use_mm_for_euclid_dist')
|
||||
dist_err = torch.abs(dX-dY).clone().detach()
|
||||
|
||||
nbin = logit_pde.shape[1]
|
||||
pde_bins = torch.linspace(pde_bin_step, pde_bin_step*(nbin-1), nbin-1, dtype=logit_pde.dtype, device=logit_pde.device)
|
||||
true_pde_label = torch.bucketize(dist_err, pde_bins, right=True).long()
|
||||
return torch.nn.CrossEntropyLoss(reduction='mean')(logit_pde, true_pde_label[None]) # assumes B=1
|
||||
|
||||
|
||||
# from Ivan: FAPE generalized over atom sets & frames
|
||||
def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=None,
|
||||
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, gamma=0.99, eps=1e-4):
|
||||
|
||||
# X (predicted) N x L x natoms x 3
|
||||
# Y (native) 1 x L x natoms x 3
|
||||
# atom_mask 1 x L x natoms
|
||||
# frames 1 x L x nframes x 3 x 2
|
||||
# frame_mask 1 x L x nframes
|
||||
# frame_atom_mask 1 x L x natoms
|
||||
|
||||
if frame_atom_mask is None:
|
||||
frame_atom_mask = atom_mask
|
||||
|
||||
N, L, natoms, _ = X.shape
|
||||
|
||||
# flatten middle dims so can gather across residues
|
||||
X_prime = X.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
|
||||
Y_prime = Y.reshape(1, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
|
||||
|
||||
# reindex frames for flat X
|
||||
frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device)
|
||||
for i in range(L):
|
||||
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
|
||||
frames_reindex = frames_reindex.long()
|
||||
|
||||
frame_mask *= torch.all(
|
||||
torch.gather(frame_atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
|
||||
axis=-1)
|
||||
|
||||
X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3))
|
||||
X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3))
|
||||
X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3))
|
||||
uX,tX = rigid_from_3_points(X_x, X_y, X_z)
|
||||
|
||||
Y_x = torch.gather(Y_prime, 1, frames_reindex[...,0:1].repeat(1,1,1,3))
|
||||
Y_y = torch.gather(Y_prime, 1, frames_reindex[...,1:2].repeat(1,1,1,3))
|
||||
Y_z = torch.gather(Y_prime, 1, frames_reindex[...,2:3].repeat(1,1,1,3))
|
||||
uY,tY = rigid_from_3_points(Y_x, Y_y, Y_z)
|
||||
|
||||
xij = torch.einsum(
|
||||
'brji,brsj->brsi',
|
||||
uX[:,frame_mask[0]], X[:,atom_mask[0]][:,None,...] - X_y[:,frame_mask[0]][:,:,None,...]
|
||||
)
|
||||
xij_t = torch.einsum('rji,rsj->rsi', uY[frame_mask], Y[atom_mask][None,...] - Y_y[frame_mask][:,None,...])
|
||||
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t[None,...]), dim=-1 ) + eps )
|
||||
|
||||
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2))
|
||||
pae_loss = compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae) if logit_pae is not None \
|
||||
else torch.tensor(0).to(frames.device)
|
||||
pde_loss = compute_pde_loss(X, Y, logit_pde) if logit_pde is not None \
|
||||
else torch.tensor(0).to(frames.device)
|
||||
|
||||
return loss, pae_loss, pde_loss
|
||||
|
||||
def calc_crd_rmsd(pred, true, atom_mask, rmsd_mask=None):
|
||||
'''
|
||||
Calculate coordinate RMSD
|
||||
Input:
|
||||
- pred: predicted coordinates (B, L, natoms, 3)
|
||||
- true: true coordinates (B, L, natoms, 3)
|
||||
- atom_mask: mask for seen coordinates (B, L, natoms)
|
||||
Output: RMSD after superposition
|
||||
'''
|
||||
def rmsd(V, W, eps=1e-6):
|
||||
L = V.shape[1]
|
||||
return torch.sqrt(torch.sum((V-W)*(V-W), dim=(1,2)) / L + eps)
|
||||
def centroid(X):
|
||||
return X.mean(dim=-2, keepdim=True)
|
||||
if rmsd_mask == None:
|
||||
rmsd_mask = atom_mask.clone()
|
||||
|
||||
B, L, natoms = pred.shape[:3]
|
||||
|
||||
# center to centroid
|
||||
pred_allatom = pred[atom_mask][None]
|
||||
true_allatom = true[atom_mask][None]
|
||||
|
||||
pred_allatom_origin = pred_allatom - centroid(pred_allatom)
|
||||
true_allatom_origin = true_allatom - centroid(true_allatom)
|
||||
|
||||
# reshape true crds to match the shape to pred crds
|
||||
# true = true.unsqueeze(0).expand(I,-1,-1,-1,-1)
|
||||
# pred = pred.view(B, L*natoms, 3)
|
||||
# true = true.view(I*B, L*natoms, 3)
|
||||
|
||||
# Computation of the covariance matrix
|
||||
C = torch.matmul(pred_allatom_origin.permute(0,2,1), true_allatom_origin)
|
||||
|
||||
# Compute optimal rotation matrix using SVD
|
||||
V, S, W = torch.svd(C)
|
||||
|
||||
# get sign to ensure right-handedness
|
||||
d = torch.ones([B,3,3], device=pred.device)
|
||||
d[:,:,-1] = torch.sign(torch.det(V)*torch.det(W)).unsqueeze(1)
|
||||
|
||||
# Rotation matrix U
|
||||
U = torch.matmul(d*V, W.permute(0,2,1)) # (IB, 3, 3)
|
||||
|
||||
pred_rms = pred[rmsd_mask][None] - centroid(pred_allatom)
|
||||
true_rms = true[rmsd_mask][None] - centroid(true_allatom)
|
||||
# Rotate pred
|
||||
rP = torch.matmul(pred_rms, U) # (IB, L*3, 3)
|
||||
|
||||
# get RMS
|
||||
rms = rmsd(rP, true_rms).reshape(B)
|
||||
return rms
|
||||
|
||||
def angle(a, b, c, eps=1e-6):
|
||||
'''
|
||||
Calculate cos/sin angle between ab and cb
|
||||
a,b,c have shape of (B, L, 3)
|
||||
'''
|
||||
B,L = a.shape[:2]
|
||||
|
||||
u1 = a-b
|
||||
u2 = c-b
|
||||
|
||||
u1_norm = torch.norm(u1, dim=-1, keepdim=True) + eps
|
||||
u2_norm = torch.norm(u2, dim=-1, keepdim=True) + eps
|
||||
|
||||
# normalize u1 & u2 --> make unit vector
|
||||
u1 = u1 / u1_norm
|
||||
u2 = u2 / u2_norm
|
||||
u1 = u1.reshape(B*L, 3)
|
||||
u2 = u2.reshape(B*L, 3)
|
||||
|
||||
# sin_theta = norm(a cross b)/(norm(a)*norm(b))
|
||||
# cos_theta = norm(a dot b) / (norm(a)*norm(b))
|
||||
sin_theta = torch.norm(torch.cross(u1, u2, dim=1), dim=1, keepdim=True).reshape(B, L, 1) # (B,L,1)
|
||||
cos_theta = torch.matmul(u1[:,None,:], u2[:,:,None]).reshape(B, L, 1)
|
||||
|
||||
return torch.cat([cos_theta, sin_theta], axis=-1) # (B, L, 2)
|
||||
|
||||
def length(a, b):
|
||||
return torch.norm(a-b, dim=-1)
|
||||
|
||||
def torsion(a,b,c,d, eps=1e-6):
|
||||
#A function that takes in 4 atom coordinates:
|
||||
# a - [B,L,3]
|
||||
# b - [B,L,3]
|
||||
# c - [B,L,3]
|
||||
# d - [B,L,3]
|
||||
# and returns cos and sin of the dihedral angle between those 4 points in order a, b, c, d
|
||||
# output - [B,L,2]
|
||||
u1 = b-a
|
||||
u1 = u1 / (torch.norm(u1, dim=-1, keepdim=True) + eps)
|
||||
u2 = c-b
|
||||
u2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)
|
||||
u3 = d-c
|
||||
u3 = u3 / (torch.norm(u3, dim=-1, keepdim=True) + eps)
|
||||
#
|
||||
t1 = torch.cross(u1, u2, dim=-1) #[B, L, 3]
|
||||
t2 = torch.cross(u2, u3, dim=-1)
|
||||
t1_norm = torch.norm(t1, dim=-1, keepdim=True)
|
||||
t2_norm = torch.norm(t2, dim=-1, keepdim=True)
|
||||
|
||||
cos_angle = torch.matmul(t1[:,:,None,:], t2[:,:,:,None])[:,:,0]
|
||||
sin_angle = torch.norm(u2, dim=-1,keepdim=True)*(torch.matmul(u1[:,:,None,:], t2[:,:,:,None])[:,:,0])
|
||||
|
||||
cos_sin = torch.cat([cos_angle, sin_angle], axis=-1)/(t1_norm*t2_norm+eps) #[B,L,2]
|
||||
return cos_sin
|
||||
|
||||
# ideal N-C distance, ideal cos(CA-C-N angle), ideal cos(C-N-CA angle)
|
||||
# for NA, we do not compute this as it is not computable from the stubs alone
|
||||
def calc_BB_bond_geom(
|
||||
seq, pred, idx, eps=1e-6,
|
||||
ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255,
|
||||
sig_len=0.02, sig_ang=0.05):
|
||||
'''
|
||||
Calculate backbone bond geometry (bond length and angle) and put loss on them
|
||||
Input:
|
||||
- pred: predicted coords (B, L, :, 3), 0; N / 1; CA / 2; C
|
||||
- true: True coords (B, L, :, 3)
|
||||
Output:
|
||||
- bond length loss, bond angle loss
|
||||
'''
|
||||
def cosangle( A,B,C ):
|
||||
AB = A-B
|
||||
BC = C-B
|
||||
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
|
||||
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
|
||||
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
|
||||
|
||||
B, L = pred.shape[:2]
|
||||
|
||||
bonded = (idx[:,1:] - idx[:,:-1])==1
|
||||
is_prot = ~is_nucleic(seq)[:-1]
|
||||
|
||||
# bond length: C-N
|
||||
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(B,L-1) # (B, L-1)
|
||||
CN_loss = torch.clamp( torch.abs(blen_CN_pred - ideal_NC) - sig_len, min=0.0 )
|
||||
CN_loss = (bonded*is_prot*CN_loss).sum() / ((bonded*is_prot).sum() + eps)
|
||||
blen_loss = CN_loss #fd squared loss
|
||||
|
||||
# bond angle: CA-C-N, C-N-CA
|
||||
bang_CACN_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
|
||||
bang_CNCA_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
|
||||
CACN_loss = torch.clamp( torch.abs(bang_CACN_pred - ideal_CACN) - sig_ang, min=0.0 )
|
||||
CACN_loss = (bonded*is_prot*CACN_loss).sum() / ((bonded*is_prot).sum() + eps)
|
||||
CNCA_loss = torch.clamp( torch.abs(bang_CNCA_pred - ideal_CNCA) - sig_ang, min=0.0 )
|
||||
CNCA_loss = (bonded*is_prot*CNCA_loss).sum() / ((bonded*is_prot).sum() + eps)
|
||||
bang_loss = CACN_loss + CNCA_loss
|
||||
|
||||
return blen_loss+bang_loss
|
||||
|
||||
def calc_atom_bond_loss(pred, true, bond_feats, seq, beta=0.2, eps=1e-6):
|
||||
"""
|
||||
loss on distances between bonded atoms
|
||||
"""
|
||||
loss_func_sum = torch.nn.SmoothL1Loss(reduction='sum', beta=beta)
|
||||
loss_func_mean = torch.nn.SmoothL1Loss(reduction='mean', beta=beta)
|
||||
|
||||
# intra-ligand bonds
|
||||
atom_bonds = (bond_feats>0)*(bond_feats < 5)
|
||||
b, i, j = torch.where(atom_bonds>0)
|
||||
nat_dist = torch.sum(torch.square(true[:,i,1]-true[:,j,1]),dim=-1)
|
||||
pred_dist = torch.sum(torch.square(pred[:,i,1]-pred[:,j,1]),dim=-1)
|
||||
#lig_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist), max=clamp)) # from EquiBind
|
||||
lig_dist_loss = loss_func_sum(nat_dist, pred_dist)
|
||||
|
||||
# bonds between protein residues and ligand atoms (i.e. atomized protein)
|
||||
inter_bonds = bond_feats==6
|
||||
_, i, j = torch.where(inter_bonds)
|
||||
a = (seq[:,i]<22) & (seq[:,j]==39) # res N - atom C: binary indicator
|
||||
b = (seq[:,i]<22) & (seq[:,j]==55) # res C - atom N
|
||||
c = (seq[:,i]==39) & (seq[:,j]<22) # atom C - res N
|
||||
d = (seq[:,i]==55) & (seq[:,j]<22) # atom N - res C
|
||||
i_atom = 0*a + 2*b + 1*c + 1*d # (B, N_bonds) : indexes of atom that is bonded (N:0, C:2, 1:ligand atom)
|
||||
j_atom = 1*a + 1*b + 0*c + 2*d # (B, N_bonds)
|
||||
nat_dist = torch.sum(torch.square(true[0,i,i_atom[0],:]-true[0,j,j_atom[0],:]), dim=-1) # assumes B=1
|
||||
pred_dist = torch.sum(torch.square(pred[0,i,i_atom[0],:]-pred[0,j,j_atom[0],:]), dim=-1)
|
||||
#inter_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist), max=clamp))
|
||||
inter_dist_loss = loss_func_sum(nat_dist, pred_dist)
|
||||
|
||||
bond_dist_loss = (lig_dist_loss + inter_dist_loss)/(atom_bonds.sum() + inter_bonds.sum() + eps)
|
||||
|
||||
# enforce LAS constraints between atoms 2 bonds away and aromatic groups
|
||||
atom_bonds_np = atom_bonds[0].cpu().numpy()
|
||||
G = nx.from_numpy_matrix(atom_bonds_np)
|
||||
paths = find_all_paths_of_length_n(G,2)
|
||||
if paths:
|
||||
paths = torch.tensor(paths, device=pred.device)
|
||||
nat_dist = torch.sum(torch.square(true[:,paths[:,0],1]-true[:,paths[:,2],1]),dim=-1)
|
||||
pred_dist = torch.sum(torch.square(pred[:,paths[:,0],1]-pred[:,paths[:,2],1]),dim=-1)
|
||||
#skip_bond_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist),max=clamp))/(paths.shape[0]+eps)
|
||||
skip_bond_dist_loss = loss_func_mean(nat_dist, pred_dist)
|
||||
else:
|
||||
skip_bond_dist_loss = torch.tensor(0, device=pred.device)
|
||||
rigid_groups = find_all_rigid_groups(bond_feats)
|
||||
if rigid_groups != None:
|
||||
nat_dist = torch.sum(torch.square(true[:,rigid_groups[:,0],1]-true[:,rigid_groups[:,1],1]),dim=-1)
|
||||
pred_dist = torch.sum(torch.square(pred[:,rigid_groups[:,0],1]-pred[:,rigid_groups[:,1],1]),dim=-1)
|
||||
#rigid_group_dist_loss = torch.sum(torch.clamp(torch.square(nat_dist-pred_dist),max=clamp))/(rigid_groups.shape[0]+eps)
|
||||
rigid_group_dist_loss = loss_func_mean(nat_dist, pred_dist)
|
||||
else:
|
||||
rigid_group_dist_loss = torch.tensor(0, device=pred.device)
|
||||
|
||||
return bond_dist_loss, skip_bond_dist_loss, rigid_group_dist_loss
|
||||
|
||||
def calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
|
||||
# pred: N x L x 27 x 3
|
||||
# idx: 1 x L
|
||||
# seq: 1 x L
|
||||
def gen_ang( A,B,C ):
|
||||
AB = A-B
|
||||
BC = C-B
|
||||
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
|
||||
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
|
||||
return torch.acos( torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) )
|
||||
|
||||
# quadratic from [-1,1], linear elsewhere
|
||||
def boundfunc(X):
|
||||
Y = torch.abs(X)
|
||||
Y[Y<1.0] = torch.square(Y[Y<1.0])
|
||||
#Y = torch.square(X)
|
||||
return Y
|
||||
|
||||
N,L = pred.shape[:2]
|
||||
cb_loss = torch.zeros(N, device=pred.device)
|
||||
|
||||
## intra-res
|
||||
cblens = len_param[seq]
|
||||
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
|
||||
len_all = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
|
||||
len_mask = cblens[...,0]!=cblens[...,1]
|
||||
E_cb_len = (
|
||||
len_mask[None,...] *
|
||||
cblens[None,...,3] *
|
||||
boundfunc( length(len_all[...,0,:],len_all[...,1,:]) - cblens[...,2] )
|
||||
).sum(dim=(0,3)) / len_mask.sum()
|
||||
|
||||
# figure out which his are his_d
|
||||
cblens[seq==8] = len_param[-1]
|
||||
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
|
||||
len_all_a = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
|
||||
len_mask_a = cblens[...,0]!=cblens[...,1]
|
||||
E_cb_len_a = (
|
||||
len_mask_a[None,...] *
|
||||
cblens[None,...,3] *
|
||||
boundfunc( length(len_all_a[...,0,:],len_all_a[...,1,:]) - cblens[...,2] )
|
||||
).sum(dim=(0,3)) / len_mask.sum() # N,L
|
||||
is_his_d = (seq==8)*(E_cb_len_a<E_cb_len)
|
||||
|
||||
cb_loss += torch.min(E_cb_len_a,E_cb_len).sum(dim=1)
|
||||
|
||||
cbangs = ang_param[seq].repeat(N,1,1,1)
|
||||
cbangs[is_his_d] = ang_param[-1]
|
||||
ang_idx = cbangs[...,:3].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
|
||||
ang_all = torch.gather(pred, 2, ang_idx).reshape(N,L,-1,3,3)
|
||||
ang_mask = cbangs[...,0]!=cbangs[...,1]
|
||||
E_cb_ang = (
|
||||
ang_mask[None,...] *
|
||||
cbangs[None,...,4] *
|
||||
boundfunc( get_ang(ang_all[...,0,:],ang_all[...,1,:],ang_all[...,2,:]) - cbangs[None,...,3] )
|
||||
).sum(dim=(0,2,3)) / ang_mask.sum()
|
||||
cb_loss += E_cb_ang
|
||||
|
||||
cbtors = tor_param[seq].repeat(N,1,1,1)
|
||||
cbtors[is_his_d] = tor_param[-1]
|
||||
tor_idx = cbtors[...,:4].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
|
||||
tor_all = torch.gather(pred, 2, tor_idx).reshape(N,L,-1,4,3)
|
||||
tor_mask = cbtors[...,0]!=cbtors[...,1]
|
||||
offset = 2*np.pi/cbtors[None,...,6]
|
||||
tor_deltas = (
|
||||
get_dih(
|
||||
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
|
||||
) - cbtors[None,...,4] + 0.5*offset
|
||||
) % offset - 0.5*offset
|
||||
|
||||
dihs = get_dih(
|
||||
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
|
||||
)
|
||||
|
||||
E_cb_tor = (
|
||||
tor_mask[None,...] *
|
||||
cbtors[None,...,5] *
|
||||
boundfunc( tor_deltas )
|
||||
).sum(dim=(0,2,3)) / tor_mask.sum()
|
||||
cb_loss += E_cb_tor
|
||||
|
||||
# inter-res
|
||||
# bond length: C-N
|
||||
bonded = (idx[:,1:] - idx[:,:-1])==1
|
||||
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(N,L-1) # (B, L-1)
|
||||
CN_loss = cb_lengths_CN[1] * boundfunc(blen_CN_pred - cb_lengths_CN[0])
|
||||
cb_loss += (bonded*CN_loss).sum(dim=1) / (bonded.sum())
|
||||
|
||||
# bond angle: CA-C-N, C-N-CA
|
||||
bang_CACN_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
|
||||
CACN_loss = cb_angles_CACN[1] * boundfunc(bang_CACN_pred - cb_angles_CACN[0])
|
||||
cb_loss += (bonded*CACN_loss).sum(dim=1) / (bonded.sum())
|
||||
|
||||
bang_CNCA_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
|
||||
CNCA_loss = cb_angles_CNCA[1] * boundfunc(bang_CNCA_pred - cb_angles_CNCA[0])
|
||||
cb_loss += (bonded*CNCA_loss).sum(dim=1) / (bonded.sum())
|
||||
|
||||
# improper torsions CA-C-N-H (CD-C-N-CA), CA-N-C-O
|
||||
# planarity around N (H for non-pro, CD for pro)
|
||||
atom4idx = torch.full_like(seq, 14)
|
||||
atom4idx[seq==14] = 6 # set to CD for proline
|
||||
atom4 = torch.gather( pred, 2, atom4idx[:,:,None,None].repeat(1,1,1,3) )
|
||||
btor_CACNH_delta = (
|
||||
get_dih(
|
||||
pred[:,:-1,1], pred[:,:-1,2], pred[:,1:,0], atom4[:,1:,0]
|
||||
) - cb_torsions_CACNH[0] + np.pi/2
|
||||
) % np.pi - np.pi/2
|
||||
CACNH_loss = cb_torsions_CACNH[1] * boundfunc( btor_CACNH_delta )
|
||||
cb_loss += (bonded*CACNH_loss).sum(dim=1) / (bonded.sum())
|
||||
|
||||
# planarity around C
|
||||
btor_CANCO_delta = (
|
||||
get_dih(
|
||||
pred[:,:-1,1], pred[:,1:,0], pred[:,:-1,2], pred[:,:-1,3]
|
||||
) - cb_torsions_CANCO[0] + np.pi/2
|
||||
) % np.pi - np.pi/2
|
||||
CANCO_loss = cb_torsions_CANCO[1] * boundfunc( btor_CANCO_delta )
|
||||
cb_loss += (bonded*CANCO_loss).sum(dim=1) / (bonded.sum())
|
||||
|
||||
return cb_loss
|
||||
|
||||
# AF2-like version of clash score
|
||||
def calc_clash(xs, mask):
|
||||
DISTCUT=2.0 # (d_lit - tau) from AF2 MS
|
||||
L = xs.shape[0]
|
||||
dij = torch.sqrt(
|
||||
torch.sum( torch.square( xs[:,:,None,None,:]-xs[None,None,:,:,:] ), dim=-1 ) + 1e-8
|
||||
)
|
||||
|
||||
allmask = mask[:,:,None,None]*mask[None,None,:,:]
|
||||
allmask[torch.arange(L),:,torch.arange(L),:] = False # ignore res-self
|
||||
allmask[torch.arange(1,L),0,torch.arange(L-1),2] = False # ignore N->C
|
||||
allmask[torch.arange(L-1),2,torch.arange(1,L),0] = False # ignore N->C
|
||||
|
||||
clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask)
|
||||
return clash
|
||||
|
||||
# Rosetta-like version of LJ (fa_atr+fa_rep)
|
||||
# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes
|
||||
def calc_lj(
|
||||
seq, xs, aamask, bond_feats, ljparams, ljcorr, num_bonds,
|
||||
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||
lj_maxrad=-1.0, eps=1e-8
|
||||
):
|
||||
def ljV(dist, sigma, epsilon, lj_lin, lj_maxrad):
|
||||
N = dist.shape[0]
|
||||
linpart = dist<lj_lin*sigma[None]
|
||||
deff = dist.clone()
|
||||
deff[linpart] = lj_lin*sigma.repeat(N,1)[linpart]
|
||||
sd = sigma[None] / deff
|
||||
sd2 = sd*sd
|
||||
sd6 = sd2 * sd2 * sd2
|
||||
sd12 = sd6 * sd6
|
||||
ljE = epsilon * (sd12 - 2 * sd6)
|
||||
ljE[linpart] += epsilon.repeat(N,1)[linpart] * (
|
||||
-12 * sd12[linpart]/deff[linpart] + 12 * sd6[linpart]/deff[linpart]
|
||||
) * (dist[linpart]-deff[linpart])
|
||||
if (lj_maxrad>0):
|
||||
sdmax = sigma / lj_maxrad
|
||||
sd2 = sd*sd
|
||||
sd6 = sd2 * sd2 * sd2
|
||||
sd12 = sd6 * sd6
|
||||
ljE = ljE - epsilon * (sd12 - 2 * sd6)
|
||||
return ljE
|
||||
|
||||
N, L = xs.shape[:2]
|
||||
|
||||
# mask keeps running total of what to compute
|
||||
mask = aamask[seq][...,None,None]*aamask[seq][None,None,...]
|
||||
idxes1r = torch.tril_indices(L,L,-1)
|
||||
mask[idxes1r[0],:,idxes1r[1],:] = False
|
||||
idxes2r = torch.arange(L)
|
||||
idxes2a = torch.tril_indices(NTOTAL,NTOTAL,0)
|
||||
mask[idxes2r[:,None],idxes2a[0:1],idxes2r[:,None],idxes2a[1:2]] = False
|
||||
|
||||
# "countpair" can be enforced by making this a weight
|
||||
mask[idxes2r,:,idxes2r,:] *= num_bonds[seq,:,:] >= 4 #intra-res
|
||||
mask[idxes2r[:-1],:,idxes2r[1:],:] *= (
|
||||
num_bonds[seq[:-1],:,2:3] + num_bonds[seq[1:],0:1,:] + 1 >= 4 #inter-res
|
||||
)
|
||||
atom_bonds = (bond_feats > 0)*(bond_feats<5)
|
||||
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds[0].long().detach().cpu().numpy(), directed=False)
|
||||
dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0), device=mask.device) # protein portion is inf and you don't want to mask it out
|
||||
mask[:,1,:,1] *= dist_matrix >=4
|
||||
si,ai,sj,aj = mask.nonzero(as_tuple=True)
|
||||
|
||||
ds = torch.sqrt( torch.sum ( torch.square( xs[:,si,ai]-xs[:,sj,aj] ), dim=-1 ) + eps )
|
||||
|
||||
# hbond correction
|
||||
use_hb_dis = (
|
||||
ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,1]
|
||||
+ ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0] )
|
||||
use_ohdon_dis = ( # OH are both donors & acceptors
|
||||
ljcorr[seq[si],ai,0]*ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0]
|
||||
+ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,0]*ljcorr[seq[sj],aj,1]
|
||||
)
|
||||
use_hb_hdis = (
|
||||
ljcorr[seq[si],ai,2]*ljcorr[seq[sj],aj,1]
|
||||
+ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,2]
|
||||
)
|
||||
|
||||
# disulfide correction
|
||||
potential_disulf = ljcorr[seq[si],ai,3]*ljcorr[seq[sj],aj,3]
|
||||
|
||||
ljrs = ljparams[seq[si],ai,0] + ljparams[seq[sj],aj,0]
|
||||
ljrs[use_hb_dis] = lj_hb_dis
|
||||
ljrs[use_ohdon_dis] = lj_OHdon_dis
|
||||
ljrs[use_hb_hdis] = lj_hbond_hdis
|
||||
|
||||
ljss = torch.sqrt( ljparams[seq[si],ai,1] * ljparams[seq[sj],aj,1] + eps )
|
||||
ljss [potential_disulf] = 0.0
|
||||
|
||||
ljval = ljV(ds,ljrs,ljss,lj_lin,lj_maxrad)
|
||||
return (torch.sum( ljval, dim=-1 )/torch.sum(aamask[seq]))
|
||||
|
||||
|
||||
def calc_hb(
|
||||
seq, xs, aamask, hbtypes, hbbaseatoms, hbpolys,
|
||||
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
|
||||
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
|
||||
):
|
||||
def evalpoly( ds, xrange, yrange, coeffs ):
|
||||
v = coeffs[...,0]
|
||||
for i in range(1,10):
|
||||
v = v * ds + coeffs[...,i]
|
||||
minmask = ds<xrange[...,0]
|
||||
v[minmask] = yrange[minmask][...,0]
|
||||
maxmask = ds>xrange[...,1]
|
||||
v[maxmask] = yrange[maxmask][...,1]
|
||||
return v
|
||||
|
||||
def cosangle( A,B,C ):
|
||||
AB = A-B
|
||||
BC = C-B
|
||||
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
|
||||
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
|
||||
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
|
||||
|
||||
hbts = hbtypes[seq]
|
||||
hbba = hbbaseatoms[seq]
|
||||
|
||||
rh,ah = (hbts[...,0]>=0).nonzero(as_tuple=True)
|
||||
ra,aa = (hbts[...,1]>=0).nonzero(as_tuple=True)
|
||||
D_xs = xs[rh,hbba[rh,ah,0]][:,None,:]
|
||||
H_xs = xs[rh,ah][:,None,:]
|
||||
A_xs = xs[ra,aa][None,:,:]
|
||||
B_xs = xs[ra,hbba[ra,aa,0]][None,:,:]
|
||||
B0_xs = xs[ra,hbba[ra,aa,1]][None,:,:]
|
||||
hyb = hbts[ra,aa,2]
|
||||
polys = hbpolys[hbts[rh,ah,0][:,None],hbts[ra,aa,1][None,:]]
|
||||
|
||||
AH = torch.sqrt( torch.sum( torch.square( H_xs-A_xs), axis=-1) + eps )
|
||||
AHD = torch.acos( cosangle( B_xs, A_xs, H_xs) )
|
||||
|
||||
Es = polys[...,0,0]*evalpoly(
|
||||
AH,polys[...,0,1:3],polys[...,0,3:5],polys[...,0,5:])
|
||||
Es += polys[...,1,0] * evalpoly(
|
||||
AHD,polys[...,1,1:3],polys[...,1,3:5],polys[...,1,5:])
|
||||
|
||||
Bm = 0.5*(B0_xs[:,hyb==HbHybType.RING]+B_xs[:,hyb==HbHybType.RING])
|
||||
cosBAH = cosangle( Bm, A_xs[:,hyb==HbHybType.RING], H_xs )
|
||||
Es[:,hyb==HbHybType.RING] += polys[:,hyb==HbHybType.RING,2,0] * evalpoly(
|
||||
cosBAH,
|
||||
polys[:,hyb==HbHybType.RING,2,1:3],
|
||||
polys[:,hyb==HbHybType.RING,2,3:5],
|
||||
polys[:,hyb==HbHybType.RING,2,5:])
|
||||
|
||||
cosBAH1 = cosangle( B_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
|
||||
cosBAH2 = cosangle( B0_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
|
||||
Esp3_1 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
|
||||
cosBAH1,
|
||||
polys[:,hyb==HbHybType.SP3,2,1:3],
|
||||
polys[:,hyb==HbHybType.SP3,2,3:5],
|
||||
polys[:,hyb==HbHybType.SP3,2,5:])
|
||||
Esp3_2 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
|
||||
cosBAH2,
|
||||
polys[:,hyb==HbHybType.SP3,2,1:3],
|
||||
polys[:,hyb==HbHybType.SP3,2,3:5],
|
||||
polys[:,hyb==HbHybType.SP3,2,5:])
|
||||
Es[:,hyb==HbHybType.SP3] += torch.log(
|
||||
torch.exp(Esp3_1 * hb_sp3_softmax_fade)
|
||||
+ torch.exp(Esp3_2 * hb_sp3_softmax_fade)
|
||||
) / hb_sp3_softmax_fade
|
||||
|
||||
cosBAH = cosangle( B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs )
|
||||
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * evalpoly(
|
||||
cosBAH,
|
||||
polys[:,hyb==HbHybType.SP2,2,1:3],
|
||||
polys[:,hyb==HbHybType.SP2,2,3:5],
|
||||
polys[:,hyb==HbHybType.SP2,2,5:])
|
||||
|
||||
BAH = torch.acos( cosBAH )
|
||||
B0BAH = get_dih(B0_xs[:,hyb==HbHybType.SP2], B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs)
|
||||
|
||||
d,m,l = hb_sp2_BAH180_rise, hb_sp2_range_span, hb_sp2_outer_width
|
||||
Echi = torch.full_like( B0BAH, m-0.5 )
|
||||
|
||||
mask1 = BAH>np.pi * 2.0 / 3.0
|
||||
H = 0.5 * (torch.cos(2 * B0BAH) + 1)
|
||||
F = d / 2 * torch.cos(3 * (np.pi - BAH[mask1])) + d / 2 - 0.5
|
||||
Echi[mask1] = H[mask1] * F + (1 - H[mask1]) * d - 0.5
|
||||
|
||||
mask2 = BAH>np.pi * (2.0 / 3.0 - l)
|
||||
mask2 *= ~mask1
|
||||
outer_rise = torch.cos(np.pi - (np.pi * 2 / 3 - BAH[mask2]) / l)
|
||||
F = m / 2 * outer_rise + m / 2 - 0.5
|
||||
G = (m - d) / 2 * outer_rise + (m - d) / 2 + d - 0.5
|
||||
Echi[mask2] = H[mask2] * F + (1 - H[mask2]) * d - 0.5
|
||||
|
||||
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * Echi
|
||||
|
||||
tosquish = torch.logical_and(Es > -0.1,Es < 0.1)
|
||||
Es[tosquish] = -0.025 + 0.5 * Es[tosquish] - 2.5 * torch.square(Es[tosquish])
|
||||
Es[Es > 0.1] = 0.
|
||||
if (normalize):
|
||||
return (torch.sum( Es ) / torch.sum(aamask[seq]))
|
||||
else:
|
||||
return torch.sum( Es )
|
||||
|
||||
def calc_chiral_loss(pred, chirals):
|
||||
"""
|
||||
calculate error in dihedral angles for chiral atoms
|
||||
Input:
|
||||
- pred: predicted coords (B, L, :, 3)
|
||||
- chirals: True coords (B, nchiral, 5), skip if 0 chiral sites, 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
|
||||
Output:
|
||||
- mean squared error of chiral angles
|
||||
"""
|
||||
if chirals.shape[1] == 0:
|
||||
return torch.tensor(0.0, device=pred.device)
|
||||
chiral_dih = pred[:, chirals[..., :-1].long(), 1]
|
||||
pred_dih = get_dih(chiral_dih[...,0, :], chiral_dih[...,1, :], chiral_dih[...,2, :], chiral_dih[...,3, :]) # n_symm, b, n, 36, 3
|
||||
l = torch.square(pred_dih-chirals[...,-1]).mean()
|
||||
return l
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_BB_bond_geom_grads(seq, pred, idx, eps=1e-6, ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255, sig_len=0.02, sig_ang=0.05):
|
||||
pred.requires_grad_(True)
|
||||
Ebond = calc_BB_bond_geom(seq, pred, idx, eps, ideal_NC, ideal_CACN, ideal_CNCA, sig_len, sig_ang)
|
||||
return torch.autograd.grad(Ebond, pred)
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_cart_bonded_grads(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
|
||||
pred.requires_grad_(True)
|
||||
Ecb = calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps)
|
||||
return torch.autograd.grad(Ecb, pred)
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_ljallatom_grads(
|
||||
seq, xyzaa,
|
||||
aamask, bond_feats, ljparams, ljcorr, num_bonds,
|
||||
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||
lj_maxrad=-1.0, eps=1e-8
|
||||
):
|
||||
xyzaa.requires_grad_(True)
|
||||
Elj = calc_lj(
|
||||
seq[0],
|
||||
xyzaa[...,:3],
|
||||
aamask,
|
||||
bond_feats,
|
||||
ljparams,
|
||||
ljcorr,
|
||||
num_bonds,
|
||||
lj_lin,
|
||||
lj_hb_dis,
|
||||
lj_OHdon_dis,
|
||||
lj_hbond_hdis,
|
||||
lj_maxrad,
|
||||
eps
|
||||
)
|
||||
return torch.autograd.grad(Elj, (xyzaa,))
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_lj_grads(
|
||||
seq, xyz, alpha, toaa, bond_feats,
|
||||
aamask, ljparams, ljcorr, num_bonds,
|
||||
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||
lj_maxrad=-1.0, eps=1e-8
|
||||
):
|
||||
xyz.requires_grad_(True)
|
||||
alpha.requires_grad_(True)
|
||||
_, xyzaa = toaa(seq, xyz, alpha)
|
||||
Elj = calc_lj(
|
||||
seq[0],
|
||||
xyzaa[...,:3],
|
||||
aamask,
|
||||
bond_feats,
|
||||
ljparams,
|
||||
ljcorr,
|
||||
num_bonds,
|
||||
lj_lin,
|
||||
lj_hb_dis,
|
||||
lj_OHdon_dis,
|
||||
lj_hbond_hdis,
|
||||
lj_maxrad,
|
||||
eps
|
||||
)
|
||||
return torch.autograd.grad(Elj, (xyz,alpha))
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_hb_grads(
|
||||
seq, xyz, alpha, toaa,
|
||||
aamask, hbtypes, hbbaseatoms, hbpolys,
|
||||
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
|
||||
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
|
||||
):
|
||||
xyz.requires_grad_(True)
|
||||
alpha.requires_grad_(True)
|
||||
_, xyzaa = toaa(seq, xyz, alpha)
|
||||
Ehb = calc_hb(
|
||||
seq,
|
||||
xyzaa[0,...,:3],
|
||||
aamask,
|
||||
hbtypes,
|
||||
hbbaseatoms,
|
||||
hbpolys,
|
||||
hb_sp2_range_span,
|
||||
hb_sp2_BAH180_rise,
|
||||
hb_sp2_outer_width,
|
||||
hb_sp3_softmax_fade,
|
||||
threshold_distance,
|
||||
eps,
|
||||
normalize)
|
||||
return torch.autograd.grad(Ehb, xs)
|
||||
|
||||
@torch.enable_grad()
|
||||
def calc_chiral_grads(xyz, chirals):
|
||||
xyz.requires_grad_(True)
|
||||
l = calc_chiral_loss(xyz, chirals)
|
||||
if l.item() == 0.0:
|
||||
return (torch.zeros(xyz.shape, device=xyz.device),) # autograd returns a tuple..
|
||||
return torch.autograd.grad(l, xyz)
|
||||
|
||||
def calc_pseudo_dih(pred, true, eps=1e-6):
|
||||
'''
|
||||
calculate pseudo CA dihedral angle and put loss on them
|
||||
Input:
|
||||
- predicted & true CA coordinates (I,B,L,3) / (B, L, 3)
|
||||
Output:
|
||||
- dihedral angle loss
|
||||
'''
|
||||
I, B, L = pred.shape[:3]
|
||||
pred = pred.reshape(I*B, L, -1)
|
||||
true_dih = torsion(true[:,:-3,:],true[:,1:-2,:],true[:,2:-1,:],true[:,3:,:]) # (B, L', 2)
|
||||
pred_dih = torsion(pred[:,:-3,:],pred[:,1:-2,:],pred[:,2:-1,:],pred[:,3:,:]) # (I*B, L', 2)
|
||||
pred_dih = pred_dih.reshape(I, B, -1, 2)
|
||||
dih_loss = torch.square(pred_dih - true_dih).sum(dim=-1).mean()
|
||||
dih_loss = torch.sqrt(dih_loss + eps)
|
||||
return dih_loss
|
||||
|
||||
def calc_lddt(pred_ca, true_ca, mask_crds, mask_2d, same_chain, negative=False, interface=False, eps=1e-6):
|
||||
# Input
|
||||
# pred_ca: predicted CA coordinates (I, B, L, 3)
|
||||
# true_ca: true CA coordinates (B, L, 3)
|
||||
# pred_lddt: predicted lddt values (I-1, B, L)
|
||||
|
||||
I, B, L = pred_ca.shape[:3]
|
||||
|
||||
pred_dist = torch.cdist(pred_ca, pred_ca) # (I, B, L, L)
|
||||
true_dist = torch.cdist(true_ca, true_ca).unsqueeze(0) # (1, B, L, L)
|
||||
|
||||
mask = torch.logical_and(true_dist > 0.0, true_dist < 15.0) # (1, B, L, L)
|
||||
# update mask information
|
||||
mask *= mask_2d[None]
|
||||
if negative:
|
||||
mask *= same_chain.bool()[None]
|
||||
elif interface:
|
||||
# ignore atoms between the same chain
|
||||
mask *= ~same_chain.bool()[None]
|
||||
|
||||
mask_crds = mask_crds * (mask[0].sum(dim=-1) != 0)
|
||||
|
||||
delta = torch.abs(pred_dist-true_dist) # (I, B, L, L)
|
||||
|
||||
true_lddt = torch.zeros((I,B,L), device=pred_ca.device)
|
||||
for distbin in [0.5, 1.0, 2.0, 4.0]:
|
||||
true_lddt += 0.25*torch.sum((delta<=distbin)*mask, dim=-1) / (torch.sum(mask, dim=-1) + eps)
|
||||
|
||||
true_lddt = mask_crds*true_lddt
|
||||
true_lddt = true_lddt.sum(dim=(1,2)) / (mask_crds.sum() + eps)
|
||||
return true_lddt
|
||||
|
||||
|
||||
#fd allatom lddt
|
||||
def calc_allatom_lddt(P, Q, idx, atm_mask, eps=1e-6):
|
||||
# P - N x L x 27 x 3
|
||||
# Q - L x 27 x 3
|
||||
N, L = P.shape[:2]
|
||||
|
||||
# distance matrix
|
||||
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
|
||||
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
|
||||
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
|
||||
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
|
||||
|
||||
# get valid pairs
|
||||
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
|
||||
# ignore missing atoms
|
||||
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
|
||||
|
||||
# ignore atoms within same residue
|
||||
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
|
||||
|
||||
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
|
||||
|
||||
lddt = torch.zeros( (N,L,27), device=P.device ) # (N, L, 27)
|
||||
for distbin in (0.5,1.0,2.0,4.0):
|
||||
lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4)
|
||||
) / ( torch.sum( pair_mask, dim=(2,4) ) + 1e-8)
|
||||
|
||||
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
|
||||
return lddt
|
||||
|
||||
|
||||
def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain, negative=False, interface=False, bin_scaling=1, eps=1e-6):
|
||||
# P - N x L x 27 x 3
|
||||
# Q - L x 27 x 3
|
||||
# pred_lddt - 1 x nbucket x L
|
||||
N, L, Natm = P.shape[:3]
|
||||
|
||||
# distance matrix
|
||||
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
|
||||
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
|
||||
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
|
||||
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
|
||||
|
||||
# get valid pairs
|
||||
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
|
||||
# ignore missing atoms
|
||||
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
|
||||
|
||||
# ignore atoms within same residue
|
||||
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
|
||||
if negative:
|
||||
# ignore atoms between different chains
|
||||
pair_mask *= same_chain.bool()[:,:,:,None,None]
|
||||
elif interface:
|
||||
# ignore atoms between the same chain
|
||||
pair_mask *= ~same_chain.bool()[:,:,:,None,None]
|
||||
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
|
||||
|
||||
lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27)
|
||||
for distbin in (0.5,1.0,2.0,4.0):
|
||||
lddt += 0.25 * torch.sum( (delta_PQ<=distbin*bin_scaling)*pair_mask, dim=(2,4)
|
||||
) / ( torch.sum( pair_mask, dim=(2,4) ) + eps)
|
||||
|
||||
final_lddt_by_res = torch.clamp(
|
||||
(lddt[-1]*atm_mask[0]).sum(-1)
|
||||
/ (atm_mask.sum(-1) + eps), min=0.0, max=1.0)
|
||||
|
||||
# calculate lddt prediction loss
|
||||
nbin = pred_lddt.shape[1]
|
||||
bin_step = 1.0 / nbin
|
||||
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
|
||||
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
|
||||
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
|
||||
pred_lddt, true_lddt_label[-1])
|
||||
|
||||
res_mask = atm_mask.any(dim=-1)
|
||||
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
|
||||
|
||||
# method 1: average per-residue
|
||||
#lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-8) # L
|
||||
#lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-8)
|
||||
|
||||
# method 2: average per-atom
|
||||
atm_mask = atm_mask * (pair_mask.sum(dim=(1,3)) != 0)
|
||||
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
|
||||
|
||||
return lddt_loss, lddt
|
||||
@@ -1,57 +0,0 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
## MEM utils ##
|
||||
def mem_report():
|
||||
'''Report the memory usage of the tensor.storage in pytorch
|
||||
Both on CPUs and GPUs are reported'''
|
||||
|
||||
def _mem_report(tensors, mem_type):
|
||||
'''Print the selected tensors of type
|
||||
There are two major storage types in our major concern:
|
||||
- GPU: tensors transferred to CUDA devices
|
||||
- CPU: tensors remaining on the system memory (usually unimportant)
|
||||
Args:
|
||||
- tensors: the tensors of specified type
|
||||
- mem_type: 'CPU' or 'GPU' in current implementation '''
|
||||
print('Storage on %s' %(mem_type))
|
||||
print('-'*LEN)
|
||||
total_numel = 0
|
||||
total_mem = 0
|
||||
visited_data = []
|
||||
for tensor in tensors:
|
||||
if tensor.is_sparse:
|
||||
continue
|
||||
# a data_ptr indicates a memory block allocated
|
||||
data_ptr = tensor.storage().data_ptr()
|
||||
if data_ptr in visited_data:
|
||||
continue
|
||||
visited_data.append(data_ptr)
|
||||
|
||||
numel = tensor.storage().size()
|
||||
total_numel += numel
|
||||
element_size = tensor.storage().element_size()
|
||||
mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte
|
||||
total_mem += mem
|
||||
element_type = type(tensor).__name__
|
||||
size = tuple(tensor.size())
|
||||
|
||||
print('%s\t\t%s\t\t%.2f' % (
|
||||
element_type,
|
||||
size,
|
||||
mem) )
|
||||
print('-'*LEN)
|
||||
print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) )
|
||||
print('-'*LEN)
|
||||
|
||||
LEN = 65
|
||||
print('='*LEN)
|
||||
objects = gc.get_objects()
|
||||
print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
|
||||
tensors = [obj for obj in objects if torch.is_tensor(obj)]
|
||||
cuda_tensors = [t for t in tensors if t.is_cuda]
|
||||
host_tensors = [t for t in tensors if not t.is_cuda]
|
||||
_mem_report(cuda_tensors, 'GPU')
|
||||
_mem_report(host_tensors, 'CPU')
|
||||
print('='*LEN)
|
||||
@@ -1,46 +0,0 @@
|
||||
MODEL_PARAM ={
|
||||
"n_extra_block" : 4,
|
||||
"n_main_block" : 32,
|
||||
"n_ref_block" : 4,
|
||||
"d_msa" : 256,
|
||||
"d_pair" : 192,
|
||||
"d_templ" : 64,
|
||||
"n_head_msa" : 8,
|
||||
"n_head_pair" : 6,
|
||||
"n_head_templ" : 4,
|
||||
"d_hidden" : 32,
|
||||
"d_hidden_templ" : 64,
|
||||
"p_drop" : 0.0,
|
||||
"lj_lin" : 0.7
|
||||
}
|
||||
|
||||
SE3_param = {
|
||||
"num_layers" : 1,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 64,
|
||||
"l0_out_features": 64,
|
||||
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 2,
|
||||
"num_edge_features": 64,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
SE3_ref_param = {
|
||||
"num_layers" : 2,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 64,
|
||||
"l0_out_features": 64,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 2,
|
||||
"num_edge_features": 64,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
MODEL_PARAM['SE3_param'] = SE3_param
|
||||
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
|
||||
MODEL_PARAM['use_extra_l1'] = True
|
||||
MODEL_PARAM['use_atom_frames'] = True
|
||||
|
||||
@@ -1,490 +0,0 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.spatial
|
||||
import string
|
||||
import os,re
|
||||
from os.path import exists
|
||||
import random
|
||||
import util
|
||||
import gzip
|
||||
from ffindex import *
|
||||
import torch
|
||||
from chemical import NAATOKENS, aa2num, aa2long, atomnum2atomtype, NTOTAL, CHAIN_GAP
|
||||
from openbabel import openbabel
|
||||
|
||||
to1letter = {
|
||||
"ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
|
||||
"GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
|
||||
"LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
|
||||
"SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V',
|
||||
"DA":'a', "DC":'c', "DG":'g', "DT":'t',
|
||||
"A":'b', "C":'d', "G":'h', "U":'u',
|
||||
}
|
||||
|
||||
def read_template_pdb(L, pdb_fn, target_chain=None):
|
||||
# get full sequence from given PDB
|
||||
seq_full = list()
|
||||
prev_chain=''
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
if line[12:16].strip() != "CA":
|
||||
continue
|
||||
if line[21] != prev_chain:
|
||||
if len(seq_full) > 0:
|
||||
L_s.append(len(seq_full)-offset)
|
||||
offset = len(seq_full)
|
||||
prev_chain = line[21]
|
||||
aa = line[17:20]
|
||||
seq_full.append(aa2num[aa] if aa in aa2num.keys() else 20)
|
||||
|
||||
seq_full = torch.tensor(seq_full).long()
|
||||
|
||||
xyz = torch.full((L, 36, 3), np.nan).float()
|
||||
seq = torch.full((L,), 20).long()
|
||||
conf = torch.zeros(L,1).float()
|
||||
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
|
||||
aa_idx = aa2num[aa] if aa in aa2num.keys() else 20
|
||||
#
|
||||
idx = resNo - 1
|
||||
for i_atm, tgtatm in enumerate(aa2long[aa_idx]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
|
||||
break
|
||||
seq[idx] = aa_idx
|
||||
|
||||
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
|
||||
mask = mask.all(dim=-1)[:,None]
|
||||
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
|
||||
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
|
||||
t1d = torch.cat((seq_1hot, conf), -1)
|
||||
|
||||
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
|
||||
return xyz[None], t1d[None]
|
||||
|
||||
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
|
||||
msa = []
|
||||
ins = []
|
||||
|
||||
fstream = open(filename,"r")
|
||||
|
||||
for line in fstream:
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa.append(line)
|
||||
|
||||
# sequence length
|
||||
L = len(msa[-1])
|
||||
|
||||
i = np.zeros((L))
|
||||
ins.append(i)
|
||||
|
||||
# convert letters into numbers
|
||||
if rmsa_alphabet:
|
||||
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||
else:
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
ins = np.array(ins, dtype=np.uint8)
|
||||
|
||||
return msa,ins
|
||||
|
||||
# parse a fasta alignment IF it exists
|
||||
# otherwise return single-sequence msa
|
||||
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
|
||||
if (exists(filename)):
|
||||
return parse_fasta(filename, maxseq, rmsa_alphabet)
|
||||
else:
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
|
||||
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
seq[seq == alphabet[i]] = i
|
||||
|
||||
return (seq, np.zeros_like(seq))
|
||||
|
||||
# read A3M and convert letters into
|
||||
# integers in the 0..20 range,
|
||||
# also keep track of insertions
|
||||
def parse_a3m(filename, unzip=True, maxseq=10000):
|
||||
msa = []
|
||||
ins = []
|
||||
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
# read file line by line
|
||||
if (unzip):
|
||||
fstream = gzip.open(filename,"rt")
|
||||
else:
|
||||
fstream = open(filename,"r")
|
||||
|
||||
for line in fstream:
|
||||
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa.append(line.translate(table))
|
||||
|
||||
# sequence length
|
||||
L = len(msa[-1])
|
||||
|
||||
# remove insertion at the end
|
||||
if (not unzip):
|
||||
n_remove = 0
|
||||
for c in reversed(line):
|
||||
if c.islower():
|
||||
n_remove += 1
|
||||
else:
|
||||
break
|
||||
line = line[:-n_remove]
|
||||
|
||||
# 0 - match or gap; 1 - insertion
|
||||
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
|
||||
i = np.zeros((L))
|
||||
|
||||
if np.sum(a) > 0:
|
||||
# positions of insertions
|
||||
pos = np.where(a==1)[0]
|
||||
|
||||
# shift by occurrence
|
||||
a = pos - np.arange(pos.shape[0])
|
||||
|
||||
# position of insertions in cleaned sequence
|
||||
# and their length
|
||||
pos,num = np.unique(a, return_counts=True)
|
||||
|
||||
# append to the matrix of insetions
|
||||
i[pos] = num
|
||||
|
||||
ins.append(i)
|
||||
|
||||
if (len(msa) >= maxseq):
|
||||
break
|
||||
|
||||
# convert letters into numbers
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
# treat all unknown characters as gaps
|
||||
msa[msa > 20] = 20
|
||||
|
||||
ins = np.array(ins, dtype=np.uint8)
|
||||
|
||||
return msa,ins
|
||||
|
||||
|
||||
# read and extract xyz coords of N,Ca,C atoms
|
||||
# from a PDB file
|
||||
def parse_pdb(filename, seq=False):
|
||||
lines = open(filename,'r').readlines()
|
||||
if seq:
|
||||
return parse_pdb_lines_w_seq(lines, parse_hetatom=parse_hetatom)
|
||||
return parse_pdb_lines(lines)
|
||||
|
||||
def parse_pdb_lines_w_seq(lines):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
#idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
|
||||
res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
|
||||
idx_s = [int(r[0]) for r in res]
|
||||
seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(idx_s), NTOTAL, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
|
||||
idx = idx_s.index(resNo)
|
||||
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# parse ligand atoms
|
||||
offset = max(idx_s)
|
||||
res_lig = [l[12:16].strip() for l in lines if l[:6]=="HETATM"]
|
||||
res_lig = [(i+offset+CHAIN_GAP,l) for i,l in enumerate(res_lig)]
|
||||
idx_s_lig = [int(r[0]) for r in res_lig]
|
||||
seq_lig = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res_lig]
|
||||
|
||||
xyz_s_lig = [[float(l[30:38]), float(l[38:46]), float(l[46:54])] for l in lines if l[:6]=='HETATM']
|
||||
|
||||
if len(xyz_s_lig)>0:
|
||||
xyz_lig = np.full((len(idx_s_lig), NTOTAL, 3), np.nan, dtype=np.float32)
|
||||
xyz_lig[:,1,:] = np.array(xyz_s_lig)
|
||||
xyz = np.concatenate([xyz, xyz_lig],axis=0)
|
||||
idx_s = idx_s + idx_s_lig
|
||||
seq = seq + seq_lig
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
|
||||
return xyz,mask,np.array(idx_s), np.array(seq)
|
||||
|
||||
#'''
|
||||
def parse_pdb_lines(lines):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(idx_s), NTOTAL, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
|
||||
idx = idx_s.index(resNo)
|
||||
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
|
||||
return xyz,mask,np.array(idx_s)
|
||||
|
||||
|
||||
def parse_templates(item, params):
|
||||
|
||||
# init FFindexDB of templates
|
||||
### and extract template IDs
|
||||
### present in the DB
|
||||
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
|
||||
read_data(params['FFDB']+'_pdb.ffdata'))
|
||||
#ffids = set([i.name for i in ffdb.index])
|
||||
|
||||
# process tabulated hhsearch output to get
|
||||
# matched positions and positional scores
|
||||
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
|
||||
hits = []
|
||||
for l in open(infile, "r").readlines():
|
||||
if l[0]=='>':
|
||||
key = l[1:].split()[0]
|
||||
hits.append([key,[],[]])
|
||||
elif "score" in l or "dssp" in l:
|
||||
continue
|
||||
else:
|
||||
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||
|
||||
# get per-hit statistics from an .hhr file
|
||||
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||
# [Probab, E-value, Score, Aligned_cols,
|
||||
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||
lines = open(infile[:-4]+'hhr', "r").readlines()
|
||||
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||
for i,posi in enumerate(pos):
|
||||
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||
|
||||
# parse templates from FFDB
|
||||
for hi in hits:
|
||||
#if hi[0] not in ffids:
|
||||
# continue
|
||||
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||
if entry == None:
|
||||
continue
|
||||
data = read_entry_lines(entry, ffdb.data)
|
||||
hi += list(parse_pdb_lines(data))
|
||||
|
||||
# process hits
|
||||
counter = 0
|
||||
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
|
||||
for data in hits:
|
||||
if len(data)<7:
|
||||
continue
|
||||
|
||||
qi,ti = np.array(data[1]).T
|
||||
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||
ncol = sel1.shape[0]
|
||||
if ncol < 10:
|
||||
continue
|
||||
|
||||
ids.append(data[0])
|
||||
f0d.append(data[3])
|
||||
f1d.append(np.array(data[2])[sel1])
|
||||
xyz.append(data[4][sel2])
|
||||
mask.append(data[5][sel2])
|
||||
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||
counter += 1
|
||||
|
||||
xyz = np.vstack(xyz).astype(np.float32)
|
||||
mask = np.vstack(mask).astype(np.bool)
|
||||
qmap = np.vstack(qmap).astype(np.long)
|
||||
f0d = np.vstack(f0d).astype(np.float32)
|
||||
f1d = np.vstack(f1d).astype(np.float32)
|
||||
ids = ids
|
||||
|
||||
return xyz,mask,qmap,f0d,f1d,ids
|
||||
|
||||
def parse_templates_raw(ffdb, hhr_fn, atab_fn):
|
||||
# process tabulated hhsearch output to get
|
||||
# matched positions and positional scores
|
||||
hits = []
|
||||
for l in open(atab_fn, "r").readlines():
|
||||
if l[0]=='>':
|
||||
key = l[1:].split()[0]
|
||||
hits.append([key,[],[]])
|
||||
elif "score" in l or "dssp" in l:
|
||||
continue
|
||||
else:
|
||||
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||
|
||||
# get per-hit statistics from an .hhr file
|
||||
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||
# [Probab, E-value, Score, Aligned_cols,
|
||||
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||
lines = open(hhr_fn, "r").readlines()
|
||||
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||
for i,posi in enumerate(pos):
|
||||
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||
|
||||
# parse templates from FFDB
|
||||
for hi in hits:
|
||||
#if hi[0] not in ffids:
|
||||
# continue
|
||||
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||
if entry == None:
|
||||
continue
|
||||
data = read_entry_lines(entry, ffdb.data)
|
||||
hi += list(parse_pdb_lines_w_seq(data))
|
||||
|
||||
# process hits
|
||||
counter = 0
|
||||
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
|
||||
for data in hits:
|
||||
if len(data)<7:
|
||||
continue
|
||||
|
||||
qi,ti = np.array(data[1]).T
|
||||
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||
ncol = sel1.shape[0]
|
||||
if ncol < 10:
|
||||
continue
|
||||
|
||||
ids.append(data[0])
|
||||
f0d.append(data[3])
|
||||
f1d.append(np.array(data[2])[sel1])
|
||||
xyz.append(data[4][sel2])
|
||||
mask.append(data[5][sel2])
|
||||
seq.append(data[-1][sel2])
|
||||
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||
counter += 1
|
||||
|
||||
xyz = np.vstack(xyz).astype(np.float32)
|
||||
qmap = np.vstack(qmap).astype(np.long)
|
||||
f1d = np.vstack(f1d).astype(np.float32)
|
||||
seq = np.hstack(seq).astype(np.long)
|
||||
ids = ids
|
||||
|
||||
return torch.from_numpy(xyz), torch.from_numpy(qmap), \
|
||||
torch.from_numpy(f1d), torch.from_numpy(seq), ids
|
||||
|
||||
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
|
||||
xyz_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn)
|
||||
npick = min(n_templ, len(ids))
|
||||
if npick < 1: # no templates
|
||||
xyz = torch.full((1,qlen,27,3),np.nan).float()
|
||||
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
|
||||
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
|
||||
return xyz, t1d
|
||||
|
||||
sample = torch.arange(npick)
|
||||
#
|
||||
xyz = torch.full((npick, qlen, 27, 3), np.nan).float()
|
||||
f1d = torch.full((npick, qlen), 20).long()
|
||||
f1d_val = torch.zeros((npick, qlen, 1)).float()
|
||||
#
|
||||
for i, nt in enumerate(sample):
|
||||
sel = torch.where(qmap[:,1] == nt)[0]
|
||||
pos = qmap[sel, 0]
|
||||
xyz[i, pos] = xyz_t[sel]
|
||||
f1d[i, pos] = seq[sel]
|
||||
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
|
||||
|
||||
f1d = torch.nn.functional.one_hot(f1d, num_classes=21).float()
|
||||
f1d = torch.cat((f1d, f1d_val), dim=-1)
|
||||
|
||||
return xyz, f1d
|
||||
|
||||
def parse_mol(filename, filetype="mol2", string=False):
|
||||
|
||||
obConversion = openbabel.OBConversion()
|
||||
obConversion.SetInFormat(filetype)
|
||||
|
||||
obmol = openbabel.OBMol()
|
||||
if string:
|
||||
obConversion.ReadString(obmol,filename)
|
||||
else:
|
||||
obConversion.ReadFile(obmol,filename)
|
||||
|
||||
obmol.DeleteHydrogens()
|
||||
# the above sometimes fails to get all the hydrogens
|
||||
i = 1
|
||||
while i < obmol.NumAtoms()+1:
|
||||
if obmol.GetAtom(i).GetAtomicNum()==1:
|
||||
obmol.DeleteAtom(obmol.GetAtom(i))
|
||||
else:
|
||||
i += 1
|
||||
|
||||
atomtypes = [atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM') for i in range(1, obmol.NumAtoms()+1)]
|
||||
msa = torch.tensor([aa2num[x] for x in atomtypes])
|
||||
ins = torch.zeros_like(msa)
|
||||
|
||||
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms)
|
||||
|
||||
try:
|
||||
automorphs = openbabel.vvpairUIntUInt()
|
||||
openbabel.FindAutomorphisms(obmol,automorphs)
|
||||
|
||||
automorphs = torch.tensor(automorphs)
|
||||
n_symmetry = automorphs.shape[0]
|
||||
|
||||
atom_coords = atom_coords.repeat(n_symmetry,1,1)
|
||||
mask = mask.repeat(n_symmetry,1)
|
||||
|
||||
atom_coords = torch.scatter(atom_coords, 1, automorphs[:,:,0:1].repeat(1,1,3),
|
||||
torch.gather(atom_coords,1,automorphs[:,:,1:2].repeat(1,1,3)))
|
||||
mask = torch.scatter(mask, 1, automorphs[:,:,0],
|
||||
torch.gather(mask, 1, automorphs[:,:,1]))
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: automorphs for {filename} yielded invalid tensor")
|
||||
|
||||
return obmol, msa, ins, atom_coords, mask
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
# pre-activation bottleneck resblock
|
||||
class ResBlock2D_bottleneck(nn.Module):
|
||||
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
|
||||
super(ResBlock2D_bottleneck, self).__init__()
|
||||
padding = self._get_same_padding(kernel, dilation)
|
||||
|
||||
n_b = n_c // 2 # bottleneck channel
|
||||
|
||||
layer_s = list()
|
||||
# pre-activation
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# project down to n_b
|
||||
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# convolution
|
||||
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# dropout
|
||||
layer_s.append(nn.Dropout(p_drop))
|
||||
# project up
|
||||
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
|
||||
|
||||
# make final layer initialize with zeros
|
||||
#nn.init.zeros_(layer_s[-1].weight)
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# zero-initialize final layer right before residual connection
|
||||
nn.init.zeros_(self.layer[-1].weight)
|
||||
|
||||
def _get_same_padding(self, kernel, dilation):
|
||||
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return x + out
|
||||
|
||||
class ResidualNetwork(nn.Module):
|
||||
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
|
||||
dilation=[1,2,4,8], p_drop=0.15):
|
||||
super(ResidualNetwork, self).__init__()
|
||||
|
||||
|
||||
layer_s = list()
|
||||
# project to n_feat_block
|
||||
if n_feat_in != n_feat_block:
|
||||
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
|
||||
|
||||
# add resblocks
|
||||
for i_block in range(n_block):
|
||||
d = dilation[i_block%len(dilation)]
|
||||
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
|
||||
layer_s.append(res_block)
|
||||
|
||||
if n_feat_out != n_feat_block:
|
||||
# project to n_feat_out
|
||||
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
@@ -1,371 +0,0 @@
|
||||
import json, os
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
|
||||
##
|
||||
## lk and lk term
|
||||
#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME)
|
||||
type2ljlk = {
|
||||
"CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000),
|
||||
"COO":(1.916661,0.141799,-3.332648,3.5000,14.653000),
|
||||
"CH0":(2.011760,0.062642,1.409284,3.5000,8.998000),
|
||||
"CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000),
|
||||
"CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000),
|
||||
"CH3":(2.011760,0.062642,7.292929,3.5000,25.855000),
|
||||
"aroC":(2.016441,0.068775,1.797950,3.5000,16.704000),
|
||||
"Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100),
|
||||
"Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700),
|
||||
"NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200),
|
||||
"NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000),
|
||||
"Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000),
|
||||
"Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000),
|
||||
"Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100),
|
||||
"OH":(1.542743,0.161947,-8.133520,3.5000,10.722000),
|
||||
"OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000),
|
||||
"ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000),
|
||||
"OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600),
|
||||
"S":(1.975967,0.455970,-1.707229,3.5000,17.640000),
|
||||
"SH1":(1.975967,0.455970,3.291643,3.5000,23.240000),
|
||||
"Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000),
|
||||
"CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000),
|
||||
"CObb":(1.916661,0.141799,3.104248,3.5000,13.221000),
|
||||
"OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000),
|
||||
"Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil
|
||||
"Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000),
|
||||
"Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000),
|
||||
"HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000),
|
||||
"Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000),
|
||||
"Haro":(1.374914,0.015909,0.0000,3.5000,0.0000),
|
||||
"Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000),
|
||||
"HS":(0.363887,0.050836,0.0000,3.5000,0.0000),
|
||||
"genAl":(1,0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genAs":(1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genAu":(1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genB": (1,0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genBe": (1,0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genBr": (2.1971, 0.1090, 2.7951, 3.5000, 19.6876),
|
||||
"genC": (2.0067, 0.0689, 2.2256, 3.5000, 10.6860), # params from CT
|
||||
"genCa": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genCl": (2.0496, 0.1070, 2.3668, 3.5000, 17.5849),
|
||||
"genCo": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genCr": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genCu": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genF": (1.6941, 0.0750, 1.6442, 3.5000, 12.2163),
|
||||
"genFe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genHg": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genI": (2.3600, 0.1110, 3.1361, 3.5000, 22.0891),
|
||||
"genIr": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genK": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genLi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genMg": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genMn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genMo": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genN": (1.7854, 0.1497, -6.3760, 3.5000, 9.5221), # params from NG2
|
||||
"genNi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genO": (1.5492, 0.1576, -3.5363, 3.5000, 10.7220), # params for OG3
|
||||
"genOs": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genP": (2.1290, 0.5838, -9.6272, 3.5000, 34.8000), # params for PG5
|
||||
"genPb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genPd": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genPr": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genPt": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genRe": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genRh": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genRu": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||
"genS": (1.9893, 0.3634, -2.3560, 3.5000, 17.6400), # params for SG3
|
||||
"genSb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genSe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genSi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genSn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genTb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genTe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genU": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genW": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genV": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genY": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genZn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||
"genATM": (1, 0.0, 0.0000, 0.0000, 0.0000), # masked
|
||||
}
|
||||
|
||||
# cartbonded
|
||||
with open(script_dir+'cartbonded.json', 'r') as j:
|
||||
cartbonded_data_raw = json.loads(j.read())
|
||||
|
||||
# hbond donor/acceptors
|
||||
class HbAtom:
|
||||
NO = 0
|
||||
DO = 1 # donor
|
||||
AC = 2 # acceptor
|
||||
DA = 3 # donor & acceptor
|
||||
HP = 4 # polar H
|
||||
|
||||
type2hb = {
|
||||
"CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO,
|
||||
"CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO,
|
||||
"Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO,
|
||||
"Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA,
|
||||
"ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO,
|
||||
"Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC,
|
||||
"HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP,
|
||||
"HS":HbAtom.HP, # HP in rosetta(?)
|
||||
"Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC,
|
||||
"genAl":HbAtom.NO, "genAs":HbAtom.NO, "genAu":HbAtom.NO, "genB": HbAtom.NO,
|
||||
"genBe": HbAtom.NO, "genBr": HbAtom.NO, "genC": HbAtom.NO, "genCa": HbAtom.NO,
|
||||
"genCl": HbAtom.NO, "genCo": HbAtom.NO, "genCr": HbAtom.NO, "genCu": HbAtom.NO,
|
||||
"genF": HbAtom.DA, "genFe": HbAtom.NO, "genHg": HbAtom.NO, "genI": HbAtom.NO,
|
||||
"genIr": HbAtom.NO, "genK": HbAtom.NO, "genLi": HbAtom.NO, "genMg": HbAtom.NO,
|
||||
"genMn": HbAtom.NO, "genMo": HbAtom.NO, "genN": HbAtom.DA, "genNi": HbAtom.NO,
|
||||
"genO": HbAtom.DA, "genOs": HbAtom.NO, "genP": HbAtom.NO, "genPb": HbAtom.NO,
|
||||
"genPd": HbAtom.NO, "genPr": HbAtom.NO, "genPt": HbAtom.NO, "genRe": HbAtom.NO,
|
||||
"genRh": HbAtom.NO, "genRu": HbAtom.NO, "genS": HbAtom.DA, "genSb": HbAtom.NO,
|
||||
"genSe": HbAtom.NO, "genSi": HbAtom.NO,"genSn": HbAtom.NO,"genTb": HbAtom.NO,
|
||||
"genTe": HbAtom.NO, "genU": HbAtom.NO, "genW": HbAtom.NO, "genV": HbAtom.NO,
|
||||
"genY": HbAtom.NO, "genZn": HbAtom.NO, "genATM": HbAtom.NO, # masked
|
||||
}
|
||||
|
||||
##
|
||||
## hbond term
|
||||
## TO DO: ADD DNA
|
||||
class HbDonType:
|
||||
PBA = 0
|
||||
IND = 1
|
||||
IME = 2
|
||||
GDE = 3
|
||||
CXA = 4
|
||||
AMO = 5
|
||||
HXL = 6
|
||||
AHX = 7
|
||||
NTYPES = 8
|
||||
|
||||
class HbAccType:
|
||||
PBA = 0
|
||||
CXA = 1
|
||||
CXL = 2
|
||||
HXL = 3
|
||||
AHX = 4
|
||||
IME = 5
|
||||
NTYPES = 6
|
||||
|
||||
class HbHybType:
|
||||
SP2 = 0
|
||||
SP3 = 1
|
||||
RING = 2
|
||||
NTYPES = 3
|
||||
|
||||
type2dontype = {
|
||||
"Nbb": HbDonType.PBA,
|
||||
"Ntrp": HbDonType.IND,
|
||||
"NtrR": HbDonType.GDE,
|
||||
"Narg": HbDonType.GDE,
|
||||
"NH2O": HbDonType.CXA,
|
||||
"Nlys": HbDonType.AMO,
|
||||
"OH": HbDonType.HXL,
|
||||
"OHY": HbDonType.AHX,
|
||||
}
|
||||
|
||||
type2acctype = {
|
||||
"OCbb": HbAccType.PBA,
|
||||
"ONH2": HbAccType.CXA,
|
||||
"OOC": HbAccType.CXL,
|
||||
"OH": HbAccType.HXL,
|
||||
"OHY": HbAccType.AHX,
|
||||
"Nhis": HbAccType.IME,
|
||||
}
|
||||
|
||||
type2hybtype = {
|
||||
"OCbb": HbHybType.SP2,
|
||||
"ONH2": HbHybType.SP2,
|
||||
"OOC": HbHybType.SP2,
|
||||
"OHY": HbHybType.SP3,
|
||||
"OH": HbHybType.SP3,
|
||||
"Nhis": HbHybType.RING,
|
||||
}
|
||||
|
||||
dontype2wt = {
|
||||
HbDonType.PBA: 1.45,
|
||||
HbDonType.IND: 1.15,
|
||||
HbDonType.IME: 1.42,
|
||||
HbDonType.GDE: 1.11,
|
||||
HbDonType.CXA: 1.29,
|
||||
HbDonType.AMO: 1.17,
|
||||
HbDonType.HXL: 0.99,
|
||||
HbDonType.AHX: 1.00,
|
||||
}
|
||||
|
||||
acctype2wt = {
|
||||
HbAccType.PBA: 1.19,
|
||||
HbAccType.CXA: 1.21,
|
||||
HbAccType.CXL: 1.10,
|
||||
HbAccType.HXL: 1.15,
|
||||
HbAccType.AHX: 1.15,
|
||||
HbAccType.IME: 1.17,
|
||||
}
|
||||
|
||||
class HbPolyType:
|
||||
ahdist_aASN_dARG = 0
|
||||
ahdist_aASN_dASN = 1
|
||||
ahdist_aASN_dGLY = 2
|
||||
ahdist_aASN_dHIS = 3
|
||||
ahdist_aASN_dLYS = 4
|
||||
ahdist_aASN_dSER = 5
|
||||
ahdist_aASN_dTRP = 6
|
||||
ahdist_aASN_dTYR = 7
|
||||
ahdist_aASP_dARG = 8
|
||||
ahdist_aASP_dASN = 9
|
||||
ahdist_aASP_dGLY = 10
|
||||
ahdist_aASP_dHIS = 11
|
||||
ahdist_aASP_dLYS = 12
|
||||
ahdist_aASP_dSER = 13
|
||||
ahdist_aASP_dTRP = 14
|
||||
ahdist_aASP_dTYR = 15
|
||||
ahdist_aGLY_dARG = 16
|
||||
ahdist_aGLY_dASN = 17
|
||||
ahdist_aGLY_dGLY = 18
|
||||
ahdist_aGLY_dHIS = 19
|
||||
ahdist_aGLY_dLYS = 20
|
||||
ahdist_aGLY_dSER = 21
|
||||
ahdist_aGLY_dTRP = 22
|
||||
ahdist_aGLY_dTYR = 23
|
||||
ahdist_aHIS_dARG = 24
|
||||
ahdist_aHIS_dASN = 25
|
||||
ahdist_aHIS_dGLY = 26
|
||||
ahdist_aHIS_dHIS = 27
|
||||
ahdist_aHIS_dLYS = 28
|
||||
ahdist_aHIS_dSER = 29
|
||||
ahdist_aHIS_dTRP = 30
|
||||
ahdist_aHIS_dTYR = 31
|
||||
ahdist_aSER_dARG = 32
|
||||
ahdist_aSER_dASN = 33
|
||||
ahdist_aSER_dGLY = 34
|
||||
ahdist_aSER_dHIS = 35
|
||||
ahdist_aSER_dLYS = 36
|
||||
ahdist_aSER_dSER = 37
|
||||
ahdist_aSER_dTRP = 38
|
||||
ahdist_aSER_dTYR = 39
|
||||
ahdist_aTYR_dARG = 40
|
||||
ahdist_aTYR_dASN = 41
|
||||
ahdist_aTYR_dGLY = 42
|
||||
ahdist_aTYR_dHIS = 43
|
||||
ahdist_aTYR_dLYS = 44
|
||||
ahdist_aTYR_dSER = 45
|
||||
ahdist_aTYR_dTRP = 46
|
||||
ahdist_aTYR_dTYR = 47
|
||||
cosBAH_off = 48
|
||||
cosBAH_7 = 49
|
||||
cosBAH_6i = 50
|
||||
AHD_1h = 51
|
||||
AHD_1i = 52
|
||||
AHD_1j = 53
|
||||
AHD_1k = 54
|
||||
|
||||
# map donor:acceptor pairs to polynomials
|
||||
hbtypepair2poly = {
|
||||
(HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||
(HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||
(HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||
(HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||
(HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||
(HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||
(HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||
(HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||
(HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||
(HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||
(HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||
(HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||
(HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||
(HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||
(HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||
(HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||
(HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||
(HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
|
||||
(HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
|
||||
(HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
(HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||
}
|
||||
|
||||
|
||||
# polynomials are triplets, (x_min, x_max), (y[x<x_min],y[x>x_max]), (c_9,...,c_0)
|
||||
hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886
|
||||
HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)),
|
||||
HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)),
|
||||
HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)),
|
||||
HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)),
|
||||
HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
|
||||
HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)),
|
||||
HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)),
|
||||
HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)),
|
||||
HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)),
|
||||
HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)),
|
||||
HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)),
|
||||
HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)),
|
||||
HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)),
|
||||
HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
|
||||
HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)),
|
||||
HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)),
|
||||
HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)),
|
||||
HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)),
|
||||
HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)),
|
||||
HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)),
|
||||
HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)),
|
||||
HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)),
|
||||
HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)),
|
||||
HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)),
|
||||
HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)),
|
||||
HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)),
|
||||
HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)),
|
||||
HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)),
|
||||
HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)),
|
||||
HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)),
|
||||
HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)),
|
||||
HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)),
|
||||
HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)),
|
||||
HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)),
|
||||
HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)),
|
||||
HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)),
|
||||
HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)),
|
||||
HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)),
|
||||
HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)),
|
||||
HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)),
|
||||
HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)),
|
||||
HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)),
|
||||
HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)),
|
||||
HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)),
|
||||
HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)),
|
||||
HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)),
|
||||
HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)),
|
||||
HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)),
|
||||
HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)),
|
||||
HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)),
|
||||
HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)),
|
||||
HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)),
|
||||
HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)),
|
||||
HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)),
|
||||
HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)),
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,589 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from opt_einsum import contract as einsum
|
||||
import copy
|
||||
import dgl
|
||||
import networkx as nx
|
||||
from util import base_indices, RTs_by_torsion, xyzs_in_base_frame, \
|
||||
rigid_from_3_points, is_nucleic, is_atom
|
||||
|
||||
def init_lecun_normal(module, scale=1.0):
|
||||
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||
normal = torch.distributions.normal.Normal(0, 1)
|
||||
|
||||
alpha = (a - mu) / sigma
|
||||
beta = (b - mu) / sigma
|
||||
|
||||
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
|
||||
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
|
||||
|
||||
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
|
||||
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
|
||||
x = torch.clamp(x, a, b)
|
||||
|
||||
return x
|
||||
|
||||
def sample_truncated_normal(shape, scale=1.0):
|
||||
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
return stddev * truncated_normal(torch.rand(shape))
|
||||
|
||||
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
|
||||
return module
|
||||
|
||||
def init_lecun_normal_param(weight, scale=1.0):
|
||||
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||
normal = torch.distributions.normal.Normal(0, 1)
|
||||
|
||||
alpha = (a - mu) / sigma
|
||||
beta = (b - mu) / sigma
|
||||
|
||||
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
|
||||
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
|
||||
|
||||
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
|
||||
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
|
||||
x = torch.clamp(x, a, b)
|
||||
|
||||
return x
|
||||
|
||||
def sample_truncated_normal(shape, scale=1.0):
|
||||
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
return stddev * truncated_normal(torch.rand(shape))
|
||||
|
||||
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
|
||||
return weight
|
||||
|
||||
# for gradient checkpointing
|
||||
def create_custom_forward(module, **kwargs):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
def get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
class Dropout(nn.Module):
|
||||
# Dropout entire row or column
|
||||
def __init__(self, broadcast_dim=None, p_drop=0.15):
|
||||
super(Dropout, self).__init__()
|
||||
# give ones with probability of 1-p_drop / zeros with p_drop
|
||||
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
|
||||
self.broadcast_dim=broadcast_dim
|
||||
self.p_drop=p_drop
|
||||
def forward(self, x):
|
||||
if not self.training: # no drophead during evaluation mode
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
if not self.broadcast_dim == None:
|
||||
shape[self.broadcast_dim] = 1
|
||||
mask = self.sampler.sample(shape).to(x.device).view(shape)
|
||||
|
||||
x = mask * x / (1.0 - self.p_drop)
|
||||
return x
|
||||
|
||||
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
|
||||
# Distance radial basis function
|
||||
D_max = D_min + (D_count-1) * D_sigma
|
||||
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
|
||||
D_mu = D_mu[None,:]
|
||||
D_expand = torch.unsqueeze(D, -1)
|
||||
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
||||
return RBF
|
||||
|
||||
def get_seqsep(idx):
|
||||
'''
|
||||
Sequence separation feature for structure module. Protein-only.
|
||||
|
||||
Input:
|
||||
- idx: residue indices of given sequence (B,L)
|
||||
Output:
|
||||
- seqsep: sequence separation feature with sign (B, L, L, 1)
|
||||
Sergey found that having sign in seqsep features helps a little
|
||||
'''
|
||||
seqsep = idx[:,None,:] - idx[:,:,None]
|
||||
sign = torch.sign(seqsep)
|
||||
neigh = torch.abs(seqsep)
|
||||
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
|
||||
neigh = sign * neigh
|
||||
return neigh.unsqueeze(-1)
|
||||
|
||||
def get_seqsep_protein_sm(idx, bond_feats, sm_mask):
|
||||
'''
|
||||
Sequence separation features for protein-SM complex
|
||||
|
||||
Input:
|
||||
- idx: residue indices of given sequence (B,L)
|
||||
- bond_feats: bond features (B, L, L)
|
||||
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
|
||||
|
||||
Output:
|
||||
- seqsep: sequence separation feature with sign (B, L, L, 1)
|
||||
-1 or 1 for bonded protein residues
|
||||
1 for bonded SM atoms or residue-atom bonds
|
||||
0 elsewhere
|
||||
'''
|
||||
sm_mask = sm_mask[0] # assume batch = 1
|
||||
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, sm_mask)
|
||||
|
||||
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
|
||||
res_dist[(res_dist > 1) | (res_dist < -1)] = 0.0
|
||||
atom_dist[(atom_dist > 1)] = 0.0
|
||||
|
||||
seqsep = sm_mask_2d*atom_dist + prot_mask_2d*res_dist + inter_mask_2d*(bond_feats==6)
|
||||
|
||||
return seqsep.unsqueeze(-1)
|
||||
|
||||
def get_res_atom_dist(idx, bond_feats, sm_mask, minpos_res=-32, maxpos_res=32, maxpos_atom=8):
|
||||
'''
|
||||
Calculates residue and atom bond distances of protein/SM complex. Used for positional
|
||||
embedding and structure module. 2nd version (2022-9-19); handles atomized proteins.
|
||||
|
||||
Input:
|
||||
- idx: residue index (B, L)
|
||||
- bond_feats: bond features (B, L, L)
|
||||
- sm_mask: boolean feature (L). True if a position represents atom, False otherwise
|
||||
- minpos_res: minimum value of residue distances
|
||||
- maxpos_res: maximum value of residue distances
|
||||
- maxpos_atom: maximum value of atom bond distances
|
||||
|
||||
Output:
|
||||
- res_dist: residue distance (B, L, L)
|
||||
- atom_dist: atom bond distance (B, L, L)
|
||||
'''
|
||||
bond_feats = bond_feats[0] # assume batch = 1
|
||||
L = bond_feats.shape[0]
|
||||
gpu = bond_feats.device
|
||||
|
||||
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
|
||||
# protein residue distances
|
||||
res_dist_prot = torch.clamp(idx[0,None,:] - idx[0,:,None],
|
||||
min=minpos_res, max=maxpos_res).to(gpu) # (L, L) intra-protein
|
||||
res_dist_sm = torch.full((L,L), maxpos_res+1).to(gpu) # (L, L) with "unknown" res. dist. token
|
||||
|
||||
# small molecule atom bond graph
|
||||
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask_2d*bond_feats
|
||||
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
|
||||
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpos_atom))
|
||||
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
|
||||
i,j,v = torch.tensor(paths).T
|
||||
|
||||
# small molecule atom bond distances
|
||||
atom_dist_sm = torch.full((L,L), maxpos_atom).to(gpu) - maxpos_atom*torch.eye(L).to(gpu).long()
|
||||
atom_dist_sm[i,j] = v.to(gpu)
|
||||
atom_dist_prot = torch.full((L,L), maxpos_atom+1).to(gpu)
|
||||
|
||||
# s.m.-protein bonds
|
||||
sm_idx = torch.where(sm_mask)[0]
|
||||
prot_idx = torch.where(~sm_mask)[0]
|
||||
i_s, j_s = torch.where(bond_feats==6)
|
||||
i_prot = [j for i,j in zip(i_s,j_s) if i in sm_idx] # protein residues bonded to s.m. atoms
|
||||
i_sm = [i for i in i_s if i in sm_idx] # s.m. atoms bonded to protein residues
|
||||
|
||||
# inter-protein-s.m. residue & atom distances
|
||||
# atoms inherit residue distances from their nearest bonded residue
|
||||
res_dist_inter = torch.full((L,L), maxpos_res).to(gpu)
|
||||
if len(i_prot) > 0: # prot & s.m. are connected
|
||||
for i in sm_idx:
|
||||
i_closest_res = i_prot[torch.argmin(atom_dist_sm[i,i_sm])]
|
||||
res_dist_inter[i,:] = res_dist_prot[i_closest_res,:]
|
||||
res_dist_inter[:,i] = res_dist_prot[:,i_closest_res]
|
||||
|
||||
# residues inherit atom distances from their nearest bonded atom (+ 1 to count "boundary" res-atom bond)
|
||||
atom_dist_inter = torch.full((L, L), maxpos_atom).to(gpu)
|
||||
if len(i_prot) > 0: # prot & s.m. are connected
|
||||
for i in prot_idx:
|
||||
i_closest_atom = i_sm[torch.argmin(torch.abs(res_dist_prot[i,i_prot]))]
|
||||
atom_dist_inter[i,:] = atom_dist_sm[i_closest_atom,:] + 1
|
||||
atom_dist_inter[:,i] = atom_dist_sm[:,i_closest_atom] + 1
|
||||
atom_dist_inter = torch.minimum(atom_dist_inter, torch.tensor(maxpos_atom))
|
||||
|
||||
res_dist = res_dist_prot * prot_mask_2d + res_dist_inter * inter_mask_2d + res_dist_sm * sm_mask_2d
|
||||
atom_dist = atom_dist_prot * prot_mask_2d + atom_dist_inter * inter_mask_2d + atom_dist_sm * sm_mask_2d
|
||||
|
||||
return res_dist[None].to(gpu), atom_dist[None].to(gpu) # add batch dim.
|
||||
|
||||
def get_relpos(idx, bond_feats, sm_mask, inter_pos=32, maxpath=32):
|
||||
'''
|
||||
Relative position matrix of protein/SM complex. Used for positional
|
||||
embedding and structure module. Simple version from 9/2/2022 that doesn't
|
||||
handle atomized proteins.
|
||||
|
||||
Input:
|
||||
- idx: residue index (B, L)
|
||||
- bond_feats: bond features (B, L, L)
|
||||
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
|
||||
- inter_pos: value to assign as the protein-SM residue index differences
|
||||
- maxpath: bond distances greater than this are clipped to this value
|
||||
|
||||
Output:
|
||||
- relpos: relative position feature (B, L, L)
|
||||
for intra-protein this is the residue index difference
|
||||
for intra-SM this is the bond distance
|
||||
for protein-SM this is user-defined value inter_pos
|
||||
'''
|
||||
bond_feats = bond_feats[0]
|
||||
|
||||
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||
|
||||
# intra-protein: residue # differences
|
||||
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
|
||||
|
||||
# intra-small molecule: bond distances
|
||||
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask*bond_feats
|
||||
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
|
||||
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpath))
|
||||
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
|
||||
i,j,v = torch.tensor(paths).T
|
||||
|
||||
bond_separation = torch.full_like(bond_feats, maxpath) \
|
||||
- maxpath*torch.eye(bond_feats.shape[0]).to(bond_feats.device).long()
|
||||
bond_separation[i,j] = v.to(bond_feats.device)
|
||||
|
||||
# combine: protein-s.m. are always positive maximum distance apart
|
||||
# assumes one small molecule per example
|
||||
relpos = prot_mask_2d * seqsep + sm_mask_2d * bond_separation + inter_mask_2d * inter_pos # (B, L, L)
|
||||
relpos = relpos.to(bond_feats.device)
|
||||
|
||||
return relpos
|
||||
|
||||
def make_full_graph(xyz, pair, idx):
|
||||
'''
|
||||
Input:
|
||||
- xyz: current backbone cooordinates (B, L, 3, 3)
|
||||
- pair: pair features from Trunk (B, L, L, E)
|
||||
- idx: residue index from ground truth pdb
|
||||
Output:
|
||||
- G: defined graph
|
||||
'''
|
||||
|
||||
B, L = xyz.shape[:2]
|
||||
device = xyz.device
|
||||
|
||||
# seq sep
|
||||
sep = idx[:,None,:] - idx[:,:,None]
|
||||
b,i,j = torch.where(sep.abs() > 0)
|
||||
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
|
||||
|
||||
return G, pair[b,i,j][...,None]
|
||||
|
||||
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
|
||||
'''
|
||||
Input:
|
||||
- xyz: current backbone cooordinates (B, L, 3, 3)
|
||||
- pair: pair features from Trunk (B, L, L, E)
|
||||
- idx: residue index from ground truth pdb
|
||||
Output:
|
||||
- G: defined graph
|
||||
'''
|
||||
|
||||
B, L = xyz.shape[:2]
|
||||
device = xyz.device
|
||||
|
||||
# distance map from current CA coordinates
|
||||
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L)
|
||||
|
||||
# seq sep
|
||||
sep = idx[:,None,:] - idx[:,:,None]
|
||||
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9
|
||||
|
||||
if (topk_incl_local):
|
||||
D = D + sep*eps
|
||||
D[sep<nlocal] = 0.0
|
||||
|
||||
# get top_k neighbors
|
||||
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
|
||||
topk_matrix = torch.zeros((B, L, L), device=device)
|
||||
topk_matrix.scatter_(2, E_idx, 1.0)
|
||||
cond = topk_matrix > 0.0
|
||||
|
||||
else:
|
||||
|
||||
D = D + sep*eps
|
||||
|
||||
# get top_k neighbors
|
||||
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
|
||||
topk_matrix = torch.zeros((B, L, L), device=device)
|
||||
topk_matrix.scatter_(2, E_idx, 1.0)
|
||||
|
||||
# put an edge if any of the 3 conditions are met:
|
||||
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
|
||||
# 2) top_k neighbors
|
||||
cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal)
|
||||
b,i,j = torch.where(cond)
|
||||
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
|
||||
|
||||
return G, pair[b,i,j][...,None]
|
||||
|
||||
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
|
||||
B,L,A = xyz.shape[:3]
|
||||
device = xyz.device
|
||||
|
||||
D = torch.norm(
|
||||
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
|
||||
)
|
||||
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
|
||||
D[~mask2d] = 9999.
|
||||
D[D==0] = 9999.
|
||||
|
||||
# select top K neighbors for each atom
|
||||
# keep indices as batch/res/atm indices
|
||||
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
|
||||
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
|
||||
bi,ri,ai = mask.nonzero(as_tuple=True)
|
||||
bi = bi[:,None].repeat(1,top_k).reshape(-1)
|
||||
ri = ri[:,None].repeat(1,top_k).reshape(-1)
|
||||
ai = ai[:,None].repeat(1,top_k).reshape(-1)
|
||||
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
|
||||
|
||||
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
|
||||
edge = torch.full(ri.shape, maxbonds, device=device)
|
||||
resmask = ri==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
|
||||
resmask = ri+1==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
|
||||
resmask = ri-1==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
|
||||
edge = edge.clamp(0,maxbonds-1)
|
||||
edge = F.one_hot(edge)[...,None]
|
||||
|
||||
natm = torch.sum(mask)
|
||||
index = torch.zeros_like(mask, dtype=torch.long, device=device)
|
||||
index[mask] = torch.arange(natm, device=device)
|
||||
src=index[bi,ri,ai]
|
||||
tgt=index[bi,rj,aj]
|
||||
|
||||
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
|
||||
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
|
||||
|
||||
return G, edge
|
||||
|
||||
|
||||
# rotate about the x axis
|
||||
def make_rotX(angs, eps=1e-6):
|
||||
B,L = angs.shape[:2]
|
||||
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||
|
||||
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||
|
||||
RTs[:,:,1,1] = angs[:,:,0]/NORM
|
||||
RTs[:,:,1,2] = -angs[:,:,1]/NORM
|
||||
RTs[:,:,2,1] = angs[:,:,1]/NORM
|
||||
RTs[:,:,2,2] = angs[:,:,0]/NORM
|
||||
return RTs
|
||||
|
||||
# rotate about the x axis
|
||||
def make_rotZ(angs, eps=1e-6):
|
||||
B,L = angs.shape[:2]
|
||||
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||
|
||||
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||
|
||||
RTs[:,:,0,0] = angs[:,:,0]/NORM
|
||||
RTs[:,:,0,1] = -angs[:,:,1]/NORM
|
||||
RTs[:,:,1,0] = angs[:,:,1]/NORM
|
||||
RTs[:,:,1,1] = angs[:,:,0]/NORM
|
||||
return RTs
|
||||
|
||||
# rotate about an arbitrary axis
|
||||
def make_rot_axis(angs, u, eps=1e-6):
|
||||
B,L = angs.shape[:2]
|
||||
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||
|
||||
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||
|
||||
ct = angs[:,:,0]/NORM
|
||||
st = angs[:,:,1]/NORM
|
||||
u0 = u[:,:,0]
|
||||
u1 = u[:,:,1]
|
||||
u2 = u[:,:,2]
|
||||
|
||||
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
|
||||
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
|
||||
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
|
||||
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
|
||||
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
|
||||
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
|
||||
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
|
||||
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
|
||||
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
|
||||
return RTs
|
||||
|
||||
|
||||
# compute allatom structure from backbone frames and torsions
|
||||
#
|
||||
# alphas:
|
||||
# omega/phi/psi: 0-2
|
||||
# chi_1-4(prot): 3-6
|
||||
# cb/cg bend: 7-9
|
||||
# eps(p)/zeta(p): 10-11
|
||||
# alpha/beta/gamma/delta: 12-15
|
||||
# nu2/nu1/nu0: 16-18
|
||||
# chi_1(na): 19
|
||||
#
|
||||
# RTs_in_base_frame:
|
||||
# omega/phi/psi: 0-2
|
||||
# chi_1-4(prot): 3-6
|
||||
# eps(p)/zeta(p): 7-8
|
||||
# alpha/beta/gamma/delta: 9-12
|
||||
# nu2/nu1/nu0: 13-15
|
||||
# chi_1(na): 16
|
||||
#
|
||||
# RT frames (output):
|
||||
# origin: 0
|
||||
# omega/phi/psi: 1-3
|
||||
# chi_1-4(prot): 4-7
|
||||
# cb bend: 8
|
||||
# alpha/beta/gamma/delta: 9-12
|
||||
# nu2/nu1/nu0: 13-15
|
||||
# chi_1(na): 16
|
||||
#
|
||||
class ComputeAllAtomCoords(nn.Module):
|
||||
def __init__(self):
|
||||
super(ComputeAllAtomCoords, self).__init__()
|
||||
|
||||
self.base_indices = nn.Parameter(base_indices, requires_grad=False)
|
||||
self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False)
|
||||
self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False)
|
||||
|
||||
def forward(self, seq, xyz, alphas):
|
||||
B,L = xyz.shape[:2]
|
||||
|
||||
is_NA = is_nucleic(seq)
|
||||
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA)
|
||||
|
||||
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
|
||||
|
||||
# bb
|
||||
RTF0[:,:,:3,:3] = Rs
|
||||
RTF0[:,:,:3,3] = Ts
|
||||
|
||||
# omega
|
||||
RTF1 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
|
||||
|
||||
# phi
|
||||
RTF2 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
|
||||
|
||||
# psi
|
||||
RTF3 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
|
||||
|
||||
# CB bend
|
||||
basexyzs = self.xyzs_in_base_frame[seq]
|
||||
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
|
||||
CAr = (basexyzs[:,:,1,:3])
|
||||
CBr = (basexyzs[:,:,4,:3])
|
||||
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
|
||||
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
|
||||
|
||||
# CB twist
|
||||
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
|
||||
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
|
||||
CBrotaxis2 = (CBr-CAr).cross(NCpp)
|
||||
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
|
||||
|
||||
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
|
||||
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
|
||||
|
||||
RTF8 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF0, CBrot1,CBrot2)
|
||||
|
||||
# chi1 + CG bend
|
||||
RTF4 = torch.einsum(
|
||||
'brij,brjk,brkl,brlm->brim',
|
||||
RTF8,
|
||||
self.RTs_in_base_frame[seq,3,:],
|
||||
make_rotX(alphas[:,:,3,:]),
|
||||
make_rotZ(alphas[:,:,9,:]))
|
||||
|
||||
# chi2
|
||||
RTF5 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
|
||||
|
||||
# chi3
|
||||
RTF6 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
|
||||
|
||||
# chi4
|
||||
RTF7 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
|
||||
|
||||
# ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:]
|
||||
|
||||
# NA alpha
|
||||
RTF9 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
|
||||
|
||||
# NA beta
|
||||
RTF10 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
|
||||
|
||||
# NA gamma
|
||||
RTF11 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
|
||||
|
||||
# NA delta
|
||||
RTF12 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
|
||||
|
||||
# NA nu2 - from gamma frame
|
||||
RTF13 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
|
||||
|
||||
# NA nu1
|
||||
RTF14 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
|
||||
|
||||
# NA nu0
|
||||
RTF15 = torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
|
||||
|
||||
# NA chi - from nu1 frame
|
||||
RTF16= torch.einsum(
|
||||
'brij,brjk,brkl->bril',
|
||||
RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
|
||||
|
||||
RTframes = torch.stack((
|
||||
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8,
|
||||
RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16
|
||||
),dim=2)
|
||||
|
||||
xyzs = torch.einsum(
|
||||
'brtij,brtj->brti',
|
||||
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
|
||||
)
|
||||
|
||||
return RTframes, xyzs[...,:3]
|
||||
@@ -1,268 +0,0 @@
|
||||
import sys, os, json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0,script_dir+'/models/fold_and_dock2/')
|
||||
import parsers
|
||||
from RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from data_loader import merge_a3m_hetero
|
||||
import util
|
||||
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_chirals
|
||||
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS, INIT_CRDS
|
||||
from model_params import MODEL_PARAM
|
||||
|
||||
alphabet = list("ARNDCQEGHILKMFPSTWYV-")
|
||||
aa_N_1 = dict(zip(range(len(alphabet)),alphabet))
|
||||
|
||||
def model_wrapper(model, inputs):
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, pred_crds, alpha_s, pred_allatom, pred_lddt_binned, \
|
||||
msa_prev, pair_prev, state_prev = model(**inputs)
|
||||
return dict(
|
||||
logit_s=logit_s,
|
||||
logit_aa_s=logit_aa_s,
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
pred_crds=pred_crds,
|
||||
alpha_s=alpha_s,
|
||||
pred_allatom=pred_allatom,
|
||||
pred_lddt_binned=pred_lddt_binned,
|
||||
msa_prev=msa_prev,
|
||||
pair_prev=pair_prev,
|
||||
state_prev=state_prev
|
||||
)
|
||||
|
||||
def run_gradient_descent(model, args, inputs, cycles, loss_funcs):
|
||||
B, N = 1,1
|
||||
device = args.device
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
print()
|
||||
print('Starting gradient descent...')
|
||||
loss_header = ''.join([f'{loss["name"]:>12}' for loss in loss_funcs])
|
||||
print(f' step{"best loss":>12}{"curr loss":>12}{loss_header} curr seq')
|
||||
|
||||
Ls = inputs['Ls']
|
||||
msa = inputs['msa'][None].to(device) # (B, N, L)
|
||||
msa_one_hot_sm = nn.functional.one_hot(msa[:,:,Ls[0]:], num_classes=NAATOKENS).float()
|
||||
|
||||
input_logits = args.init_sd*torch.randn([B,N,Ls[0],20]).to(device).float()
|
||||
input_logits = input_logits.requires_grad_(True)
|
||||
|
||||
optimizer = NSGD([input_logits], lr=args.learning_rate*np.sqrt(Ls[0]), dim=[-1,-2])
|
||||
|
||||
best_loss = torch.full((B,),1e4).to(device)
|
||||
|
||||
for i_step in range(args.grad_steps):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# discretize protein tokens on protein residues
|
||||
msa_one_hot_prot = logits_to_probs(input_logits, output_type=args.seq_prob_type)
|
||||
|
||||
# pad with non-protein tokens on protein residues
|
||||
msa_one_hot = torch.concat([msa_one_hot_prot, torch.zeros((B,N,Ls[0],NAATOKENS-20)).to(device)], dim=-1)
|
||||
|
||||
# pad with ligand residues
|
||||
msa_one_hot = torch.concat([msa_one_hot, msa_one_hot_sm], dim=-2)
|
||||
|
||||
# predict structure
|
||||
xyz_prev = inputs['xyz_prev'].clone()
|
||||
msa_prev = None
|
||||
pair_prev = None
|
||||
alpha_prev = torch.zeros((1,sum(Ls),NTOTALDOFS,2), device=device)
|
||||
state_prev = None
|
||||
mask_recycle = inputs['mask_recycle'].clone()
|
||||
|
||||
for i_cycle in range(cycles):
|
||||
out = model_wrapper(model, dict(
|
||||
msa_one_hot=msa_one_hot,
|
||||
seq_unmasked=msa_one_hot.argmax(-1)[:,0].detach(), # (B,L)
|
||||
xyz=inputs['xyz_prev'],
|
||||
sctors=alpha_prev,
|
||||
idx=inputs['idx_pdb'],
|
||||
bond_feats=inputs['bond_feats'],
|
||||
chirals=inputs['chirals'],
|
||||
atom_frames=inputs['atom_frames'],
|
||||
t1d=inputs['t1d'],
|
||||
t2d=inputs['t2d'],
|
||||
xyz_t=inputs['xyz_t'][...,1,:],
|
||||
alpha_t=inputs['alpha_t'],
|
||||
mask_t=inputs['mask_t_2d'],
|
||||
same_chain=inputs['same_chain'],
|
||||
msa_prev=msa_prev,
|
||||
pair_prev=pair_prev,
|
||||
state_prev=state_prev,
|
||||
mask_recycle=mask_recycle,
|
||||
use_checkpoint=True
|
||||
))
|
||||
|
||||
xyz_prev = out['pred_allatom'][-1].unsqueeze(0)
|
||||
msa_prev = out['msa_prev']
|
||||
pair_prev = out['pair_prev']
|
||||
alpha_prev = out['alpha_s'][-1]
|
||||
state_prev = out['state_prev']
|
||||
mask_recycle = None
|
||||
|
||||
|
||||
# calculate loss
|
||||
curr_msa = msa_one_hot.argmax(-1).detach().clone() #(B=1,N=1,L)
|
||||
out['msa'] = curr_msa
|
||||
loss_s = [loss['func'](out) for loss in loss_funcs]
|
||||
curr_loss = torch.sum(torch.stack(loss_s), dim=0)
|
||||
|
||||
# best design so far
|
||||
if curr_loss < best_loss:
|
||||
best_loss = curr_loss
|
||||
best_loss_s = loss_s
|
||||
best_out = out
|
||||
msa = curr_msa
|
||||
best_step = i_step
|
||||
|
||||
# update sequence
|
||||
curr_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print step info
|
||||
seq_str = ''.join([aa_N_1[int(a)] for a in msa_one_hot_prot.argmax(-1)[0,0]])
|
||||
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in loss_s])
|
||||
print(f'{i_step:>6}{float(best_loss):>12.3f}{float(curr_loss):>12.3f}{loss_str} {seq_str}')
|
||||
|
||||
seq_str = ''.join([aa_N_1[int(a)] for a in msa[0,0,:Ls[0]]])
|
||||
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in best_loss_s])
|
||||
print(f' final{float(best_loss):>12.2f}{" "*(12)}{loss_str} {seq_str}')
|
||||
|
||||
best_out.update(dict(loss = best_loss, msa = msa, best_step=best_step))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
return best_out
|
||||
|
||||
def run_mcmc(model, args, inputs, cycles, loss_funcs):
|
||||
B, N = 1,1
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
print()
|
||||
print('Starting MCMC...')
|
||||
loss_header = ''.join([f'{loss["name"]:>12}' for loss in loss_funcs])
|
||||
print(f' step{"best loss":>12}{"curr loss":>12}{"accept?":>8}{loss_header} curr seq')
|
||||
|
||||
Ls = inputs['Ls']
|
||||
device = args.device
|
||||
msa = inputs['msa'][None].to(device) # (B, N, L)
|
||||
|
||||
best_loss = torch.full((B,),1e4).to(device)
|
||||
for i_step in range(args.mcmc_steps):
|
||||
|
||||
# make mutation
|
||||
i_pos = np.random.randint(Ls[0])
|
||||
aa_new = np.random.randint(20)
|
||||
msa_new = msa.clone()
|
||||
msa_new[0,0,i_pos] = aa_new # assumes B=1
|
||||
|
||||
# predict structure
|
||||
xyz_prev = inputs['xyz_prev'].clone()
|
||||
msa_prev = None
|
||||
pair_prev = None
|
||||
alpha_prev = torch.zeros((1,sum(Ls),NTOTALDOFS,2), device=msa_new.device)
|
||||
state_prev = None
|
||||
mask_recycle = inputs['mask_recycle'].clone()
|
||||
|
||||
for i_cycle in range(cycles):
|
||||
out = model_wrapper(model, dict(
|
||||
msa_one_hot=nn.functional.one_hot(msa_new, num_classes=NAATOKENS).float(),
|
||||
seq_unmasked=msa_new[:,0],
|
||||
xyz=inputs['xyz_prev'],
|
||||
sctors=alpha_prev,
|
||||
idx=inputs['idx_pdb'],
|
||||
bond_feats=inputs['bond_feats'],
|
||||
chirals=inputs['chirals'],
|
||||
atom_frames=inputs['atom_frames'],
|
||||
t1d=inputs['t1d'],
|
||||
t2d=inputs['t2d'],
|
||||
xyz_t=inputs['xyz_t'][...,1,:],
|
||||
alpha_t=inputs['alpha_t'],
|
||||
mask_t=inputs['mask_t_2d'],
|
||||
same_chain=inputs['same_chain'],
|
||||
msa_prev=msa_prev,
|
||||
pair_prev=pair_prev,
|
||||
state_prev=state_prev,
|
||||
mask_recycle=mask_recycle
|
||||
))
|
||||
|
||||
xyz_prev = out['pred_allatom'][-1].unsqueeze(0)
|
||||
msa_prev = out['msa_prev']
|
||||
pair_prev = out['pair_prev']
|
||||
alpha_prev = out['alpha_s'][-1]
|
||||
state_prev = out['state_prev']
|
||||
mask_recycle = None
|
||||
|
||||
# calculate loss
|
||||
curr_msa = msa_new.detach().clone()
|
||||
out['msa'] = curr_msa
|
||||
loss_s = [loss['func'](out) for loss in loss_funcs]
|
||||
curr_loss = torch.sum(torch.stack(loss_s), dim=0)
|
||||
|
||||
# batch-wise Metropolis update
|
||||
T = args.T0*(0.5)**(i_step/args.mcmc_halflife)
|
||||
p_accept = torch.clamp(torch.exp(-(curr_loss-best_loss)/T), min=0.0, max=1.0)
|
||||
accept = torch.rand(1).to(device) < p_accept
|
||||
if accept:
|
||||
best_loss = curr_loss
|
||||
best_loss_s = loss_s
|
||||
msa = curr_msa
|
||||
best_out = out
|
||||
|
||||
seq_str = ''.join([aa_N_1[int(a)] for a in msa_new[0,~util.is_atom(msa_new)[0]]])
|
||||
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in loss_s])
|
||||
print(f'{i_step:>6}{float(best_loss):>12.2f}{float(curr_loss):>12.2f}{int(accept):>8}{loss_str} {seq_str}')
|
||||
|
||||
seq_str = ''.join([aa_N_1[int(a)] for a in msa[0,0,:Ls[0]]])
|
||||
loss_str = ''.join([f'{float(loss_val):>12.3f}' for loss_val in best_loss_s])
|
||||
print(f' final{float(best_loss):>12.2f}{" "*(12+8)}{loss_str} {seq_str}')
|
||||
best_out.update(dict(loss = best_loss, msa = msa, loss_terms = best_loss_s))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
return best_out
|
||||
|
||||
|
||||
def logits_to_probs(logits, output_type='hard', temp=1, add_gumbel_noise=False, eps=1e-8):
|
||||
device = logits.device
|
||||
B, N, L, A = logits.shape
|
||||
if add_gumbel_noise:
|
||||
U = torch.rand(logits.shape)
|
||||
noise = -torch.log(-torch.log(U + eps) + eps)
|
||||
noise = noise.to(device)
|
||||
logits = logits + noise
|
||||
|
||||
y_soft = torch.nn.functional.softmax(logits/temp, -1)
|
||||
if output_type == 'soft':
|
||||
return y_soft
|
||||
elif output_type == 'hard':
|
||||
n_cat = y_soft.shape[-1]
|
||||
y_oh = torch.nn.functional.one_hot(y_soft.argmax(-1), n_cat)
|
||||
y_hard = (y_oh - y_soft).detach() + y_soft
|
||||
return y_hard
|
||||
else:
|
||||
raise NotImplementedError('Output type must be "soft" or "hard"')
|
||||
|
||||
|
||||
class NSGD(torch.optim.Optimizer):
|
||||
def __init__(self, params, lr, dim):
|
||||
defaults = dict(lr=lr)
|
||||
super(NSGD, self).__init__(params, defaults)
|
||||
self.dim=dim
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None: continue
|
||||
d_p = p.grad / (torch.norm(p.grad, dim=self.dim, keepdim=True) + 1e-8)
|
||||
p.add_(d_p, alpha=-group['lr'])
|
||||
return loss
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
num2aa=[
|
||||
'ALA','ARG','ASN','ASP','CYS',
|
||||
'GLN','GLU','GLY','HIS','ILE',
|
||||
'LEU','LYS','MET','PHE','PRO',
|
||||
'SER','THR','TRP','TYR','VAL',
|
||||
'UNK','MAS',
|
||||
]
|
||||
|
||||
aa2num= {x:i for i,x in enumerate(num2aa)}
|
||||
|
||||
# full sc atom representation (Nx14)
|
||||
aa2long=[
|
||||
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
|
||||
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
|
||||
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
||||
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
|
||||
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
|
||||
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
|
||||
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
|
||||
(" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
|
||||
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
|
||||
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
|
||||
]
|
||||
|
||||
def parse_pdb(filename, **kwargs):
|
||||
'''extract xyz coords for all heavy atoms'''
|
||||
lines = open(filename,'r').readlines()
|
||||
return parse_pdb_lines(lines, **kwargs)
|
||||
|
||||
def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True):
|
||||
# indices of residues observed in the structure
|
||||
res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
|
||||
seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]
|
||||
pdb_idx = [( l[21:22].strip(), int(l[22:26].strip()) ) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # chain letter, res num
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
chain, resNo, atom, aa = l[21:22], int(l[22:26]), ' '+l[12:16].strip().ljust(3), l[17:20]
|
||||
idx = pdb_idx.index((chain,resNo))
|
||||
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
|
||||
if tgtatm is not None and tgtatm.strip() == atom.strip(): # ignore whitespace
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
|
||||
# remove duplicated (chain, resi)
|
||||
new_idx = []
|
||||
i_unique = []
|
||||
for i,idx in enumerate(pdb_idx):
|
||||
if idx not in new_idx:
|
||||
new_idx.append(idx)
|
||||
i_unique.append(i)
|
||||
|
||||
pdb_idx = new_idx
|
||||
xyz = xyz[i_unique]
|
||||
mask = mask[i_unique]
|
||||
seq = np.array(seq)[i_unique]
|
||||
|
||||
out = {'xyz':xyz, # cartesian coordinates, [Lx14]
|
||||
'mask':mask, # mask showing which atoms are present in the PDB file, [Lx14]
|
||||
'idx':np.array([i[1] for i in pdb_idx]), # residue numbers in the PDB file, [L]
|
||||
'seq':np.array(seq), # amino acid sequence, [L]
|
||||
'pdb_idx': pdb_idx, # list of (chain letter, residue number) in the pdb file, [L]
|
||||
}
|
||||
|
||||
# heteroatoms (ligands, etc)
|
||||
if parse_hetatom:
|
||||
xyz_het, info_het = [], []
|
||||
for l in lines:
|
||||
if l[:6]=='HETATM' and not (ignore_het_h and l[77]=='H'):
|
||||
info_het.append(dict(
|
||||
idx=int(l[7:11]),
|
||||
atom_id=l[12:16],
|
||||
atom_type=l[77],
|
||||
name=l[16:20]
|
||||
))
|
||||
xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])])
|
||||
|
||||
out['xyz_het'] = np.array(xyz_het)
|
||||
out['info_het'] = info_het
|
||||
|
||||
return out
|
||||
|
||||
class ContigMap():
|
||||
'''
|
||||
New class for doing mapping.
|
||||
Supports multichain or multiple crops from a single receptor chain.
|
||||
Also supports indexing jump (+200) or not, based on contig input.
|
||||
Default chain outputs are inpainted chains as A (and B, C etc if multiple chains), and all fragments of receptor chain on the next one (generally B)
|
||||
Output chains can be specified. Sequence must be the same number of elements as in contig string
|
||||
'''
|
||||
def __init__(self, pdb_idx, contigs=None, inpaint_seq=None, inpaint_str=None, length=None, ref_idx=None, hal_idx=None, idx_rf=None, inpaint_seq_tensor=None, inpaint_str_tensor=None, topo=False):
|
||||
#sanity checks
|
||||
if contigs is None and ref_idx is None:
|
||||
sys.exit("Must either specify a contig string or precise mapping")
|
||||
if idx_rf is not None or hal_idx is not None or ref_idx is not None:
|
||||
if idx_rf is None or hal_idx is None or ref_idx is None:
|
||||
sys.exit("If you're specifying specific contig mappings, the reference and output positions must be specified, AND the indexing for RoseTTAFold (idx_rf)")
|
||||
|
||||
self.chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
if length is not None:
|
||||
if '-' not in length:
|
||||
self.length = [int(length),int(length)+1]
|
||||
else:
|
||||
self.length = [int(length.split("-")[0]),int(length.split("-")[1])+1]
|
||||
else:
|
||||
self.length = None
|
||||
self.ref_idx = ref_idx
|
||||
self.hal_idx=hal_idx
|
||||
self.idx_rf=idx_rf
|
||||
self.inpaint_seq = ','.join(inpaint_seq).split(",") if inpaint_seq is not None else None
|
||||
self.inpaint_str = ','.join(inpaint_str).split(",") if inpaint_str is not None else None
|
||||
self.inpaint_seq_tensor=inpaint_seq_tensor
|
||||
self.inpaint_str_tensor=inpaint_str_tensor
|
||||
self.pdb_idx = pdb_idx
|
||||
self.topo=topo
|
||||
if ref_idx is None:
|
||||
#using default contig generation, which outputs in rosetta-like format
|
||||
self.contigs=contigs
|
||||
self.sampled_mask,self.contig_length,self.n_inpaint_chains = self.get_sampled_mask()
|
||||
self.receptor_chain = self.chain_order[self.n_inpaint_chains]
|
||||
self.receptor, self.receptor_hal, self.receptor_rf, self.inpaint, self.inpaint_hal, self.inpaint_rf= self.expand_sampled_mask()
|
||||
self.ref = self.inpaint + self.receptor
|
||||
self.hal = self.inpaint_hal + self.receptor_hal
|
||||
self.rf = self.inpaint_rf + self.receptor_rf
|
||||
else:
|
||||
#specifying precise mappings
|
||||
self.ref=ref_idx
|
||||
self.hal=hal_idx
|
||||
self.rf = rf_idx
|
||||
self.mask_1d = [False if i == ('_','_') else True for i in self.ref]
|
||||
|
||||
#take care of sequence and structure masking
|
||||
if self.inpaint_seq_tensor is None:
|
||||
if self.inpaint_seq is not None:
|
||||
self.inpaint_seq = self.get_inpaint_seq_str(self.inpaint_seq)
|
||||
else:
|
||||
self.inpaint_seq = np.array([True if i != ('_','_') else False for i in self.ref])
|
||||
else:
|
||||
self.inpaint_seq = self.inpaint_seq_tensor
|
||||
|
||||
if self.inpaint_str_tensor is None:
|
||||
if self.inpaint_str is not None:
|
||||
self.inpaint_str = self.get_inpaint_seq_str(self.inpaint_str)
|
||||
else:
|
||||
self.inpaint_str = np.array([True if i != ('_','_') else False for i in self.ref])
|
||||
else:
|
||||
self.inpaint_str = self.inpaint_str_tensor
|
||||
#get 0-indexed input/output (for trb file)
|
||||
self.ref_idx0,self.hal_idx0, self.ref_idx0_inpaint, self.hal_idx0_inpaint, self.ref_idx0_receptor, self.hal_idx0_receptor=self.get_idx0()
|
||||
|
||||
def get_sampled_mask(self):
|
||||
'''
|
||||
Function to get a sampled mask from a contig.
|
||||
'''
|
||||
length_compatible=False
|
||||
count = 0
|
||||
while length_compatible is False:
|
||||
inpaint_chains=0
|
||||
contig_list = self.contigs[0].strip().split()
|
||||
sampled_mask = []
|
||||
sampled_mask_length = 0
|
||||
#allow receptor chain to be last in contig string
|
||||
if all([i[0].isalpha() for i in contig_list[-1].split(",")]):
|
||||
contig_list[-1] = f'{contig_list[-1]},0'
|
||||
for con in contig_list:
|
||||
if ((all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0')) or self.topo is True:
|
||||
#receptor chain
|
||||
sampled_mask.append(con)
|
||||
else:
|
||||
inpaint_chains += 1
|
||||
#chain to be inpainted. These are the only chains that count towards the length of the contig
|
||||
subcons = con.split(",")
|
||||
subcon_out = []
|
||||
for subcon in subcons:
|
||||
if subcon[0].isalpha():
|
||||
subcon_out.append(subcon)
|
||||
if '-' in subcon:
|
||||
sampled_mask_length += (int(subcon.split("-")[1])-int(subcon.split("-")[0][1:])+1)
|
||||
else:
|
||||
sampled_mask_length += 1
|
||||
|
||||
else:
|
||||
if '-' in subcon:
|
||||
length_inpaint=np.random.randint(int(subcon.split("-")[0]),int(subcon.split("-")[1]))
|
||||
subcon_out.append(f'{length_inpaint}-{length_inpaint}')
|
||||
sampled_mask_length += length_inpaint
|
||||
elif subcon == '0':
|
||||
subcon_out.append('0')
|
||||
else:
|
||||
length_inpaint=int(subcon)
|
||||
subcon_out.append(f'{length_inpaint}-{length_inpaint}')
|
||||
sampled_mask_length += int(subcon)
|
||||
sampled_mask.append(','.join(subcon_out))
|
||||
#check length is compatible
|
||||
if self.length is not None:
|
||||
if sampled_mask_length >= self.length[0] and sampled_mask_length < self.length[1]:
|
||||
length_compatible = True
|
||||
else:
|
||||
length_compatible = True
|
||||
count+=1
|
||||
if count == 100000: #contig string incompatible with this length
|
||||
sys.exit("Contig string incompatible with --length range")
|
||||
return sampled_mask, sampled_mask_length, inpaint_chains
|
||||
|
||||
def expand_sampled_mask(self):
|
||||
chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
receptor = []
|
||||
inpaint = []
|
||||
receptor_hal = []
|
||||
inpaint_hal = []
|
||||
receptor_idx = 1
|
||||
inpaint_idx = 1
|
||||
inpaint_chain_idx=-1
|
||||
receptor_chain_break=[]
|
||||
inpaint_chain_break = []
|
||||
for con in self.sampled_mask:
|
||||
if (all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0') or self.topo is True:
|
||||
#receptor chain
|
||||
subcons = con.split(",")[:-1]
|
||||
assert all([i[0] == subcons[0][0] for i in subcons]), "If specifying fragmented receptor in a single block of the contig string, they MUST derive from the same chain"
|
||||
assert all(int(subcons[i].split("-")[0][1:]) < int(subcons[i+1].split("-")[0][1:]) for i in range(len(subcons)-1)), "If specifying multiple fragments from the same chain, pdb indices must be in ascending order!"
|
||||
for idx, subcon in enumerate(subcons):
|
||||
ref_to_add = [(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)]
|
||||
receptor.extend(ref_to_add)
|
||||
receptor_hal.extend([(self.receptor_chain,i) for i in np.arange(receptor_idx, receptor_idx+len(ref_to_add))])
|
||||
receptor_idx += len(ref_to_add)
|
||||
if idx != len(subcons)-1:
|
||||
idx_jump = int(subcons[idx+1].split("-")[0][1:]) - int(subcon.split("-")[1]) -1
|
||||
receptor_chain_break.append((receptor_idx-1,idx_jump)) #actual chain break in pdb chain
|
||||
else:
|
||||
receptor_chain_break.append((receptor_idx-1,200)) #200 aa chain break
|
||||
else:
|
||||
inpaint_chain_idx += 1
|
||||
for subcon in con.split(","):
|
||||
if subcon[0].isalpha():
|
||||
ref_to_add=[(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)]
|
||||
inpaint.extend(ref_to_add)
|
||||
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+len(ref_to_add))])
|
||||
inpaint_idx += len(ref_to_add)
|
||||
|
||||
else:
|
||||
inpaint.extend([('_','_')] * int(subcon.split("-")[0]))
|
||||
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+int(subcon.split("-")[0]))])
|
||||
inpaint_idx += int(subcon.split("-")[0])
|
||||
inpaint_chain_break.append((inpaint_idx-1,200))
|
||||
|
||||
if self.topo is True or inpaint_hal == []:
|
||||
receptor_hal = [(i[0], i[1]) for i in receptor_hal]
|
||||
else:
|
||||
receptor_hal = [(i[0], i[1] + inpaint_hal[-1][1]) for i in receptor_hal] #rosetta-like numbering
|
||||
#get rf indexes, with chain breaks
|
||||
inpaint_rf = np.arange(0,len(inpaint))
|
||||
receptor_rf = np.arange(len(inpaint)+200,len(inpaint)+len(receptor)+200)
|
||||
for ch_break in inpaint_chain_break[:-1]:
|
||||
receptor_rf[:] += 200
|
||||
inpaint_rf[ch_break[0]:] += ch_break[1]
|
||||
for ch_break in receptor_chain_break[:-1]:
|
||||
receptor_rf[ch_break[0]:] += ch_break[1]
|
||||
|
||||
return receptor, receptor_hal, receptor_rf.tolist(), inpaint, inpaint_hal, inpaint_rf.tolist()
|
||||
|
||||
def get_inpaint_seq_str(self, inpaint_s):
|
||||
'''
|
||||
function to generate inpaint_str or inpaint_seq masks specific to this contig
|
||||
'''
|
||||
s_mask = np.copy(self.mask_1d)
|
||||
inpaint_s_list = []
|
||||
for i in inpaint_s:
|
||||
if '-' in i:
|
||||
inpaint_s_list.extend([(i[0],p) for p in range(int(i.split("-")[0][1:]), int(i.split("-")[1])+1)])
|
||||
else:
|
||||
inpaint_s_list.append((i[0],int(i[1:])))
|
||||
for res in inpaint_s_list:
|
||||
if res in self.ref:
|
||||
s_mask[self.ref.index(res)] = False #mask this residue
|
||||
|
||||
return np.array(s_mask)
|
||||
|
||||
def get_idx0(self):
|
||||
ref_idx0=[]
|
||||
hal_idx0=[]
|
||||
ref_idx0_inpaint=[]
|
||||
hal_idx0_inpaint=[]
|
||||
ref_idx0_receptor=[]
|
||||
hal_idx0_receptor=[]
|
||||
for idx, val in enumerate(self.ref):
|
||||
if val != ('_','_'):
|
||||
assert val in self.pdb_idx,f"{val} is not in pdb file!"
|
||||
hal_idx0.append(idx)
|
||||
ref_idx0.append(self.pdb_idx.index(val))
|
||||
for idx, val in enumerate(self.inpaint):
|
||||
if val != ('_','_'):
|
||||
hal_idx0_inpaint.append(idx)
|
||||
ref_idx0_inpaint.append(self.pdb_idx.index(val))
|
||||
for idx, val in enumerate(self.receptor):
|
||||
if val != ('_','_'):
|
||||
hal_idx0_receptor.append(idx)
|
||||
ref_idx0_receptor.append(self.pdb_idx.index(val))
|
||||
|
||||
|
||||
return ref_idx0, hal_idx0, ref_idx0_inpaint, hal_idx0_inpaint, ref_idx0_receptor, hal_idx0_receptor
|
||||
|
||||
def get_mappings(rm):
|
||||
mappings = {}
|
||||
mappings['con_ref_pdb_idx'] = [i for i in rm.inpaint if i != ('_','_')]
|
||||
mappings['con_hal_pdb_idx'] = [rm.inpaint_hal[i] for i in range(len(rm.inpaint_hal)) if rm.inpaint[i] != ("_","_")]
|
||||
mappings['con_ref_idx0'] = rm.ref_idx0_inpaint
|
||||
mappings['con_hal_idx0'] = rm.hal_idx0_inpaint
|
||||
if rm.inpaint != rm.ref:
|
||||
mappings['complex_con_ref_pdb_idx'] = [i for i in rm.ref if i != ("_","_")]
|
||||
mappings['complex_con_hal_pdb_idx'] = [rm.hal[i] for i in range(len(rm.hal)) if rm.ref[i] != ("_","_")]
|
||||
mappings['receptor_con_ref_pdb_idx'] = [i for i in rm.receptor if i != ("_","_")]
|
||||
mappings['receptor_con_hal_pdb_idx'] = [rm.receptor_hal[i] for i in range(len(rm.receptor_hal)) if rm.receptor[i] != ("_","_")]
|
||||
mappings['complex_con_ref_idx0'] = rm.ref_idx0
|
||||
mappings['complex_con_hal_idx0'] = rm.hal_idx0
|
||||
mappings['receptor_con_ref_idx0'] = rm.ref_idx0_receptor
|
||||
mappings['receptor_con_hal_idx0'] = rm.hal_idx0_receptor
|
||||
mappings['inpaint_str'] = rm.inpaint_str
|
||||
mappings['inpaint_seq'] = rm.inpaint_seq
|
||||
mappings['sampled_mask'] = rm.sampled_mask
|
||||
mappings['mask_1d'] = rm.mask_1d
|
||||
return mappings
|
||||
@@ -56,8 +56,13 @@ def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra
|
||||
clamp = torch.zeros_like(difference)
|
||||
clamp[:,same_chain==1] = d_clamp_intra
|
||||
clamp[:,same_chain==0] = d_clamp_inter
|
||||
difference = torch.clamp(difference, max=clamp)
|
||||
loss = difference / A # (I, B, L, L)
|
||||
|
||||
mixing_factor = 0.9
|
||||
unclamped_difference = difference.clone()
|
||||
clamped_difference = torch.clamp(difference, max=clamp)
|
||||
|
||||
clamped_loss = clamped_difference / A
|
||||
unclamped_loss = unclamped_difference/A # (I, B, L, L)
|
||||
|
||||
# Get a mask information (ignore missing residue + inter-chain residues)
|
||||
# for positive cases, mask = mask_2d
|
||||
@@ -66,8 +71,11 @@ def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra
|
||||
mask = mask_2d * same_chain
|
||||
else:
|
||||
mask = mask_2d
|
||||
|
||||
# calculate masked loss (ignore missing regions when calculate loss)
|
||||
loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
|
||||
clamped_loss = (mask[None]*clamped_loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
|
||||
unclamped_loss = (mask[None]*unclamped_loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
|
||||
loss = mixing_factor *clamped_loss + (1-mixing_factor)*unclamped_loss
|
||||
|
||||
# weighting loss
|
||||
w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
|
||||
@@ -608,7 +616,7 @@ def compute_pde_loss(X, Y, logit_pde, atom_mask, pde_bin_step=0.3, frame_atom_ma
|
||||
|
||||
# from Ivan: FAPE generalized over atom sets & frames
|
||||
def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=None, frame_atom_mask_2d=None,
|
||||
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, dclamp_2d=None, gamma=0.99, eps=1e-4):
|
||||
logit_pae=None, logit_pde=None, Z=10.0, dclamp=10.0, dclamp_2d=None, gamma=0.99, mixing_factor=0.9, eps=1e-4):
|
||||
|
||||
# X (predicted) N x L x natoms x 3
|
||||
# Y (native) 1 x L x natoms x 3
|
||||
@@ -617,7 +625,12 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
|
||||
# frame_mask 1 x L x nframes
|
||||
# frame_atom_mask 1 x L x natoms masks the frames over which fape is calculated
|
||||
# frame_atom_mask_2d 1 x L x nframes x L x natoms 2d mask, 2nd dimension frames, 3rd/4th dimension atoms so fape can be taken over some atoms for some frames (only works for BB fape)
|
||||
|
||||
# logit_pae
|
||||
# logit_pde
|
||||
# dclamp int
|
||||
# dclamp_2d
|
||||
# gamma
|
||||
# mixing_factor int in cases where the loss is clamped, allows gradients to flow by mixing the clamped and unclamped loss
|
||||
if frame_atom_mask is None:
|
||||
frame_atom_mask = atom_mask
|
||||
|
||||
@@ -652,6 +665,9 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
|
||||
# multiply diff by frame_atom_mask_2d if frame_atom_mask_2d not None
|
||||
if frame_atom_mask_2d is not None:
|
||||
diff = diff*frame_atom_mask_2d[:,frame_mask[0]][:, :, atom_mask[0]]
|
||||
N_values = torch.sum(frame_atom_mask_2d[:,frame_mask[0]][:, :, atom_mask[0]])
|
||||
else:
|
||||
N_values = diff.shape[1]*diff.shape[2] # frame dimension * atom dimension
|
||||
|
||||
assert dclamp is not None or dclamp_2d is not None, "need to provide either dclamp or dclamp_2d to compute_general_FAPE"
|
||||
assert not (dclamp is not None and dclamp_2d is not None), "you provided both dclamp and dclamp_2d, please only provide one"
|
||||
@@ -659,8 +675,10 @@ def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, frame_atom_mask=No
|
||||
if dclamp_2d is not None:
|
||||
dclamp = dclamp_2d[:,frame_mask[0]][:, :, atom_mask[0]]
|
||||
|
||||
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2))
|
||||
clamped_loss = (1.0/Z) * torch.sum(torch.clamp(diff, max=dclamp), dim=(1,2))/(N_values)
|
||||
unclamped_loss = (1.0/Z) *torch.sum(diff, dim=(1,2))/(N_values)
|
||||
|
||||
loss = mixing_factor * clamped_loss + (1-mixing_factor) * unclamped_loss
|
||||
pae_loss = compute_pae_loss(X, X_y, uX, Y, Y_y, uY, logit_pae, frame_mask, atom_mask, frame_atom_mask_2d=frame_atom_mask_2d) \
|
||||
if logit_pae is not None \
|
||||
else torch.tensor(0).to(frames.device)
|
||||
@@ -683,26 +701,26 @@ def mask_unresolved_frames(frames, frame_mask, atom_mask):
|
||||
"""
|
||||
B, L, natoms = atom_mask.shape
|
||||
|
||||
frame_mask_update = frame_mask.clone()
|
||||
# reindex frames for flat X
|
||||
frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device)
|
||||
for i in range(L):
|
||||
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
|
||||
frames_reindex = frames_reindex.long()
|
||||
frames_reindex = (
|
||||
torch.arange(L, device=frames.device)[None,:,None,None] + frames[..., 0]
|
||||
)*natoms + frames[..., 1]
|
||||
|
||||
masked_atom_frames = torch.any(frames_reindex>L*natoms, dim=-1) # find frames with atoms that aren't resolved
|
||||
masked_atom_frames *= torch.any(frames_reindex<0, dim=-1)
|
||||
frame_mask_update *= ~masked_atom_frames
|
||||
# There are currently indices for frames that aren't in the coordinates bc they arent resolved, reset these indices to 0 to avoid
|
||||
# indexing errors
|
||||
frames_reindex[masked_atom_frames, :] = 0
|
||||
|
||||
frame_mask_update = frame_mask.clone()
|
||||
frame_mask_update *= ~masked_atom_frames
|
||||
frame_mask_update *= torch.all(
|
||||
torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
|
||||
axis=-1)
|
||||
|
||||
return frames_reindex, frame_mask_update
|
||||
|
||||
|
||||
def calc_crd_rmsd(pred, true, atom_mask, rmsd_mask=None, alignment_radius=None):
|
||||
'''
|
||||
Calculate coordinate RMSD
|
||||
@@ -1096,6 +1114,121 @@ def calc_clash(xs, mask):
|
||||
clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask)
|
||||
return clash
|
||||
|
||||
def calc_l1_clash_loss(xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds, lj_hb_dis=3.0, \
|
||||
lj_OHdon_dis=2.6, lj_hbond_hdis=1.75, tolerance=1.5, agg="sum", eps=1e-8):
|
||||
"""
|
||||
The LJ potential doesnt work for large systems because it means the energy over all atoms, so 1-2 clashes wash
|
||||
out of the loss. To remedy this, we want to only apply a loss on the violating elements. if doing this, the
|
||||
contribution of the london dispersion forces no longer matter (you will never evaluate these because they
|
||||
are not violations). Thus, switching the functional form of this to be a flat bottom L1 loss similar to the
|
||||
other bond losses in RF.
|
||||
"""
|
||||
def compute_loss(deltas, ljrs, tolerance, eps=1e-8):
|
||||
"""
|
||||
deltas: differences between atom positions between valid pairs (natom_pairs, 3)
|
||||
ljrs: lj radii of valid pairs
|
||||
tolerance: how large the deviation can be (this could become a more complicated hyperparameter)
|
||||
"""
|
||||
dist = torch.sqrt( torch.sum ( torch.square( deltas ), dim=-1 ) + eps ) # compute distances
|
||||
deviations = torch.clamp(ljrs - dist - tolerance, min=0)
|
||||
return torch.sum(deviations), torch.sum(deviations > 0)
|
||||
|
||||
N, L = xs.shape[:2]
|
||||
rs = torch.triu_indices(L,L,0, device=xs.device)
|
||||
ri,rj = rs[0],rs[1]
|
||||
|
||||
# batch during inference for huge systems
|
||||
BATCHSIZE = 65536//N
|
||||
|
||||
running_clash_val = 0
|
||||
running_num_violations = 0
|
||||
#NOTE: a lot of this is similar to calc_lj -- probably should move into its own fx
|
||||
for i_batch in range((len(ri)-1)//BATCHSIZE + 1):
|
||||
idx = torch.arange(
|
||||
i_batch*BATCHSIZE,
|
||||
min( (i_batch+1)*BATCHSIZE, len(ri)),
|
||||
device=xs.device
|
||||
)
|
||||
rii,rjj = ri[idx],rj[idx] # residue pairs we consider
|
||||
|
||||
ridx,ai,aj = (
|
||||
aamask[seq[rii]][:,:,None]*aamask[seq[rjj]][:,None,:]
|
||||
).nonzero(as_tuple=True)
|
||||
|
||||
deltas = xs[:,rii,:,None,:]-xs[:,rjj,None,:,:] # N,BATCHSIZE,Natm,Natm,3
|
||||
seqi,seqj = seq[rii[ridx]], seq[rjj[ridx]]
|
||||
|
||||
mask = torch.ones_like(ridx, dtype=torch.bool) # are atoms defined?
|
||||
|
||||
# mask out atom pairs from too-distant residues (C-alpha dist > 24A)
|
||||
ca_dist = torch.linalg.norm(deltas[:,:,1,1],dim=-1)
|
||||
mask *= (ca_dist[:,ridx]<24).any(dim=0) # will work for batch>1 but very inefficient
|
||||
|
||||
intrares = (rii[ridx]==rjj[ridx])
|
||||
mask[intrares*(ai<aj)] = False # upper tri (atoms)
|
||||
|
||||
## count-pair
|
||||
# a) intra-protein
|
||||
mask[intrares] *= num_bonds[seqi[intrares],ai[intrares],aj[intrares]]>=4
|
||||
pepbondres = ri[ridx]+1==rj[ridx]
|
||||
mask[pepbondres] *= (
|
||||
num_bonds[seqi[pepbondres],ai[pepbondres],2]
|
||||
+ num_bonds[seqj[pepbondres],0,aj[pepbondres]]
|
||||
+ 1) >=4
|
||||
|
||||
# b) intra-ligand
|
||||
atommask = (ai==1)*(aj==1)
|
||||
dist_matrix = torch.nan_to_num(dist_matrix, posinf=4.0) #NOTE: need to run nan_to_num to remove infinities
|
||||
resmask = (dist_matrix[0,rii,rjj] >= 4) # * will only work for batch=1
|
||||
mask[atommask] *= resmask[ ridx[atommask] ]
|
||||
|
||||
# c) protein/ligand
|
||||
##fd NOTE1: changed 6->5 in masking (atom 5 is CG which should always be 4+ bonds away from connected atom)
|
||||
##fd NOTE2: this does NOT work correctly for nucleic acids
|
||||
##fd for NAs atoms 0-4 are masked, but also 5,7,8 and 9 should be masked!
|
||||
bbatommask = (ai<5)*(aj<5)
|
||||
resmask = (bond_feats[0,rii,rjj] != 6) # * will only work for batch=1
|
||||
mask[bbatommask] *= resmask[ ridx[bbatommask] ]
|
||||
# d) potential disulfide
|
||||
# disulfide correction previously this was done by reducing well depth but
|
||||
# since this loss does not have a well depth it is done by masking the pairs of atoms
|
||||
potential_disulf = (ljcorr[seqi,ai,3]*ljcorr[seqj,aj,3] )
|
||||
mask *= ~potential_disulf
|
||||
|
||||
# apply mask. only interactions to be scored remain
|
||||
ai,aj,seqi,seqj,ridx = ai[mask],aj[mask],seqi[mask],seqj[mask],ridx[mask]
|
||||
deltas = deltas[:,ridx,ai,aj]
|
||||
|
||||
# hbond correction
|
||||
use_hb_dis = (
|
||||
ljcorr[seqi,ai,0]*ljcorr[seqj,aj,1]
|
||||
+ ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0] ).nonzero()
|
||||
use_ohdon_dis = ( # OH are both donors & acceptors
|
||||
ljcorr[seqi,ai,0]*ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0]
|
||||
+ljcorr[seqi,ai,0]*ljcorr[seqj,aj,0]*ljcorr[seqj,aj,1]
|
||||
).nonzero()
|
||||
use_hb_hdis = (
|
||||
ljcorr[seqi,ai,2]*ljcorr[seqj,aj,1]
|
||||
+ljcorr[seqi,ai,1]*ljcorr[seqj,aj,2]
|
||||
).nonzero()
|
||||
|
||||
ljrs = ljparams[seqi,ai,0] + ljparams[seqj,aj,0]
|
||||
ljrs[use_hb_dis] = lj_hb_dis
|
||||
ljrs[use_ohdon_dis] = lj_OHdon_dis
|
||||
ljrs[use_hb_hdis] = lj_hbond_hdis
|
||||
|
||||
clash_val, num_violations = compute_loss(deltas, ljrs, tolerance, eps=eps)
|
||||
running_clash_val += clash_val
|
||||
running_num_violations += num_violations
|
||||
|
||||
# aggregate values from all batches
|
||||
if agg=="sum":
|
||||
return running_clash_val, running_num_violations
|
||||
elif agg=="mean_over_viol":
|
||||
return running_clash_val/running_num_violations, running_num_violations
|
||||
else:
|
||||
raise NotImplementedError("{agg} not an accepted aggregation method for clash loss")
|
||||
|
||||
|
||||
#fd more efficient LJ loss
|
||||
class LJLoss(torch.autograd.Function):
|
||||
630
rf2aa/loss/loss_factory.py
Normal file
630
rf2aa/loss/loss_factory.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
from rf2aa.chemical import NAATOKENS
|
||||
from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins
|
||||
from rf2aa.loss.loss import resolve_equiv_natives, resolve_equiv_natives_asmb, \
|
||||
resolve_symmetry_predictions, resolve_symmetry, mask_unresolved_frames, \
|
||||
compute_general_FAPE, torsionAngleLoss, calc_lddt, calc_allatom_lddt_loss, \
|
||||
calc_crd_rmsd, calc_BB_bond_geom, calc_l1_clash_loss, calc_atom_bond_loss
|
||||
from rf2aa.util import is_atom, is_protein, Ls_from_same_chain_2d, get_prot_sm_mask, \
|
||||
xyz_to_frame_xyz, get_frames, NTOTALDOFS, NTOTAL
|
||||
|
||||
cce_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
|
||||
|
||||
def get_loss_and_misc(
|
||||
trainer,
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
loss_param
|
||||
):
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
|
||||
if pred_allatom is None:
|
||||
_, pred_allatom = trainer.xyz_converter.compute_all_atom(msa[0][0][None],pred_crds[-1][None], alphas[-1][None])
|
||||
pred_crds = pred_crds[:, None]
|
||||
alphas = alphas[:, None]
|
||||
if (symmRs is not None):
|
||||
#print ('a', pred_crds.shape, true_crds.shape, mask_crds.shape)
|
||||
###
|
||||
# resolve symmetry
|
||||
###
|
||||
true_crds = true_crds[:,0]
|
||||
atom_mask = atom_mask[:,0]
|
||||
mapT2P = resolve_symmetry_predictions(pred_crds, true_crds, atom_mask, Lasu) # (Nlayer, Ltrue)
|
||||
|
||||
# update all derived data to only include subunits mapping to native
|
||||
logit_s_new = []
|
||||
for li in logit_s:
|
||||
li=torch.gather(li,2,mapT2P[-1][None,None,:,None].repeat(1,li.shape[1],1,li.shape[-1]))
|
||||
li=torch.gather(li,3,mapT2P[-1][None,None,None,:].repeat(1,li.shape[1],li.shape[2],1))
|
||||
logit_s_new.append(li)
|
||||
logit_s = tuple(logit_s_new)
|
||||
|
||||
logit_aa_s = logit_aa_s.view(1,NAATOKENS,msa.shape[-2],msa.shape[-1])
|
||||
logit_aa_s = torch.gather(logit_aa_s,3,mapT2P[-1][None,None,None,:].repeat(1,NAATOKENS,logit_aa_s.shape[-2],1))
|
||||
logit_aa_s = logit_aa_s.view(1,NAATOKENS,-1)
|
||||
|
||||
msa = torch.gather(msa,2,mapT2P[-1][None,None,:].repeat(1,msa.shape[-2],1))
|
||||
mask_msa = torch.gather(mask_msa,2,mapT2P[-1][None,None,:].repeat(1,mask_msa.shape[-2],1))
|
||||
|
||||
logit_pae=torch.gather(logit_pae,2,mapT2P[-1][None,None,:,None].repeat(1,logit_pae.shape[1],1,logit_pae.shape[-1]))
|
||||
logit_pae=torch.gather(logit_pae,3,mapT2P[-1][None,None,None,:].repeat(1,logit_pae.shape[1],logit_pae.shape[2],1))
|
||||
|
||||
logit_pde=torch.gather(logit_pde,2,mapT2P[-1][None,None,:,None].repeat(1,logit_pde.shape[1],1,logit_pde.shape[-1]))
|
||||
logit_pde=torch.gather(logit_pde,3,mapT2P[-1][None,None,None,:].repeat(1,logit_pde.shape[1],logit_pde.shape[2],1))
|
||||
|
||||
pred_crds = torch.gather(pred_crds,2,mapT2P[:,None,:,None,None].repeat(1,1,1,3,3))
|
||||
pred_allatom = torch.gather(pred_allatom,1,mapT2P[-1,None,:,None,None].repeat(1,1,NTOTAL,3))
|
||||
alphas = torch.gather(alphas,2,mapT2P[:,None,:,None,None].repeat(1,1,1,NTOTALDOFS,2))
|
||||
|
||||
same_chain=torch.gather(same_chain,1,mapT2P[-1][None,:,None].repeat(1,1,same_chain.shape[-1]))
|
||||
same_chain=torch.gather(same_chain,2,mapT2P[-1][None,None,:].repeat(1,same_chain.shape[1],1))
|
||||
|
||||
bond_feats=torch.gather(bond_feats,1,mapT2P[-1][None,:,None].repeat(1,1,bond_feats.shape[-1]))
|
||||
bond_feats=torch.gather(bond_feats,2,mapT2P[-1][None,None,:].repeat(1,bond_feats.shape[1],1))
|
||||
|
||||
dist_matrix=torch.gather(dist_matrix,1,mapT2P[-1][None,:,None].repeat(1,1,dist_matrix.shape[-1]))
|
||||
dist_matrix=torch.gather(dist_matrix,2,mapT2P[-1][None,None,:].repeat(1,dist_matrix.shape[1],1))
|
||||
|
||||
pred_lddts = torch.gather(pred_lddts,2,mapT2P[-1][None,None,:].repeat(1,pred_lddts.shape[-2],1))
|
||||
idx_pdb = torch.gather(idx_pdb,1,mapT2P[-1][None,:])
|
||||
elif 'sm_compl' in task[0] or 'metal_compl' in task[0]:
|
||||
sm_mask = is_atom(seq[0,0])
|
||||
Ls_prot = Ls_from_same_chain_2d(same_chain[:,~sm_mask][:,:,~sm_mask])
|
||||
Ls_sm = Ls_from_same_chain_2d(same_chain[:,sm_mask][:,:,sm_mask])
|
||||
|
||||
true_crds, atom_mask = resolve_equiv_natives_asmb(
|
||||
pred_allatom, true_crds, atom_mask, ch_label, Ls_prot, Ls_sm)
|
||||
else:
|
||||
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
|
||||
|
||||
res_mask = get_prot_sm_mask(atom_mask, msa[0,0])
|
||||
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
|
||||
|
||||
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, 0], atom_frames)
|
||||
c6d = xyz_to_c6d(true_crds_frame)
|
||||
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
|
||||
|
||||
# contact accuray not as useful to track anymore
|
||||
#prob = self.active_fn(logit_s[0]) # distogram
|
||||
#acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
|
||||
loss, loss_dict = calc_loss(
|
||||
trainer, logit_s, c6d,
|
||||
logit_aa_s, msa, mask_msa, logit_pae, logit_pde, p_bind,
|
||||
pred_crds, alphas, pred_allatom, true_crds,
|
||||
atom_mask, res_mask, mask_2d, same_chain,
|
||||
pred_lddts, idx_pdb, bond_feats, dist_matrix,
|
||||
atom_frames=atom_frames,unclamp=unclamp, negative=negative,
|
||||
item=item, task=task, **loss_param
|
||||
)
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
def calc_loss(trainer, logit_s, label_s,
|
||||
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
|
||||
pred, pred_tors, pred_allatom, true,
|
||||
mask_crds, mask_BB, mask_2d, same_chain,
|
||||
pred_lddt, idx, bond_feats, dist_matrix, atom_frames=None, unclamp=False,
|
||||
negative=False, interface=False,
|
||||
w_dist=1.0, w_aa=1.0, w_str=1.0, w_inter_fape=0.0, w_lig_fape=1.0, w_lddt=1.0,
|
||||
w_bond=1.0, w_clash=0.0, w_atom_bond=0.0, w_skip_bond=0.0, w_rigid=0.0, w_hb=0.0, w_bind=0.0,
|
||||
w_pae=0.0, w_pde=0.0, lj_lin=0.85, eps=1e-6, binder_loss_label_smoothing = 0.0, item=None, task=None, out_dir='./'
|
||||
):
|
||||
gpu = pred.device
|
||||
|
||||
# track losses for printing to local log and uploading to WandB
|
||||
loss_dict = OrderedDict()
|
||||
|
||||
B, L, natoms = true.shape[:3]
|
||||
seq = label_aa_s[:,0].clone()
|
||||
|
||||
assert (B==1) # fd - code assumes a batch size of 1
|
||||
|
||||
tot_loss = 0.0
|
||||
# set up frames
|
||||
frames, frame_mask = get_frames(
|
||||
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
|
||||
|
||||
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
|
||||
frames_BB = frames.clone()
|
||||
frames_BB[..., 1:, :, :] = 0
|
||||
frame_mask_BB = frame_mask.clone()
|
||||
frame_mask_BB[...,1:] =False
|
||||
|
||||
# c6d loss
|
||||
for i in range(4):
|
||||
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
|
||||
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
|
||||
mask_2d_ = mask_2d
|
||||
else:
|
||||
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
|
||||
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
|
||||
bb_frame_good = bb_frame_good[...,0] # (1,L)
|
||||
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
|
||||
mask_2d_ = mask_2d & loss_mask_2d
|
||||
|
||||
if negative.item():
|
||||
# Don't compute inter-chain distogram losses
|
||||
# for negative examples.
|
||||
mask_2d_ = mask_2d_ * same_chain
|
||||
loss = (mask_2d_*loss).sum() / (mask_2d_.sum() + eps)
|
||||
tot_loss += w_dist*loss
|
||||
loss_dict[f'c6d_{i}'] = loss.detach()
|
||||
|
||||
# masked token prediction loss
|
||||
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
|
||||
loss = loss * mask_aa_s.reshape(B, -1)
|
||||
loss = loss.sum() / (mask_aa_s.sum() + 1e-8)
|
||||
tot_loss += w_aa*loss
|
||||
loss_dict['aa_cce'] = loss.detach()
|
||||
|
||||
# col 4: binder loss
|
||||
# only apply binding loss to complexes
|
||||
# note that this will apply loss to positive sets w/o a corresponding negative set
|
||||
# (e.g., homomers). Maybe want to change this?
|
||||
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
|
||||
if (torch.sum(same_chain==0) > 0):
|
||||
bce = torch.nn.BCELoss()
|
||||
target = torch.tensor(
|
||||
[abs(float(not negative) - binder_loss_label_smoothing)],
|
||||
device=p_bind.device
|
||||
)
|
||||
loss = bce(p_bind,target)
|
||||
else:
|
||||
# avoid unused parameter error
|
||||
loss = 0.0 * p_bind.sum()
|
||||
|
||||
tot_loss += w_bind * loss
|
||||
loss_dict['binder_bce_loss'] = loss.detach()
|
||||
|
||||
|
||||
### GENERAL LAYERS
|
||||
# Structural loss (layer-wise backbone FAPE)
|
||||
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
|
||||
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
|
||||
dclamp_prot = 10
|
||||
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
|
||||
# whereas other losses also maks unresolved ligand atoms (mask_BB)
|
||||
# frames with unresolved ligand atoms are masked in compute_general_FAPE
|
||||
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
|
||||
|
||||
# create 2d masks for intrachain and interchain fape calculations
|
||||
nframes = frame_mask.shape[-1]
|
||||
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
|
||||
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
|
||||
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, NTOTAL)
|
||||
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
|
||||
different_chain = ~same_chain.bool()
|
||||
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
|
||||
|
||||
# ic(task, res_mask.sum(), pred.shape, true.shape)
|
||||
if 'tf' in task[0] or res_mask.sum() == 0:
|
||||
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
|
||||
pae_loss = 0.0 * logit_pae.sum()
|
||||
pde_loss = 0.0 * logit_pde.sum()
|
||||
elif negative: # inter-chain fapes should be ignored for negative cases
|
||||
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
|
||||
dclamp=dclamp,
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
)
|
||||
|
||||
#fd pae/pde loss not computed correctly, zero for negatives
|
||||
# Pascal: I think the above is no longer true. PAE/PDE should
|
||||
# be computed correctly for intra chain
|
||||
#pae_loss *= 0.0
|
||||
#pde_loss *= 0.0
|
||||
|
||||
else:
|
||||
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
|
||||
# change clamp for intra protein to 10, leave rest at 30
|
||||
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
|
||||
if not unclamp:
|
||||
is_prot = is_protein(seq) # (1,L)
|
||||
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
|
||||
# zero out rows and columns with small molecules
|
||||
same_chain_clamp_mask[:, ~is_prot[0]] = 0
|
||||
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
|
||||
dclamp_2d *= ~same_chain_clamp_mask.bool()
|
||||
dclamp_2d += same_chain_clamp_mask*dclamp_prot
|
||||
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
dclamp=None,
|
||||
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
)
|
||||
|
||||
# free up big intermediate data tensors
|
||||
del dclamp_2d
|
||||
if not unclamp:
|
||||
del same_chain_clamp_mask
|
||||
|
||||
num_layers = pred.shape[0]
|
||||
gamma = 1.0 # equal weighting of fape across all layers
|
||||
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
|
||||
w_bb_fape = torch.flip(w_bb_fape, (0,))
|
||||
w_bb_fape = w_bb_fape / w_bb_fape.sum()
|
||||
bb_l_fape = (w_bb_fape*tot_str).sum()
|
||||
|
||||
tot_loss += 0.5*w_str*bb_l_fape
|
||||
for i in range(len(tot_str)):
|
||||
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
|
||||
loss_dict['bb_fape_full'] = bb_l_fape.detach()
|
||||
|
||||
tot_loss += w_pae*pae_loss + w_pde*pde_loss
|
||||
loss_dict['pae_loss'] = pae_loss.detach()
|
||||
loss_dict['pde_loss'] = pde_loss.detach()
|
||||
|
||||
## small-molecule ligands
|
||||
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
|
||||
|
||||
#if not negative and bool(torch.any(~sm_res_mask)) and torch.any(frame_mask_BB[0,~sm_res_mask]):
|
||||
## protein fape (layer-averaged fape on protein coordinates with protein frames)
|
||||
#l_fape_prot_intra, _, _ = compute_general_FAPE(
|
||||
#pred[:, ~sm_res_mask[None],:,:3],
|
||||
#true[:,~sm_res_mask,:3,:3],
|
||||
#atom_mask = mask_crds[:,~sm_res_mask, :3],
|
||||
#frames = frames_BB[:,~sm_res_mask],
|
||||
#frame_mask = frame_mask_BB[:,~sm_res_mask],
|
||||
#frame_atom_mask_2d=frame_atom_mask_2d_intra[:, ~sm_res_mask][:, :, :, ~sm_res_mask],
|
||||
#)
|
||||
#prot_fape = l_fape_prot_intra.mean()
|
||||
|
||||
#l_fape_prot_inter, _, _ = compute_general_FAPE(
|
||||
#pred[:, ~sm_res_mask[None],:,:3],
|
||||
#true[:,~sm_res_mask,:3,:3],
|
||||
#atom_mask = mask_crds[:,~sm_res_mask, :3],
|
||||
#frames = frames_BB[:,~sm_res_mask],
|
||||
#frame_mask = frame_mask_BB[:,~sm_res_mask],
|
||||
#frame_atom_mask_2d=frame_atom_mask_2d_inter[:, ~sm_res_mask][:, :, :, ~sm_res_mask],
|
||||
#)
|
||||
#inter_prot_fape = l_fape_prot_inter.mean()
|
||||
#else:
|
||||
#prot_fape = torch.tensor(0).to(gpu)
|
||||
#inter_prot_fape = torch.tensor(0).to(gpu)
|
||||
|
||||
#loss_dict['bb_fape_prot_intra'] = prot_fape.detach()
|
||||
#loss_dict['bb_fape_prot_inter'] = inter_prot_fape.detach()
|
||||
|
||||
##if bool(torch.any(sm_res_mask)) and torch.any(frame_mask_BB[0,sm_res_mask]):
|
||||
## ligand fape (layer-averaged fape on atom coordinates with atom frames)
|
||||
#l_fape_sm_intra, _, _ = compute_general_FAPE(
|
||||
#pred[:, sm_res_mask[None],:,:3],
|
||||
#true[:,sm_res_mask,:3,:3],
|
||||
#atom_mask = mask_crds[:,sm_res_mask, :3],
|
||||
#frames = frames_BB[:,sm_res_mask],
|
||||
#frame_mask = frame_mask_BB[:,sm_res_mask],
|
||||
#frame_atom_mask_2d=frame_atom_mask_2d_intra[:, sm_res_mask][:, :, :, sm_res_mask],
|
||||
#dclamp=dclamp_sm,
|
||||
#Z=Z_sm
|
||||
#)
|
||||
#lig_fape = (w_bb_fape*l_fape_sm_intra).sum()
|
||||
#tot_loss += 0.5*w_lig_fape*lig_fape
|
||||
|
||||
#l_fape_sm_inter, _, _ = compute_general_FAPE(
|
||||
#pred[:, sm_res_mask[None],:,:3],
|
||||
#true[:,sm_res_mask,:3,:3],
|
||||
#atom_mask = mask_crds[:,sm_res_mask, :3],
|
||||
#frames = frames_BB[:,sm_res_mask],
|
||||
#frame_mask = frame_mask_BB[:,sm_res_mask],
|
||||
#frame_atom_mask_2d=frame_atom_mask_2d_inter[:, sm_res_mask][:, :, :, sm_res_mask],
|
||||
#dclamp=dclamp_sm,
|
||||
#Z=Z_sm
|
||||
#)
|
||||
#inter_lig_fape = l_fape_sm_inter.mean()
|
||||
#else:
|
||||
#lig_fape = torch.tensor(0).to(gpu)
|
||||
#inter_lig_fape = torch.tensor(0).to(gpu)
|
||||
|
||||
#loss_dict['bb_fape_lig_intra'] = lig_fape.detach()
|
||||
#loss_dict['bb_fape_lig_inter'] = inter_lig_fape.detach()
|
||||
|
||||
#if not bool(torch.all(sm_res_mask)) and bool(torch.any(sm_res_mask)):
|
||||
## calculate interchain fape
|
||||
## fape of protein coordinates wrt ligand frames
|
||||
#mask_crds_protein = mask_crds.clone()
|
||||
#mask_crds_protein[:, sm_res_mask] = False
|
||||
#frame_mask_BB_sm = frame_mask_BB.clone()
|
||||
#frame_mask_BB_sm[:,~sm_res_mask] = False
|
||||
#if torch.any(mask_crds_protein[:,res_mask[0], :3]) and torch.any(frame_mask_BB_sm[:,res_mask[0]]):
|
||||
#l_fape_protein_sm, _, _ = compute_general_FAPE(
|
||||
#pred[:, res_mask,:,:3],
|
||||
#true[:, res_mask[0],:3,:3],
|
||||
#atom_mask = mask_crds_protein[:,res_mask[0], :3],
|
||||
#frames = frames_BB[:,res_mask[0]],
|
||||
#frame_mask = frame_mask_BB_sm[:,res_mask[0]],
|
||||
#frame_atom_mask = mask_crds[:,res_mask[0],:3],
|
||||
#dclamp=dclamp
|
||||
#)
|
||||
#else:
|
||||
#l_fape_protein_sm = torch.tensor(0).to(gpu)
|
||||
|
||||
## fape of ligand coordinates wrt protein frames
|
||||
#mask_crds_sm = mask_crds.clone()
|
||||
#mask_crds_sm[:, ~sm_res_mask] = False
|
||||
#frame_mask_BB_protein = frame_mask_BB.clone()
|
||||
#frame_mask_BB_protein[:,sm_res_mask] = False
|
||||
#if torch.any(mask_crds_sm[:,res_mask[0], :3]) and torch.any(frame_mask_BB_protein[:,res_mask[0]]):
|
||||
#l_fape_sm_protein, _, _ = compute_general_FAPE(
|
||||
#pred[:, res_mask,:,:3],
|
||||
#true[:, res_mask[0],:3,:3],
|
||||
#atom_mask = mask_crds_sm[:,res_mask[0], :3],
|
||||
#frames = frames_BB[:,res_mask[0]],
|
||||
#frame_mask = frame_mask_BB_protein[:,res_mask[0]],
|
||||
#frame_atom_mask = mask_crds[:,res_mask[0],:3],
|
||||
#dclamp=dclamp
|
||||
#)
|
||||
#else:
|
||||
#l_fape_sm_protein = torch.tensor(0).to(gpu)
|
||||
|
||||
##frac_sm = torch.sum(frame_mask_BB_sm[:,res_mask[0]])/ torch.sum(frame_mask_BB[:,res_mask[0]])
|
||||
##inter_fape = frac_sm*l_fape_protein_sm + (1.0-frac_sm)*l_fape_sm_protein
|
||||
#inter_fape = l_fape_sm_protein + l_fape_protein_sm
|
||||
#bb_l_fape_inter = (w_bb_fape*inter_fape).sum()
|
||||
#tot_loss += 0.5*w_inter_fape*bb_l_fape_inter
|
||||
#else:
|
||||
#bb_l_fape_inter = torch.tensor(0).to(gpu)
|
||||
|
||||
#loss_dict['bb_fape_inter'] = bb_l_fape_inter.detach()
|
||||
|
||||
## AllAtom loss
|
||||
# get ground-truth torsion angles
|
||||
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
|
||||
true, seq, mask_in=mask_crds)
|
||||
tors_mask *= mask_BB[...,None]
|
||||
|
||||
# get alternative coordinates for ground-truth
|
||||
true_alt = torch.zeros_like(true)
|
||||
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
|
||||
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
|
||||
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
|
||||
predTs = pred[-1,...]
|
||||
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
|
||||
|
||||
# - resolve symmetry
|
||||
xs_mask = trainer.aamask[seq] # (B, L, 27)
|
||||
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
|
||||
xs_mask *= mask_crds # mask missing atoms & residues as well
|
||||
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
|
||||
|
||||
# torsion angle loss
|
||||
l_tors = torsionAngleLoss(
|
||||
pred_tors,
|
||||
true_tors,
|
||||
true_tors_alt,
|
||||
tors_mask,
|
||||
tors_planar,
|
||||
eps = 1e-10)
|
||||
tot_loss += w_str*l_tors
|
||||
loss_dict['torsion'] = l_tors.detach()
|
||||
|
||||
### FINETUNING LAYERS
|
||||
# lddts (CA)
|
||||
ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface)
|
||||
loss_dict['ca_lddt'] = ca_lddt[-1].detach()
|
||||
|
||||
# lddts (allatom) + lddt loss
|
||||
lddt_loss, allatom_lddt = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain,
|
||||
negative=negative, interface=interface, N_stripe=10)
|
||||
tot_loss += w_lddt*lddt_loss
|
||||
loss_dict['lddt_loss'] = lddt_loss.detach()
|
||||
loss_dict['allatom_lddt'] = allatom_lddt[0].detach()
|
||||
#print (allatom_lddt[0].detach())
|
||||
|
||||
# FAPE losses
|
||||
# allatom fape and torsion angle loss
|
||||
# frames, frame_mask = get_frames(
|
||||
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
|
||||
if 'tf' in task[0] or res_mask.sum() == 0:
|
||||
l_fape = torch.zeros((pred.shape[0])).to(gpu)
|
||||
|
||||
elif negative.item(): # inter-chain fapes should be ignored for negative cases
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
|
||||
)
|
||||
|
||||
else:
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]]
|
||||
)
|
||||
|
||||
tot_loss += w_str*l_fape[0]
|
||||
loss_dict['allatom_fape'] = l_fape[0].detach()
|
||||
|
||||
# rmsd loss (for logging only)
|
||||
if torch.any(mask_BB[0]):
|
||||
rmsd = calc_crd_rmsd(
|
||||
pred_allatom[:,mask_BB[0],:,:3],
|
||||
nat_symm[None,mask_BB[0],:,:3],
|
||||
xs_mask[:,mask_BB[0]]
|
||||
)
|
||||
loss_dict["rmsd"] = rmsd[0].detach()
|
||||
else:
|
||||
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
|
||||
|
||||
# create protein and not protein masks; not protein could include nucleic acids
|
||||
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
|
||||
not_prot_mask_BB = ~prot_mask_BB.bool()
|
||||
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
|
||||
xs_mask_prot[:,~prot_mask_BB] = False
|
||||
xs_mask_lig[:,~not_prot_mask_BB] = False
|
||||
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_prot = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_prot_prot = torch.tensor([0], device=pred.device)
|
||||
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_lig_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_lig_lig = torch.tensor([0], device=pred.device)
|
||||
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
|
||||
alignment_radius=10.0
|
||||
)
|
||||
else:
|
||||
rmsd_prot_lig = torch.tensor([0], device=pred.device)
|
||||
|
||||
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
|
||||
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
|
||||
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
|
||||
|
||||
# cart bonded (bond geometry)
|
||||
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
|
||||
if w_bond > 0.0:
|
||||
tot_loss += w_bond*bond_loss
|
||||
loss_dict['bond_geom'] = bond_loss.detach()
|
||||
|
||||
# if (pred_allatom.shape[0] > 1):
|
||||
# bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
|
||||
# if w_bond > 0.0:
|
||||
# tot_loss += w_bond*bond_loss.mean()
|
||||
# loss_dict['clash_loss'] = ( bond_loss.detach() )
|
||||
# else:
|
||||
# bond_loss = torch.tensor(0).to(gpu)
|
||||
# loss_dict['bond_loss'] = bond_loss.detach()
|
||||
|
||||
# clash [use all atoms not just those in native]
|
||||
# clash_loss = calc_lj(
|
||||
# seq[0], pred_allatom,
|
||||
# self.aamask, bond_feats, dist_matrix, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
|
||||
# lj_lin=lj_lin
|
||||
# )
|
||||
clash_loss, num_violations = calc_l1_clash_loss(pred_allatom, seq[0],\
|
||||
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, \
|
||||
trainer.lj_correction_parameters, trainer.num_bonds)
|
||||
if w_clash > 0.0:
|
||||
tot_loss += w_clash*clash_loss.mean()
|
||||
loss_dict['clash_loss'] = clash_loss.detach()
|
||||
if torch.any(mask_BB[0]):
|
||||
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
|
||||
pred=pred_allatom[:,mask_BB[0]],
|
||||
true=nat_symm[None,mask_BB[0]],
|
||||
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
|
||||
seq=seq[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
atom_bond_loss = torch.tensor(0, device=gpu)
|
||||
skip_bond_loss = torch.tensor(0, device=gpu)
|
||||
rigid_loss = torch.tensor(0, device=gpu)
|
||||
|
||||
if w_atom_bond >= 0.0:
|
||||
tot_loss += w_atom_bond*atom_bond_loss
|
||||
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
|
||||
|
||||
if w_skip_bond >= 0.0:
|
||||
tot_loss += w_skip_bond*skip_bond_loss
|
||||
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
|
||||
|
||||
if w_rigid >= 0.0:
|
||||
tot_loss += w_rigid*rigid_loss
|
||||
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
|
||||
chain_prot = same_chain.clone()
|
||||
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
|
||||
|
||||
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, negative=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
|
||||
|
||||
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
|
||||
|
||||
chain_lig = same_chain.clone()
|
||||
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
|
||||
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
|
||||
|
||||
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
|
||||
|
||||
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
|
||||
chain_prot_lig_inter += protein_mask_2d
|
||||
chain_prot_lig_inter += not_protein_mask_2d
|
||||
_, allatom_lddt_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
|
||||
chain_prot_lig_inter, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
|
||||
loss_dict['total_loss'] = tot_loss.detach()
|
||||
|
||||
return tot_loss, loss_dict
|
||||
|
||||
|
||||
### this file will contain specific calls to the loss function
|
||||
class LossManager:
|
||||
""" this class computes the loss and holds useful primitives for loss calc """
|
||||
def __init__(self, config) -> None:
|
||||
self.loss_list = []
|
||||
self.loss_weights = []
|
||||
self.loss_dict = {}
|
||||
|
||||
def compute_loss(self, rf_inputs, rf_outputs):
|
||||
for loss in self.loss_list:
|
||||
pass
|
||||
|
||||
def get_frames(self):
|
||||
if self.frames is not None and self.frame_mask is not None:
|
||||
return self.frames, self.frame_mask
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
loss_factory = {
|
||||
"c6d": None,
|
||||
"mlm": None,
|
||||
"lddt": None,
|
||||
"pae": None,
|
||||
"bb_fape": None,
|
||||
"allatom_fape": None,
|
||||
|
||||
}
|
||||
def c6d_loss(loss_manager, trainer):
|
||||
pass
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from pathlib import Path
|
||||
|
||||
from rf2aa.submitit_utils import add_slurm_args, create_executor
|
||||
from rf2aa.arguments import get_args
|
||||
from factory import trainer_factory
|
||||
|
||||
|
||||
def call_fn(args, dataset_param, model_param, loader_param, loss_param):
|
||||
master_port = str(np.random.randint(10000, 100000))
|
||||
master_address = (
|
||||
subprocess.check_output(
|
||||
['scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1'], shell=True
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
print(f"Setting MASTER_PORT={master_port} and MASTER_ADDR={master_address}")
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
os.environ["MASTER_ADDR"] = master_address
|
||||
|
||||
print("============== INPUT ARGUMENTS ==============")
|
||||
pprint(vars(args))
|
||||
print("=============================================")
|
||||
|
||||
mp.freeze_support()
|
||||
trainer_object = trainer_factory(args, dataset_param, model_param, loader_param, loss_param)
|
||||
trainer_object.run_model_training(torch.cuda.device_count())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
choices=["train", "eval"],
|
||||
help="Set to eval to run evaluation only",
|
||||
)
|
||||
parser = add_slurm_args(parser)
|
||||
print("Reading in arguments...")
|
||||
args, dataset_param, model_param, loader_param, loss_param = get_args(parser)
|
||||
print("Done reading in arguments.")
|
||||
|
||||
assert (
|
||||
args.local or not args.interactive
|
||||
), "When submiting via submitit you either have to launch it locally, or not in interactive mode."
|
||||
|
||||
if args.local:
|
||||
call_fn(args, dataset_param, model_param, loader_param, loss_param)
|
||||
else:
|
||||
if args.mode == "train":
|
||||
log_folder = Path(args.slurm_log_path) / args.model_name / "training_log/"
|
||||
else:
|
||||
if args.initialize_model_from_checkpoint is not None:
|
||||
model_restore_path = Path(args.initialize_model_from_checkpoint)
|
||||
model_name = model_restore_path.stem
|
||||
else:
|
||||
model_name = "no_checkpoint"
|
||||
|
||||
log_folder = Path(args.slurm_log_path) / args.model_name / f"eval_{model_name}/"
|
||||
log_folder.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
job_name = f"{args.model_name}_training"
|
||||
executor = create_executor(args, log_folder, job_name)
|
||||
job = executor.submit(
|
||||
call_fn, args, dataset_param, model_param, loader_param, loss_param
|
||||
)
|
||||
print(
|
||||
f"Submitted job {job.job_id} with name {job_name} and log folder {log_folder}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,9 +3,9 @@ import torch.nn as nn
|
||||
import assertpy
|
||||
from assertpy import assert_that
|
||||
from icecream import ic
|
||||
from rf2aa.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling
|
||||
from rf2aa.Track_module import IterativeSimulator
|
||||
from rf2aa.AuxiliaryPredictor import (
|
||||
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
||||
from rf2aa.model.Track_module import IterativeSimulator
|
||||
from rf2aa.model.layers.AuxiliaryPredictor import (
|
||||
DistanceNetwork,
|
||||
MaskedTokenNetwork,
|
||||
LDDTNetwork,
|
||||
@@ -26,7 +26,7 @@ def get_shape(t):
|
||||
return type(t)
|
||||
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
class LegacyRoseTTAFoldModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
|
||||
@@ -50,6 +50,7 @@ class RoseTTAFoldModule(nn.Module):
|
||||
d_hidden_templ=64,
|
||||
p_drop=0.15,
|
||||
additional_dt1d=0,
|
||||
recycling_type="msa_pair",
|
||||
SE3_param={}, SE3_ref_param={},
|
||||
atom_type_index=None,
|
||||
aamask=None,
|
||||
@@ -60,22 +61,25 @@ class RoseTTAFoldModule(nn.Module):
|
||||
cb_tor=None,
|
||||
num_bonds=None,
|
||||
lj_lin=0.6,
|
||||
use_extra_l1=True,
|
||||
use_chiral_l1=True,
|
||||
use_lj_l1=False,
|
||||
use_atom_frames=True,
|
||||
use_same_chain=False,
|
||||
# New for diffusion
|
||||
freeze_track_motif=False,
|
||||
assert_single_sequence_input=False,
|
||||
fit=False,
|
||||
tscale=1.0
|
||||
):
|
||||
super(RoseTTAFoldModule, self).__init__()
|
||||
super(LegacyRoseTTAFoldModule, self).__init__()
|
||||
self.freeze_track_motif = freeze_track_motif
|
||||
self.assert_single_sequence_input = assert_single_sequence_input
|
||||
self.recycling_type = recycling_type
|
||||
#
|
||||
# Input Embeddings
|
||||
d_state = SE3_param["l0_out_features"]
|
||||
self.latent_emb = MSA_emb(
|
||||
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop
|
||||
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain
|
||||
)
|
||||
self.full_emb = Extra_emb(
|
||||
d_msa=d_msa_full, d_init=NAATOKENS - 1 + 4, p_drop=p_drop
|
||||
@@ -97,7 +101,8 @@ class RoseTTAFoldModule(nn.Module):
|
||||
additional_dt1d=additional_dt1d)
|
||||
|
||||
# Update inputs with outputs from previous round
|
||||
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
|
||||
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
#
|
||||
self.simulator = IterativeSimulator(
|
||||
n_extra_block=n_extra_block,
|
||||
@@ -122,23 +127,22 @@ class RoseTTAFoldModule(nn.Module):
|
||||
cb_ang=cb_ang,
|
||||
cb_tor=cb_tor,
|
||||
lj_lin=lj_lin,
|
||||
use_extra_l1=use_extra_l1,
|
||||
use_lj_l1=use_lj_l1,
|
||||
use_chiral_l1=use_chiral_l1,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method,
|
||||
main_block=main_block
|
||||
main_block=main_block,
|
||||
use_same_chain=use_same_chain
|
||||
)
|
||||
|
||||
##
|
||||
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||
self.lddt_pred = LDDTNetwork(d_state)
|
||||
if (
|
||||
use_extra_l1
|
||||
): # extra l1 features introduced at the same time as pAE/pDE heads
|
||||
self.pae_pred = PAENetwork(d_pair)
|
||||
self.pde_pred = PAENetwork(
|
||||
self.pae_pred = PAENetwork(d_pair)
|
||||
self.pde_pred = PAENetwork(
|
||||
d_pair
|
||||
) # distance error, but use same architecture as aligned error
|
||||
|
||||
@@ -147,8 +151,6 @@ class RoseTTAFoldModule(nn.Module):
|
||||
# this prediction head.
|
||||
# self.binder_network = BinderNetwork(d_pair, d_state)
|
||||
|
||||
self.use_extra_l1 = use_extra_l1
|
||||
|
||||
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
|
||||
|
||||
self.use_atom_frames = use_atom_frames
|
||||
@@ -167,7 +169,7 @@ class RoseTTAFoldModule(nn.Module):
|
||||
dist_matrix,
|
||||
chirals,
|
||||
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
||||
msa_prev=None, pair_prev=None, mask_recycle=None, is_motif=None,
|
||||
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
|
||||
return_raw=False,
|
||||
use_checkpoint=False,
|
||||
return_infer=False, #fd ?
|
||||
@@ -338,12 +340,15 @@ class RoseTTAFoldModule(nn.Module):
|
||||
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||
if pair_prev is None:
|
||||
pair_prev = torch.zeros_like(pair)
|
||||
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
||||
state_prev = torch.zeros_like(state)
|
||||
|
||||
msa_recycle, pair_recycle = self.recycle(msa_prev, pair_prev, xyz, mask_recycle)
|
||||
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
||||
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
state = state + state_recycle # if state is not recycled these will be zeros
|
||||
|
||||
# add template embedding
|
||||
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
|
||||
@@ -383,26 +388,15 @@ class RoseTTAFoldModule(nn.Module):
|
||||
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
|
||||
)
|
||||
|
||||
# if return_infer:
|
||||
# xyz_last = xyz_allatom[-1].unsqueeze(0)
|
||||
# # return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
|
||||
# # return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
|
||||
# return msa[:,0], pair, xyz_last, state, alpha_s[-1], logits_aa.permute(0,2,1), lddt
|
||||
|
||||
logits_pae = logits_pde = p_bind = None
|
||||
if not return_infer:
|
||||
# predict aligned error and distance error
|
||||
if self.use_extra_l1:
|
||||
logits_pae = self.pae_pred(pair)
|
||||
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||
# predict aligned error and distance error
|
||||
logits_pae = self.pae_pred(pair)
|
||||
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||
|
||||
#fd predict bind/no-bind
|
||||
p_bind = self.bind_pred(logits_pae,same_chain)
|
||||
else:
|
||||
logits_pae = None
|
||||
logits_pde = None
|
||||
#fd predict bind/no-bind
|
||||
p_bind = self.bind_pred(logits_pae,same_chain)
|
||||
|
||||
return (
|
||||
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
|
||||
)
|
||||
@@ -8,11 +8,10 @@ from icecream import ic
|
||||
from contextlib import ExitStack, nullcontext
|
||||
|
||||
from rf2aa.util_module import *
|
||||
from rf2aa.Attention_module import *
|
||||
from rf2aa.SE3_network import SE3TransformerWrapper
|
||||
from rf2aa.resnet import ResidualNetwork
|
||||
from rf2aa.model.layers.Attention_module import *
|
||||
from rf2aa.model.layers.SE3_network import SE3TransformerWrapper
|
||||
from rf2aa.util import INIT_CRDS, is_atom, xyz_frame_from_rotation_mask
|
||||
from rf2aa.loss import (
|
||||
from rf2aa.loss.loss import (
|
||||
calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads,
|
||||
calc_lj, calc_cart_bonded, calc_chiral_grads
|
||||
)
|
||||
@@ -28,7 +27,7 @@ from rf2aa.symmetry import get_symm_map
|
||||
|
||||
class PositionalEncoding2D(nn.Module):
|
||||
# Add relative positional encoding to pair features
|
||||
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
|
||||
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15, use_same_chain=False):
|
||||
super(PositionalEncoding2D, self).__init__()
|
||||
self.minpos = minpos
|
||||
self.maxpos = maxpos
|
||||
@@ -37,8 +36,13 @@ class PositionalEncoding2D(nn.Module):
|
||||
self.nbin_atom = maxpos_atom+2 # include 0 and "unknown" token (maxpos_sm + 1)
|
||||
self.emb_res = nn.Embedding(self.nbin_res, d_pair)
|
||||
self.emb_atom = nn.Embedding(self.nbin_atom, d_pair)
|
||||
|
||||
self.use_same_chain = use_same_chain
|
||||
if use_same_chain:
|
||||
self.emb_chain = nn.Embedding(2, d_pair)
|
||||
|
||||
def forward(self, seq, idx, bond_feats, dist_matrix):
|
||||
|
||||
def forward(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||||
sm_mask = is_atom(seq[0])
|
||||
|
||||
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask,
|
||||
@@ -54,6 +58,10 @@ class PositionalEncoding2D(nn.Module):
|
||||
|
||||
out = emb_res + emb_atom
|
||||
|
||||
if self.use_same_chain and same_chain is not None:
|
||||
emb_c = self.emb_chain(same_chain.long())
|
||||
out += emb_c*0 # cursed but exists for backwards compatibility
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -353,10 +361,10 @@ class PairStr2Pair(nn.Module):
|
||||
|
||||
return pair + (pairnew/countnew[...,None]).reshape(N,L,L,-1)
|
||||
|
||||
def forward(self, pair, rbf_feat, state, crop=64):
|
||||
def forward(self, pair, rbf_feat, state, crop=-1):
|
||||
B,L = pair.shape[:2]
|
||||
|
||||
rbf_feat = self.emb_rbf(rbf_feat)
|
||||
rbf_feat = self.emb_rbf(rbf_feat) # B, L, L, d_pair
|
||||
|
||||
state = self.norm_state(state)
|
||||
left = self.proj_left(state)
|
||||
@@ -458,7 +466,7 @@ class Str2Str(nn.Module):
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
|
||||
self.embed_node = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
|
||||
self.ff_node = FeedForwardLayer(SE3_param['l0_in_features'], 2, p_drop=p_drop)
|
||||
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
|
||||
@@ -887,7 +895,7 @@ class IterBlock(nn.Module):
|
||||
d_hidden=32, d_hidden_msa=None,
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
nextra_l0=0, nextra_l1=0,
|
||||
nextra_l0=0, nextra_l1=0, use_same_chain=False,
|
||||
symmetrize_repeats=None, repeat_length=None,symmsub_k=None, sym_method=None, main_block=None,
|
||||
fit=False, tscale=1.0
|
||||
):
|
||||
@@ -900,7 +908,7 @@ class IterBlock(nn.Module):
|
||||
self.tscale = tscale
|
||||
|
||||
self.pos = PositionalEncoding2D(d_rbf, minpos=minpos, maxpos=maxpos,
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop)
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain)
|
||||
|
||||
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
|
||||
n_head=n_head_msa,
|
||||
@@ -932,7 +940,7 @@ class IterBlock(nn.Module):
|
||||
crop=-1
|
||||
):
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, dist_matrix)
|
||||
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, dist_matrix, same_chain)
|
||||
if use_checkpoint:
|
||||
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
|
||||
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
|
||||
@@ -966,7 +974,7 @@ class IterativeSimulator(nn.Module):
|
||||
atom_type_index=None, aamask=None,
|
||||
ljlk_parameters=None, lj_correction_parameters=None,
|
||||
cb_len=None, cb_ang=None, cb_tor=None,
|
||||
num_bonds=None, lj_lin=0.6, use_extra_l1=True,
|
||||
num_bonds=None, lj_lin=0.6, use_same_chain=False, use_chiral_l1=True, use_lj_l1=False,
|
||||
symmetrize_repeats=None,
|
||||
repeat_length=None,
|
||||
symmsub_k=None,
|
||||
@@ -990,7 +998,8 @@ class IterativeSimulator(nn.Module):
|
||||
self.cb_len = cb_len
|
||||
self.cb_ang = cb_ang
|
||||
self.cb_tor = cb_tor
|
||||
self.use_extra_l1 = use_extra_l1 # set to False to not use chiral & LJ grads
|
||||
self.use_chiral_l1 = use_chiral_l1
|
||||
self.use_lj_l1 = use_lj_l1
|
||||
self.fit = fit
|
||||
self.tscale = tscale
|
||||
|
||||
@@ -1004,7 +1013,8 @@ class IterativeSimulator(nn.Module):
|
||||
p_drop=p_drop,
|
||||
use_global_attn=True,
|
||||
SE3_param=SE3_param,
|
||||
nextra_l1=3 if self.use_extra_l1 else 0,
|
||||
nextra_l1=3 if self.use_chiral_l1 else 0,
|
||||
use_same_chain=use_same_chain,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
@@ -1023,7 +1033,8 @@ class IterativeSimulator(nn.Module):
|
||||
p_drop=p_drop,
|
||||
use_global_attn=False,
|
||||
SE3_param=SE3_param,
|
||||
nextra_l1=3 if self.use_extra_l1 else 0,
|
||||
nextra_l1=3 if self.use_chiral_l1 else 0,
|
||||
use_same_chain=use_same_chain,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
@@ -1035,12 +1046,19 @@ class IterativeSimulator(nn.Module):
|
||||
|
||||
# Final SE(3) refinement
|
||||
if n_ref_block > 0:
|
||||
n_extra_l0 = 0
|
||||
n_extra_l1 = 0
|
||||
if self.use_chiral_l1:
|
||||
n_extra_l1 += 3
|
||||
if self.use_lj_l1:
|
||||
n_extra_l0 += 2*NTOTALDOFS
|
||||
n_extra_l1 += 3
|
||||
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
SE3_param=SE3_ref_param,
|
||||
p_drop=p_drop,
|
||||
nextra_l0=2*NTOTALDOFS if self.use_extra_l1 else 0,
|
||||
nextra_l1=6 if self.use_extra_l1 else 0
|
||||
nextra_l0=n_extra_l0,
|
||||
nextra_l1=n_extra_l1,
|
||||
)
|
||||
|
||||
# # Fine-tuning all-atom SE(3) refinement
|
||||
@@ -1096,7 +1114,7 @@ class IterativeSimulator(nn.Module):
|
||||
for i_m in range(self.n_extra_block):
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
if self.use_extra_l1:
|
||||
if self.use_chiral_l1:
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l1 = dchiraldxyz[0].detach()
|
||||
|
||||
@@ -1119,7 +1137,7 @@ class IterativeSimulator(nn.Module):
|
||||
for i_m in range(self.n_main_block):
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
if self.use_extra_l1:
|
||||
if self.use_chiral_l1:
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l1 = dchiraldxyz[0].detach()
|
||||
msa, pair, xyz, state, alpha, symmsub = self.main_block[i_m](msa, pair,
|
||||
@@ -1140,17 +1158,16 @@ class IterativeSimulator(nn.Module):
|
||||
|
||||
_, xyzallatom = self.xyzconverter.compute_all_atom(seq_unmasked, xyz, alpha) # think about detach here...
|
||||
|
||||
# memory savings: only backprop 1st and another random step
|
||||
backprop = np.random.randint(1,self.n_ref_block)
|
||||
backprop = torch.arange(self.n_ref_block) # backprop through everything
|
||||
for i_m in range(self.n_ref_block):
|
||||
with ExitStack() as stack:
|
||||
if i_m != 0 and i_m != backprop:
|
||||
if (backprop != i_m).all():
|
||||
stack.enter_context(torch.no_grad())
|
||||
|
||||
|
||||
extra_l0 = None
|
||||
extra_l1 = None
|
||||
extra_l1 = []
|
||||
|
||||
if self.use_extra_l1:
|
||||
if self.use_lj_l1:
|
||||
dljdxyz, dljdalpha = calc_lj_grads(
|
||||
seq_unmasked, xyz.detach(), alpha.detach(),
|
||||
self.xyzconverter.compute_all_atom,
|
||||
@@ -1160,14 +1177,19 @@ class IterativeSimulator(nn.Module):
|
||||
self.lj_correction_parameters,
|
||||
self.num_bonds,
|
||||
lj_lin=self.lj_lin)
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach()
|
||||
extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
|
||||
extra_l1.append(dljdxyz[0].detach())
|
||||
|
||||
if self.use_chiral_l1:
|
||||
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||||
#extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
|
||||
extra_l1.append(dchiraldxyz[0].detach())
|
||||
extra_l1 = torch.cat(extra_l1, dim=1)
|
||||
|
||||
xyz, state, alpha = self.str_refiner(
|
||||
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx,
|
||||
rotation_mask, bond_feats, dist_matrix, atom_frames,
|
||||
is_motif, extra_l0.float(), extra_l1.float(), top_k=64, use_atom_frames=use_atom_frames #fd 128->64
|
||||
is_motif, extra_l0, extra_l1.float(), top_k=64, use_atom_frames=use_atom_frames #fd 128->64
|
||||
)
|
||||
|
||||
|
||||
125
rf2aa/model/embedding_blocks.py
Normal file
125
rf2aa/model/embedding_blocks.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.model.layers.Embeddings import MSA_emb, MSA_emb_nostate, \
|
||||
Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
||||
from rf2aa.chemical import NBTYPES, NAATOKENS
|
||||
|
||||
|
||||
class RF2_embedding(nn.Module):
|
||||
|
||||
def __init__(self, global_params, block_params):
|
||||
super(RF2_embedding, self).__init__()
|
||||
d_msa, d_msa_full, d_pair, d_state = global_params["d_msa"], global_params["d_msa_full"], global_params["d_pair"], global_params["d_state"]
|
||||
self.latent_emb = MSA_emb(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_state=d_state,
|
||||
p_drop=block_params.p_drop,
|
||||
use_same_chain=block_params.use_same_chain
|
||||
)
|
||||
self.full_emb = Extra_emb(
|
||||
d_msa=d_msa_full,
|
||||
d_init=NAATOKENS - 1 + 4, #HACK: should define this freom the config (4: ins/del,nterm/cterm feats)
|
||||
p_drop=block_params.p_drop
|
||||
)
|
||||
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES)
|
||||
|
||||
self.templ_emb = Templ_emb(d_pair=d_pair,
|
||||
d_templ=block_params.d_templ,
|
||||
d_state=d_state,
|
||||
n_head=block_params.n_head_templ,
|
||||
d_hidden=block_params.d_hidden_templ,
|
||||
p_drop=block_params.templ_p_drop,
|
||||
symmetrize_repeats=block_params.symmetrize_repeats, # repeat protein stuff
|
||||
repeat_length=block_params.repeat_length,
|
||||
symmsub_k=block_params.symmsub_k,
|
||||
sym_method=block_params.sym_method,
|
||||
main_block=block_params.main_block,
|
||||
copy_main_block=block_params.copy_main_block_template,
|
||||
additional_dt1d=block_params.additional_dt1d)
|
||||
|
||||
## Update inputs with outputs from previous forward pass
|
||||
self.recycle = recycling_factory[block_params.recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
self.recycling_type = block_params.recycling_type
|
||||
assert self.recycling_type == "msa_pair", "no backward compatibility to recycling state"
|
||||
|
||||
def _unpack_inputs(self, rf_inputs):
|
||||
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix = \
|
||||
rf_inputs["msa_latent"], rf_inputs["msa_full"], rf_inputs["seq"], rf_inputs["idx"], rf_inputs["bond_feats"], \
|
||||
rf_inputs["dist_matrix"]
|
||||
## recycling inputs
|
||||
msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle = rf_inputs["msa_prev"], rf_inputs["pair_prev"], None, \
|
||||
rf_inputs["xyz"], rf_inputs["sctors"], rf_inputs["mask_recycle"]
|
||||
return msa_latent, msa_full, seq, idx, bond_feats, dist_matrix, msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle
|
||||
|
||||
def _add_templ_features(self, rf_inputs, pair, state):
|
||||
t1d, t2d, alpha_t, xyz_t, mask_t = rf_inputs["t1d"], rf_inputs["t2d"], \
|
||||
rf_inputs["alpha_t"], rf_inputs["xyz_t"], \
|
||||
rf_inputs["mask_t"]
|
||||
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state)
|
||||
return pair, state
|
||||
|
||||
def forward(self, rf_inputs):
|
||||
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix, msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle = \
|
||||
self._unpack_inputs(rf_inputs)
|
||||
B, N, L = msa_latent.shape[:3]
|
||||
|
||||
dtype = msa_latent.dtype
|
||||
|
||||
msa_latent, pair, state = self.latent_emb(
|
||||
msa_latent, seq, idx, bond_feats, dist_matrix
|
||||
)
|
||||
msa_full = self.full_emb(msa_full, seq, idx)
|
||||
pair = pair + self.bond_emb(bond_feats)
|
||||
|
||||
msa_latent, pair = msa_latent.to(dtype), pair.to(dtype)
|
||||
msa_full = msa_full.to(dtype)
|
||||
if state is not None:
|
||||
state = state.to(dtype)
|
||||
|
||||
if msa_prev is None:
|
||||
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||
if pair_prev is None:
|
||||
pair_prev = torch.zeros_like(pair)
|
||||
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
||||
state_prev = torch.zeros_like(msa_latent[:, 0])
|
||||
|
||||
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||
|
||||
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
||||
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
# No support for recycling state
|
||||
#state = state + state_recycle # if state is not recycled these will be zeros
|
||||
# add template embedding
|
||||
pair, state = self._add_templ_features(rf_inputs, pair, state)
|
||||
return {
|
||||
"msa": msa_latent,
|
||||
"msa_full": msa_full,
|
||||
"pair": pair,
|
||||
"state": state
|
||||
}
|
||||
|
||||
class RF2_embedding_nostate(RF2_embedding):
|
||||
|
||||
def __init__(self, global_params, block_params):
|
||||
super(RF2_embedding_nostate, self).__init__(global_params, block_params)
|
||||
d_msa, d_msa_full, d_pair, d_state = global_params["d_msa"], global_params["d_msa_full"], global_params["d_pair"], global_params["d_state"]
|
||||
|
||||
self.latent_emb = MSA_emb_nostate(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_state=d_state,
|
||||
p_drop=block_params.p_drop,
|
||||
use_same_chain=block_params.use_same_chain
|
||||
)
|
||||
|
||||
|
||||
|
||||
embedding_factory = {
|
||||
"rf2aa": RF2_embedding,
|
||||
"rf2aa_nostate": RF2_embedding_nostate
|
||||
}
|
||||
|
||||
@@ -85,7 +85,6 @@ class SequenceWeight(nn.Module):
|
||||
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
@@ -453,7 +452,7 @@ class BiasedAxialAttention(nn.Module):
|
||||
|
||||
pair = self.norm_pair(pair)
|
||||
bias = self.norm_bias(bias)
|
||||
|
||||
|
||||
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
||||
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
||||
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
||||
@@ -4,9 +4,9 @@ import torch.nn as nn
|
||||
from rf2aa.chemical import NAATOKENS
|
||||
|
||||
class DistanceNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.1):
|
||||
def __init__(self, n_feat, p_drop=0.0):
|
||||
super(DistanceNetwork, self).__init__()
|
||||
#
|
||||
#HACK: dimensions are hard coded here
|
||||
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
|
||||
self.proj_asymm = nn.Linear(n_feat, 37+19)
|
||||
|
||||
@@ -36,7 +36,7 @@ class DistanceNetwork(nn.Module):
|
||||
return logits_dist, logits_omega, logits_theta, logits_phi
|
||||
|
||||
class MaskedTokenNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.1):
|
||||
def __init__(self, n_feat, p_drop=0.0):
|
||||
super(MaskedTokenNetwork, self).__init__()
|
||||
|
||||
#fd note this predicts probability for the mask token (which is never in ground truth)
|
||||
@@ -76,6 +76,7 @@ class PAENetwork(nn.Module):
|
||||
super(PAENetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, n_bin_pae)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
@@ -101,3 +102,10 @@ class BinderNetwork(nn.Module):
|
||||
prob = torch.sigmoid( self.classify( logits_inter ) )
|
||||
return prob
|
||||
|
||||
aux_predictor_factory = {
|
||||
"c6d": DistanceNetwork,
|
||||
"mlm": MaskedTokenNetwork,
|
||||
"plddt": LDDTNetwork,
|
||||
"pae": PAENetwork,
|
||||
"binder": BinderNetwork
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from rf2aa.util import NAATOKENS
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -5,8 +6,8 @@ from opt_einsum import contract as einsum
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from rf2aa.util import *
|
||||
from rf2aa.util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
|
||||
from rf2aa.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
|
||||
from rf2aa.Track_module import PairStr2Pair, PositionalEncoding2D
|
||||
from rf2aa.model.layers.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
|
||||
from rf2aa.model.Track_module import PairStr2Pair, PositionalEncoding2D
|
||||
from rf2aa.chemical import NAATOKENS,NTOTALDOFS, NBTYPES
|
||||
|
||||
# Module contains classes and functions to generate initial embeddings
|
||||
@@ -14,7 +15,7 @@ from rf2aa.chemical import NAATOKENS,NTOTALDOFS, NBTYPES
|
||||
class MSA_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2,
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1):
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
|
||||
super(MSA_emb, self).__init__()
|
||||
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
|
||||
@@ -22,7 +23,7 @@ class MSA_emb(nn.Module):
|
||||
self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||
self.emb_state = nn.Embedding(NAATOKENS, d_state)
|
||||
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop)
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
@@ -35,6 +36,26 @@ class MSA_emb(nn.Module):
|
||||
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
|
||||
def _msa_emb(self, msa, seq):
|
||||
N = msa.shape[1]
|
||||
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
|
||||
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
|
||||
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
|
||||
return msa
|
||||
|
||||
def _pair_emb(self, seq, idx, bond_feats, dist_matrix):
|
||||
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
|
||||
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
|
||||
pair = left + right # (B, L, L, d_pair)
|
||||
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix) # add relative position
|
||||
|
||||
return pair
|
||||
|
||||
def _state_emb(self, seq):
|
||||
return self.emb_state(seq)
|
||||
|
||||
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
|
||||
# Inputs:
|
||||
# - msa: Input MSA (B, N, L, d_init)
|
||||
@@ -45,25 +66,24 @@ class MSA_emb(nn.Module):
|
||||
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||
# - pair: Initial Pair embedding (B, L, L, d_pair)
|
||||
|
||||
N = msa.shape[1] # number of sequenes in MSA
|
||||
|
||||
# msa embedding
|
||||
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
|
||||
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
|
||||
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
#msa = self.drop(msa)
|
||||
msa = self._msa_emb(msa, seq)
|
||||
|
||||
# pair embedding
|
||||
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
|
||||
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
|
||||
pair = left + right # (B, L, L, d_pair)
|
||||
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix) # add relative position
|
||||
|
||||
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
|
||||
# state embedding
|
||||
state = self.emb_state(seq)
|
||||
|
||||
state = self._state_emb(seq)
|
||||
return msa, pair, state
|
||||
|
||||
class MSA_emb_nostate(MSA_emb):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2 * NAATOKENS + 2 + 2, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
|
||||
super().__init__(d_msa, d_pair, d_state, d_init, minpos, maxpos, maxpos_atom, p_drop, use_same_chain)
|
||||
self.emb_state = None # emb state is just the identity
|
||||
|
||||
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
|
||||
msa = self._msa_emb(msa, seq)
|
||||
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
|
||||
return msa, pair, None
|
||||
|
||||
class Extra_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_init=NAATOKENS-1+4, p_drop=0.1):
|
||||
@@ -352,7 +372,7 @@ class Recycling(nn.Module):
|
||||
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||
nn.init.zeros_(self.proj_dist.bias)
|
||||
|
||||
def forward(self, msa, pair, xyz, mask_recycle=None):
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = msa.shape[:2]
|
||||
msa = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
@@ -367,5 +387,47 @@ class Recycling(nn.Module):
|
||||
|
||||
pair = pair + self.proj_dist(dist_CA)
|
||||
|
||||
return msa, pair
|
||||
return msa, pair, state # state is just zeros
|
||||
|
||||
class RecyclingAllFeatures(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||
super(RecyclingAllFeatures, self).__init__()
|
||||
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa)
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||
nn.init.zeros_(self.proj_dist.bias)
|
||||
self.proj_sctors = init_lecun_normal(self.proj_sctors)
|
||||
nn.init.zeros_(self.proj_sctors.bias)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = pair.shape[:2]
|
||||
state = self.norm_state(state)
|
||||
|
||||
left = state.unsqueeze(2).expand(-1,-1,L,-1)
|
||||
right = state.unsqueeze(1).expand(-1,L,-1,-1)
|
||||
|
||||
Ca_or_P = xyz[:,:,1].contiguous()
|
||||
|
||||
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
|
||||
if mask_recycle != None:
|
||||
dist = mask_recycle[...,None].float()*dist
|
||||
dist = torch.cat((dist, left, right), dim=-1)
|
||||
dist = self.proj_dist(dist)
|
||||
pair = dist + self.norm_pair(pair)
|
||||
|
||||
sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS))
|
||||
msa = sctors + self.norm_msa(msa)
|
||||
|
||||
return msa, pair, state
|
||||
|
||||
recycling_factory = {
|
||||
"msa_pair": Recycling,
|
||||
"all": RecyclingAllFeatures
|
||||
}
|
||||
296
rf2aa/model/layers/SE3_network.py
Normal file
296
rf2aa/model/layers/SE3_network.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from icecream import ic
|
||||
import inspect
|
||||
|
||||
import sys, os
|
||||
#script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
#sys.path.insert(0,script_dir+'SE3Transformer')
|
||||
|
||||
from rf2aa.util import xyz_frame_from_rotation_mask
|
||||
from rf2aa.util_module import init_lecun_normal_param, \
|
||||
make_full_graph, rbf, init_lecun_normal
|
||||
from rf2aa.loss.loss import calc_chiral_grads
|
||||
from rf2aa.model.layers.Attention_module import FeedForwardLayer
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.model.layers.resnet import SCPred
|
||||
|
||||
se3_transformer_path = inspect.getfile(SE3Transformer)
|
||||
se3_fiber_path = inspect.getfile(Fiber)
|
||||
assert 'rf2aa' in se3_transformer_path
|
||||
|
||||
class SE3TransformerWrapper(nn.Module):
|
||||
"""SE(3) equivariant GCN with attention"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=2,
|
||||
num_edge_features=32):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.l1_in = l1_in_features
|
||||
self.l1_out = l1_out_features
|
||||
#
|
||||
fiber_edge = Fiber({0: num_edge_features})
|
||||
if l1_out_features > 0:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
|
||||
self.se3 = SE3Transformer(num_layers=num_layers,
|
||||
fiber_in=fiber_in,
|
||||
fiber_hidden=fiber_hidden,
|
||||
fiber_out = fiber_out,
|
||||
num_heads=n_heads,
|
||||
channels_div=div,
|
||||
fiber_edge=fiber_edge,
|
||||
populate_edge="arcsin",
|
||||
final_layer="lin",
|
||||
use_layer_norm=True)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
|
||||
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||
for n, p in self.se3.named_parameters():
|
||||
if "bias" in n:
|
||||
nn.init.zeros_(p)
|
||||
elif len(p.shape) == 1:
|
||||
continue
|
||||
else:
|
||||
if "radial_func" not in n:
|
||||
p = init_lecun_normal_param(p)
|
||||
else:
|
||||
if "net.6" in n:
|
||||
nn.init.zeros_(p)
|
||||
else:
|
||||
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||
|
||||
# make last layers to be zero-initialized
|
||||
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
if self.l1_out > 0:
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
|
||||
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||
if self.l1_in > 0:
|
||||
node_features = {'0': type_0_features, '1': type_1_features}
|
||||
else:
|
||||
node_features = {'0': type_0_features}
|
||||
edge_features = {'0': edge_features}
|
||||
return self.se3(G, node_features, edge_features)
|
||||
|
||||
class FullyConnectedSE3_noR(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_msa,
|
||||
d_pair,
|
||||
d_rbf,
|
||||
num_layers,
|
||||
num_channels,
|
||||
num_degrees,
|
||||
n_heads,
|
||||
div,
|
||||
l0_in_features,
|
||||
l0_out_features,
|
||||
l1_in_features,
|
||||
l1_out_features,
|
||||
num_edge_features
|
||||
):
|
||||
super(FullyConnectedSE3_noR, self).__init__()
|
||||
# initial node & pair feature process
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.embed_node = nn.Linear(d_msa, l0_in_features)
|
||||
self.ff_node = FeedForwardLayer(l0_in_features, 2) #HACK: hardcoded value
|
||||
self.norm_node = nn.LayerNorm(l0_in_features)
|
||||
|
||||
self.embed_edge = nn.Linear(d_pair+d_rbf, num_edge_features)
|
||||
self.ff_edge = FeedForwardLayer(num_edge_features, 2)
|
||||
self.norm_edge = nn.LayerNorm(num_edge_features)
|
||||
|
||||
self.se3 = SE3TransformerWrapper(
|
||||
num_layers=num_layers,
|
||||
num_channels=num_channels,
|
||||
num_degrees=num_degrees,
|
||||
n_heads=n_heads,
|
||||
div=div,
|
||||
l0_in_features=l0_in_features,
|
||||
l0_out_features=l0_out_features,
|
||||
l1_in_features=l1_in_features,
|
||||
l1_out_features=l1_out_features,
|
||||
num_edge_features=num_edge_features
|
||||
)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.embed_node = init_lecun_normal(self.embed_node)
|
||||
self.embed_edge = init_lecun_normal(self.embed_edge)
|
||||
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.embed_node.bias)
|
||||
nn.init.zeros_(self.embed_edge.bias)
|
||||
|
||||
def embed_node_feats(self, msa, state):
|
||||
seq = self.norm_msa(msa[:, 0])
|
||||
node = self.embed_node(seq)
|
||||
node = node + self.ff_node(node)
|
||||
node = self.norm_node(node)
|
||||
return node
|
||||
|
||||
def embed_edge_feats(self, pair, xyz):
|
||||
pair = self.norm_pair(pair)
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
rbf_feat = rbf(torch.cdist(cas, cas))
|
||||
edge = torch.cat((pair, rbf_feat), dim=-1)
|
||||
edge = self.embed_edge(edge)
|
||||
edge = edge + self.ff_edge(edge)
|
||||
edge = self.norm_edge(edge)
|
||||
return edge
|
||||
|
||||
def construct_graph(self, xyz, edge):
|
||||
B, L = xyz.shape[:2]
|
||||
idx = torch.arange(L, device=edge.device).reshape(B, L) # NOTE: only works in B==1
|
||||
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx)
|
||||
return G, edge_feats
|
||||
|
||||
def construct_l1_feats(self, xyz, is_atom, atom_frames, chirals):
|
||||
l1_feats = get_chiral_vectors(xyz[...,:3,:], chirals)[..., 1:2, :] # only pass features from Calpha
|
||||
return l1_feats
|
||||
|
||||
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, is_atom):
|
||||
B, L = xyz.shape[:2]
|
||||
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||||
|
||||
state = shift["0"].reshape(B, L, -1)
|
||||
offset = shift["1"].reshape(B, L, 3)
|
||||
T = offset / 10.0
|
||||
xyz_update = xyz.clone()
|
||||
xyz_update[...,1:2, :] = xyz[..., 1:2, :] +T[..., None, :]
|
||||
return state, xyz_update
|
||||
|
||||
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
|
||||
#TODO: allow these functions to accept kwargs so we can pass
|
||||
# different inputs when iterating
|
||||
B, N, L = msa.shape[:3]
|
||||
node = self.embed_node_feats(msa, state)
|
||||
edge = self.embed_edge_feats(pair, xyz)
|
||||
G, edge_feats = self.construct_graph(xyz, edge)
|
||||
#TODO: get extra l1 feats automatically and populate the extra l1 dimension
|
||||
l1_feats = self.construct_l1_feats(xyz, is_atom, atom_frames, chirals)
|
||||
state, xyz_update = self.compute_structure_update(G, node, l1_feats, edge_feats, xyz, is_atom)
|
||||
return state, xyz_update
|
||||
|
||||
|
||||
class FullyConnectedSE3(FullyConnectedSE3_noR):
|
||||
def __init__(self, d_msa, d_pair, d_state, d_rbf, num_layers, num_channels, num_degrees, n_heads, div, l0_in_features, l0_out_features, l1_in_features, l1_out_features, num_edge_features, sc_pred_d_hidden, sc_pred_p_drop):
|
||||
super().__init__(d_msa, d_pair, d_rbf, num_layers, num_channels, num_degrees, n_heads, div, l0_in_features, l0_out_features, l1_in_features, l1_out_features, num_edge_features)
|
||||
self.embed_node = nn.Linear(d_msa+d_state, l0_in_features)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
self.sc_predictor = SCPred(
|
||||
d_msa=d_msa,
|
||||
d_state=l0_out_features,
|
||||
d_hidden=sc_pred_d_hidden,
|
||||
p_drop=sc_pred_p_drop
|
||||
)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.embed_node = init_lecun_normal(self.embed_node)
|
||||
self.embed_edge = init_lecun_normal(self.embed_edge)
|
||||
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.embed_node.bias)
|
||||
nn.init.zeros_(self.embed_edge.bias)
|
||||
|
||||
nn.init.ones_(self.norm_msa.weight)
|
||||
nn.init.ones_(self.norm_pair.weight)
|
||||
|
||||
def embed_node_feats(self, msa, state):
|
||||
seq = self.norm_msa(msa[:, 0])
|
||||
state = self.norm_state(state)
|
||||
node = self.embed_node(torch.cat((seq, state), dim=-1))
|
||||
node = node + self.ff_node(node)
|
||||
node = self.norm_node(node)
|
||||
return node
|
||||
|
||||
def construct_l1_feats(self, xyz, is_atom, atom_frames, chirals):
|
||||
l1_feats = torch.cat(
|
||||
[
|
||||
get_backbone_offset_vectors(xyz, is_atom, atom_frames),
|
||||
get_chiral_vectors(xyz, chirals)
|
||||
], dim=1
|
||||
)
|
||||
return l1_feats
|
||||
|
||||
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, is_atom):
|
||||
B, L = node.shape[:2]
|
||||
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||||
|
||||
state = shift["0"].reshape(B, L, -1)
|
||||
offset = shift["1"].reshape(B, L, 2, 3)
|
||||
T = offset[:,:,0,:] / 10
|
||||
R = offset[:,:,1,:] / 100.0
|
||||
|
||||
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
|
||||
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
|
||||
|
||||
v = xyz - xyz[:,:,1:2,:]
|
||||
Rout = torch.zeros((B,L,3,3), device=xyz.device)
|
||||
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
|
||||
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
|
||||
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
|
||||
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
|
||||
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
|
||||
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
|
||||
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
|
||||
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
|
||||
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
|
||||
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
|
||||
Rout = torch.where(is_atom.reshape(B, L, 1,1), I, Rout)
|
||||
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
|
||||
|
||||
return state, xyz
|
||||
|
||||
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
|
||||
|
||||
state, xyz = super().forward(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
alpha = self.sc_predictor(msa[:, 0], state)
|
||||
return {
|
||||
"state": state,
|
||||
"xyz": xyz,
|
||||
"alpha": alpha
|
||||
}
|
||||
|
||||
def get_backbone_offset_vectors(xyz, is_atom, atom_frames):
|
||||
xyz_frame = xyz_frame_from_rotation_mask(xyz, is_atom, atom_frames)
|
||||
l1_feats = xyz_frame - xyz_frame[:,:,1,:].unsqueeze(2)
|
||||
return l1_feats[0][..., :3, :]
|
||||
|
||||
def get_chiral_vectors(xyz, chirals):
|
||||
dchiraldxyz, = calc_chiral_grads(xyz,chirals)
|
||||
extra_l1 = dchiraldxyz[0]
|
||||
extra_l1_slice = extra_l1.clone()
|
||||
return extra_l1_slice.detach()
|
||||
|
||||
37
rf2aa/model/layers/outer_product.py
Normal file
37
rf2aa/model/layers/outer_product.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.util_module import init_lecun_normal
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
|
||||
super(OuterProductMean, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_msa)
|
||||
self.proj_left = nn.Linear(d_msa, d_hidden)
|
||||
self.proj_right = nn.Linear(d_msa, d_hidden)
|
||||
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal initialization
|
||||
self.proj_left = init_lecun_normal(self.proj_left)
|
||||
self.proj_right = init_lecun_normal(self.proj_right)
|
||||
nn.init.zeros_(self.proj_left.bias)
|
||||
nn.init.zeros_(self.proj_right.bias)
|
||||
|
||||
# zero initialize output
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
msa = self.norm(msa)
|
||||
left = self.proj_left(msa)
|
||||
right = self.proj_right(msa)
|
||||
right = right / float(N)
|
||||
out = torch.einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
|
||||
out = self.proj_out(out)
|
||||
|
||||
return out
|
||||
139
rf2aa/model/layers/resnet.py
Normal file
139
rf2aa/model/layers/resnet.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from rf2aa.chemical import NTOTALDOFS
|
||||
from rf2aa.util_module import init_lecun_normal
|
||||
|
||||
# pre-activation bottleneck resblock
|
||||
class ResBlock2D_bottleneck(nn.Module):
|
||||
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
|
||||
super(ResBlock2D_bottleneck, self).__init__()
|
||||
padding = self._get_same_padding(kernel, dilation)
|
||||
|
||||
n_b = n_c // 2 # bottleneck channel
|
||||
|
||||
layer_s = list()
|
||||
# pre-activation
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# project down to n_b
|
||||
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# convolution
|
||||
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# dropout
|
||||
layer_s.append(nn.Dropout(p_drop))
|
||||
# project up
|
||||
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
|
||||
|
||||
# make final layer initialize with zeros
|
||||
#nn.init.zeros_(layer_s[-1].weight)
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# zero-initialize final layer right before residual connection
|
||||
nn.init.zeros_(self.layer[-1].weight)
|
||||
|
||||
def _get_same_padding(self, kernel, dilation):
|
||||
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return x + out
|
||||
|
||||
class ResidualNetwork(nn.Module):
|
||||
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
|
||||
dilation=[1,2,4,8], p_drop=0.15):
|
||||
super(ResidualNetwork, self).__init__()
|
||||
|
||||
|
||||
layer_s = list()
|
||||
# project to n_feat_block
|
||||
if n_feat_in != n_feat_block:
|
||||
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
|
||||
|
||||
# add resblocks
|
||||
for i_block in range(n_block):
|
||||
d = dilation[i_block%len(dilation)]
|
||||
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
|
||||
layer_s.append(res_block)
|
||||
|
||||
if n_feat_out != n_feat_block:
|
||||
# project to n_feat_out
|
||||
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
#TODO: get rid of this duplicated code, needed now to avoid circular import
|
||||
class SCPred(nn.Module):
|
||||
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
|
||||
super(SCPred, self).__init__()
|
||||
self.norm_s0 = nn.LayerNorm(d_msa)
|
||||
self.norm_si = nn.LayerNorm(d_state)
|
||||
self.linear_s0 = nn.Linear(d_msa, d_hidden)
|
||||
self.linear_si = nn.Linear(d_state, d_hidden)
|
||||
|
||||
# ResNet layers
|
||||
self.linear_1 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_2 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_3 = nn.Linear(d_hidden, d_hidden)
|
||||
self.linear_4 = nn.Linear(d_hidden, d_hidden)
|
||||
|
||||
# Final outputs
|
||||
self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal initialization
|
||||
self.linear_s0 = init_lecun_normal(self.linear_s0)
|
||||
self.linear_si = init_lecun_normal(self.linear_si)
|
||||
self.linear_out = init_lecun_normal(self.linear_out)
|
||||
nn.init.zeros_(self.linear_s0.bias)
|
||||
nn.init.zeros_(self.linear_si.bias)
|
||||
nn.init.zeros_(self.linear_out.bias)
|
||||
|
||||
# right before relu activation: He initializer (kaiming normal)
|
||||
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear_1.bias)
|
||||
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear_3.bias)
|
||||
|
||||
# right before residual connection: zero initialize
|
||||
nn.init.zeros_(self.linear_2.weight)
|
||||
nn.init.zeros_(self.linear_2.bias)
|
||||
nn.init.zeros_(self.linear_4.weight)
|
||||
nn.init.zeros_(self.linear_4.bias)
|
||||
|
||||
def forward(self, seq, state):
|
||||
'''
|
||||
Predict side-chain torsion angles along with backbone torsions
|
||||
Inputs:
|
||||
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
|
||||
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
|
||||
Outputs:
|
||||
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
|
||||
'''
|
||||
B, L = seq.shape[:2]
|
||||
seq = self.norm_s0(seq)
|
||||
state = self.norm_si(state)
|
||||
si = self.linear_s0(seq) + self.linear_si(state)
|
||||
|
||||
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
|
||||
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
|
||||
|
||||
si = self.linear_out(F.relu_(si))
|
||||
return si.view(B, L, NTOTALDOFS, 2)
|
||||
|
||||
62
rf2aa/model/layers/structure_bias.py
Normal file
62
rf2aa/model/layers/structure_bias.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from rf2aa.util_module import rbf, init_lecun_normal
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from opt_einsum import contract as einsum
|
||||
|
||||
class StructureBias(torch.nn.Module):
|
||||
|
||||
def __init__(self, d_rbf, d_pair) -> None:
|
||||
super(StructureBias, self).__init__()
|
||||
self.proj_rbf = nn.Linear(d_rbf, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_rbf = init_lecun_normal(self.proj_rbf)
|
||||
nn.init.zeros_(self.proj_rbf.bias)
|
||||
|
||||
def forward(self, xyz):
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
rbf_feat = rbf(torch.cdist(cas, cas))
|
||||
bias = self.proj_rbf(rbf_feat)
|
||||
return bias
|
||||
|
||||
|
||||
class GatedStructureBias(torch.nn.Module):
|
||||
|
||||
def __init__(self, d_rbf, d_state, d_pair, d_hidden_gate) -> None:
|
||||
super(GatedStructureBias, self).__init__()
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
self.proj_rbf = nn.Linear(d_rbf, d_pair)
|
||||
self.proj_left = nn.Linear(d_state, d_hidden_gate)
|
||||
self.proj_right = nn.Linear(d_state, d_hidden_gate)
|
||||
self.to_gate = nn.Linear(d_hidden_gate*d_hidden_gate, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
pass
|
||||
|
||||
def forward(self, xyz, state):
|
||||
B, L = xyz.shape[:2]
|
||||
cas = xyz[:,:,1].contiguous()
|
||||
|
||||
rbf_feat = rbf(torch.cdist(cas, cas))
|
||||
rbf_feat = self.proj_rbf(rbf_feat)
|
||||
|
||||
state = self.norm_state(state)
|
||||
left = self.proj_left(state)
|
||||
right = self.proj_right(state)
|
||||
gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1)
|
||||
gate = torch.sigmoid(self.to_gate(gate))
|
||||
rbf_feat = gate*rbf_feat
|
||||
|
||||
return rbf_feat
|
||||
|
||||
|
||||
structure_bias_factory = {
|
||||
"ungated": StructureBias,
|
||||
"gated": GatedStructureBias
|
||||
}
|
||||
92
rf2aa/model/network.py
Normal file
92
rf2aa/model/network.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.embedding_blocks import embedding_factory
|
||||
from rf2aa.model.refinement_blocks import refinement_factory
|
||||
from rf2aa.model.simulator_blocks import block_factory
|
||||
from rf2aa.model.layers.AuxiliaryPredictor import aux_predictor_factory
|
||||
from rf2aa.training.checkpoint import create_custom_forward
|
||||
from rf2aa.util import is_atom
|
||||
|
||||
|
||||
class RosettaFold(nn.Module):
|
||||
""" creates an instance of RosettaFold which includes an embedder, trunk, refinement layers and aux predictor"""
|
||||
def __init__(self, config):
|
||||
super(RosettaFold, self).__init__()
|
||||
model_params = config.model
|
||||
|
||||
assert len(model_params.embedding.keys()) == 1, "only can have one embedder"
|
||||
embedding_type = next(iter(model_params.embedding.keys()))
|
||||
self.embedding = embedding_factory[embedding_type](model_params["global_params"], model_params.embedding[embedding_type]["params"])
|
||||
|
||||
## instantiate blocks of network
|
||||
blocks = []
|
||||
for block in model_params.blocks.keys():
|
||||
if block not in block_factory:
|
||||
raise ValueError(f"User specified {block} type, but this block is not registered in rf2aa.Trunk_blocks.")
|
||||
blocks_to_add = [block_factory[block](
|
||||
global_config=model_params["global_params"],
|
||||
block_params=model_params.blocks[block]["params"])
|
||||
for _ in range(model_params.blocks[block]["num_blocks"])]
|
||||
blocks.extend(blocks_to_add)
|
||||
|
||||
self.simulator = nn.ModuleList(blocks)
|
||||
|
||||
assert len(model_params.refinement.keys()) == 1, "only can have one refinment block"
|
||||
refinement_type = next(iter(model_params.refinement.keys()))
|
||||
|
||||
self.refinement = refinement_factory[refinement_type](
|
||||
model_params["global_params"],
|
||||
model_params.refinement[refinement_type]["params"]
|
||||
)
|
||||
|
||||
aux_tasks = {}
|
||||
for aux_task in model_params.auxiliary_predictors.keys():
|
||||
aux_tasks.update({
|
||||
aux_task:
|
||||
aux_predictor_factory[aux_task](
|
||||
model_params.auxiliary_predictors[aux_task]["n_feat"])
|
||||
}
|
||||
) #HACK: eventually this will just use the correct n_feat from the global config
|
||||
self.auxiliary_predictors = nn.ModuleDict(aux_tasks)
|
||||
self.auxiliary_predictor_input_feats = {
|
||||
aux_task:model_params.auxiliary_predictors[aux_task]["input_feature"] \
|
||||
for aux_task in model_params.auxiliary_predictors.keys()
|
||||
}
|
||||
|
||||
def forward(self, rf_inputs, use_checkpoint):
|
||||
latent_feats = self.embedding(rf_inputs)
|
||||
#load useful primitives into latent_features
|
||||
latent_feats.update(
|
||||
{
|
||||
"is_atom": is_atom(rf_inputs["seq_unmasked"]),
|
||||
"atom_frames": rf_inputs["atom_frames"],
|
||||
"chirals": rf_inputs["chirals"],
|
||||
"xyz": rf_inputs["xyz"]
|
||||
}
|
||||
)
|
||||
for block in self.simulator:
|
||||
latent_feats = block(latent_feats,use_checkpoint)
|
||||
|
||||
rf_outputs = self.refinement(latent_feats)
|
||||
|
||||
for aux_task, aux_predictor in self.auxiliary_predictors.items():
|
||||
input_feature = self.auxiliary_predictor_input_feats[aux_task]
|
||||
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
|
||||
rf_outputs.update({aux_task: auxiliary_predictions})
|
||||
|
||||
|
||||
return rf_outputs, latent_feats
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path='../config/train', config_name='base')
|
||||
def main(config):
|
||||
model = RosettaFold(config)
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ =="__main__":
|
||||
main()
|
||||
172
rf2aa/model/refinement_blocks.py
Normal file
172
rf2aa/model/refinement_blocks.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors
|
||||
from rf2aa.model.Track_module import SCPred
|
||||
from rf2aa.util import NTOTAL, NTOTALDOFS
|
||||
from rf2aa.util_module import rbf, make_topk_graph, init_lecun_normal
|
||||
|
||||
class LocalRefinementSE3(FullyConnectedSE3):
|
||||
|
||||
def __init__(self, global_config, block_params):
|
||||
d_msa, d_pair, d_state = global_config.d_msa, global_config.d_pair, global_config.d_state
|
||||
d_rbf, num_layers, num_channels, num_degrees, n_heads, div, \
|
||||
l0_in_features, l0_out_features, l1_in_features, l1_out_features, \
|
||||
num_edge_features, top_k, sc_pred_d_hidden, sc_pred_p_drop = \
|
||||
block_params.d_rbf, block_params.n_se3_layers, block_params.n_se3_channels, \
|
||||
block_params.n_se3_degrees, block_params.n_se3_head, block_params.n_div, \
|
||||
block_params.l0_in_features, block_params.l0_out_features, \
|
||||
block_params.l1_in_features, block_params.l1_out_features, \
|
||||
block_params.n_se3_edge_features, block_params.top_k, \
|
||||
block_params.sc_pred_d_hidden, block_params.sc_pred_p_drop
|
||||
|
||||
super(LocalRefinementSE3, self).__init__(d_msa,
|
||||
d_pair,
|
||||
d_state,
|
||||
d_rbf,
|
||||
num_layers,
|
||||
num_channels,
|
||||
num_degrees,
|
||||
n_heads,
|
||||
div,
|
||||
l0_in_features,
|
||||
l0_out_features,
|
||||
l1_in_features,
|
||||
l1_out_features,
|
||||
num_edge_features,
|
||||
sc_pred_d_hidden,
|
||||
sc_pred_p_drop
|
||||
)
|
||||
self.top_k = top_k
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize weights to normal distribution
|
||||
self.embed_node = init_lecun_normal(self.embed_node)
|
||||
self.embed_edge = init_lecun_normal(self.embed_edge)
|
||||
|
||||
# initialize bias to zeros
|
||||
nn.init.zeros_(self.embed_node.bias)
|
||||
nn.init.zeros_(self.embed_edge.bias)
|
||||
|
||||
nn.init.ones_(self.norm_msa.weight)
|
||||
nn.init.ones_(self.norm_pair.weight)
|
||||
|
||||
def construct_graph(self, xyz, edge):
|
||||
L = xyz.shape[1]
|
||||
idx = torch.arange(L, device=edge.device)[None]
|
||||
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
|
||||
return G, edge_feats
|
||||
|
||||
#def reset_parameter(self):
|
||||
## initialize weights to normal distribution
|
||||
#self.embed_node = init_lecun_normal(self.embed_node)
|
||||
#self.embed_edge = init_lecun_normal(self.embed_edge)
|
||||
|
||||
## initialize bias to zeros
|
||||
#nn.init.zeros_(self.embed_node.bias)
|
||||
#nn.init.zeros_(self.embed_edge.bias)
|
||||
|
||||
#nn.init.ones_(self.norm_msa.weight)
|
||||
#nn.init.ones_(self.norm_pair.weight)
|
||||
|
||||
|
||||
#def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
|
||||
#B, N, L = msa.shape[:3]
|
||||
#seq = self.norm_msa(msa[:, 0])
|
||||
#pair = self.norm_pair(pair)
|
||||
|
||||
#node = self.embed_node(torch.cat((seq, state), dim=-1))
|
||||
#node = node + self.ff_node(node)
|
||||
#node = self.norm_node(node)
|
||||
|
||||
##NOTE: Ablating providing the positional encoding at every step
|
||||
## we introduced this and I do not think it is in RF2
|
||||
##neighbor = get_seqsep_protein_sm(idx, bond_feats, dist_matrix, rotation_mask)
|
||||
#cas = xyz[:,:,1].contiguous()
|
||||
#rbf_feat = rbf(torch.cdist(cas, cas))
|
||||
#edge = torch.cat((pair, rbf_feat), dim=-1)
|
||||
#edge = self.embed_edge(edge)
|
||||
#edge = edge + self.ff_edge(edge)
|
||||
#edge = self.norm_edge(edge)
|
||||
|
||||
#idx = torch.arange(L, device=edge.device)[None]
|
||||
#G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
|
||||
|
||||
##TODO: get extra l1 feats automatically and populate the extra l1 dimension
|
||||
#l1_feats = torch.cat(
|
||||
#[
|
||||
#get_backbone_offset_vectors(xyz, is_atom, atom_frames),
|
||||
#get_chiral_vectors(xyz, chirals)
|
||||
#], dim=1
|
||||
#)
|
||||
|
||||
#shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||||
|
||||
#state = shift["0"].reshape(B, L, -1)
|
||||
#offset = shift["1"].reshape(B, L, 2, 3)
|
||||
#T = offset[:,:,0,:] / 10
|
||||
#R = offset[:,:,1,:] / 100.0
|
||||
|
||||
#Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
|
||||
#qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
|
||||
|
||||
#v = xyz - xyz[:,:,1:2,:]
|
||||
#Rout = torch.zeros((B,L,3,3), device=xyz.device)
|
||||
#Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
|
||||
#Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
|
||||
#Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
|
||||
#Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
|
||||
#Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
|
||||
#Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
|
||||
#Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
|
||||
#Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
|
||||
#Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
|
||||
#I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
|
||||
#Rout = torch.where(is_atom.reshape(B, L, 1,1), I, Rout)
|
||||
#xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
|
||||
|
||||
#alpha = self.sc_predictor(msa[:,0], state)
|
||||
#return {
|
||||
#"state": state,
|
||||
#"xyz": xyz,
|
||||
#"alpha": alpha
|
||||
#}
|
||||
|
||||
class RecurrentLocalRefinement(nn.Module):
|
||||
|
||||
def __init__(self, global_config, block_params):
|
||||
super(RecurrentLocalRefinement, self).__init__()
|
||||
self.num_iterations = block_params.num_iterations
|
||||
|
||||
self.se3 = LocalRefinementSE3(global_config, block_params)
|
||||
|
||||
def _unpack_inputs(self, latent_feats):
|
||||
msa, pair, state, xyz, is_atom, atom_frames, chirals = \
|
||||
latent_feats["msa"], latent_feats["pair"], \
|
||||
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
|
||||
latent_feats["atom_frames"], latent_feats["chirals"]
|
||||
return msa, pair, state, xyz, is_atom, atom_frames, chirals
|
||||
|
||||
def forward(self, latent_feats):
|
||||
B, N, L = latent_feats["msa"].shape[:3]
|
||||
xyzs = torch.full((self.num_iterations, L, 3, 3 ), torch.nan, device=latent_feats["msa"].device)
|
||||
alphas = torch.full((self.num_iterations, L, NTOTALDOFS, 2), torch.nan, device=latent_feats["msa"].device)
|
||||
|
||||
msa, pair, state, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
|
||||
|
||||
for i in range(self.num_iterations):
|
||||
output = self.se3(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
xyzs[i] = output["xyz"]
|
||||
alphas[i] = output["alpha"]
|
||||
latent_feats["state"] = output["state"]
|
||||
|
||||
return {
|
||||
"xyzs": xyzs,
|
||||
"alphas": alphas,
|
||||
}
|
||||
|
||||
refinement_factory ={
|
||||
"local": RecurrentLocalRefinement
|
||||
}
|
||||
324
rf2aa/model/simulator_blocks.py
Normal file
324
rf2aa/model/simulator_blocks.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from functools import partial
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
||||
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
||||
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
||||
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention
|
||||
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
|
||||
from rf2aa.training.checkpoint import create_custom_forward
|
||||
from rf2aa.util_module import Dropout
|
||||
|
||||
|
||||
class RF2_block(nn.Module):
|
||||
"""
|
||||
nearly faithful implementation of RF2aa blocks in new paradigm
|
||||
unfaithful portions are:
|
||||
- ablating adding the positional encodings as biases to the attentions
|
||||
- the "is_bonded" boolean feature is no longer embedded with the edge features of the SE3 transformer
|
||||
"""
|
||||
|
||||
def __init__(self, global_config, block_params, is_full, **kwargs):
|
||||
super(RF2_block, self).__init__()
|
||||
d_msa, d_msa_full, d_pair, d_state = global_config.d_msa, global_config.d_msa_full, global_config.d_pair, \
|
||||
global_config.d_state
|
||||
self.is_full = is_full
|
||||
|
||||
if self.is_full:
|
||||
d_msa = d_msa_full
|
||||
|
||||
self.norm_pair_bias = nn.LayerNorm(d_pair)
|
||||
self.norm_state_bias = nn.LayerNorm(d_state)
|
||||
self.proj_state_bias = nn.Linear(d_state, d_msa)
|
||||
self.msa_str_bias = structure_bias_factory["ungated"](block_params.d_rbf, d_pair)
|
||||
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
|
||||
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_pair)
|
||||
|
||||
self.msa_row_attn = MSARowAttentionWithBias(
|
||||
d_msa=d_msa, d_pair=d_pair, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels)
|
||||
if self.is_full:
|
||||
self.msa_col_attn = MSAColGlobalAttention(
|
||||
d_msa=d_msa,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels
|
||||
)
|
||||
else:
|
||||
self.msa_col_attn = MSAColAttention(
|
||||
d_msa=d_msa,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels
|
||||
)
|
||||
self.msa_transition = FeedForwardLayer(d_msa, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
# Pair update parameters
|
||||
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels, \
|
||||
p_drop=block_params.p_drop_outer_product)
|
||||
|
||||
|
||||
self.structure_bias = structure_bias_factory["gated"](block_params.d_rbf, d_state, d_pair, block_params.structure_bias_gate_channels)
|
||||
self.tri_mul_outgoing = TriangleMultiplication(d_pair, d_hidden=block_params.n_pair_channels, outgoing=True)
|
||||
self.tri_mul_incoming = TriangleMultiplication(d_pair, d_hidden=block_params.n_pair_channels, outgoing=False)
|
||||
self.pair_row_attn = BiasedAxialAttention(d_pair, d_pair, block_params.n_pair_head, block_params.n_pair_channels, p_drop=block_params.p_drop_pair, is_row=True)
|
||||
self.pair_col_attn = BiasedAxialAttention(d_pair, d_pair, block_params.n_pair_head, block_params.n_pair_channels, p_drop=block_params.p_drop_pair, is_row=False)
|
||||
self.pair_transition = FeedForwardLayer(d_pair, 2) # HACK: hardcoded value for transition
|
||||
|
||||
self.structure_attn = FullyConnectedSE3(d_msa,
|
||||
d_pair,
|
||||
d_state,
|
||||
block_params.d_rbf,
|
||||
block_params.n_se3_layers,
|
||||
block_params.n_se3_channels,
|
||||
block_params.n_se3_degrees,
|
||||
block_params.n_se3_head,
|
||||
block_params.n_div,
|
||||
block_params.l0_in_features,
|
||||
block_params.l0_out_features,
|
||||
block_params.l1_in_features,
|
||||
block_params.l1_out_features,
|
||||
block_params.n_se3_edge_features,
|
||||
block_params.sc_pred_d_hidden,
|
||||
block_params.sc_pred_p_drop
|
||||
)
|
||||
|
||||
def _unpack_inputs(self, latent_feats):
|
||||
|
||||
pair, state, xyz, is_atom, atom_frames, chirals = \
|
||||
latent_feats["pair"], latent_feats["state"], \
|
||||
latent_feats["xyz"], latent_feats["is_atom"], \
|
||||
latent_feats["atom_frames"], latent_feats["chirals"]
|
||||
if self.is_full:
|
||||
msa = latent_feats["msa_full"]
|
||||
else:
|
||||
msa = latent_feats["msa"]
|
||||
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals
|
||||
|
||||
def _pack_outputs(self, msa, pair, state, xyz, alpha, latent_feats):
|
||||
if self.is_full:
|
||||
latent_feats["msa_full"] = msa
|
||||
else:
|
||||
latent_feats["msa"] = msa
|
||||
latent_feats["pair"] = pair
|
||||
latent_feats["state"] = state
|
||||
latent_feats["xyz"] = xyz
|
||||
#HACK: appending to growing list, this could cause weird memory problems in pytorch
|
||||
# eventually want to refactor this to make it more elegant
|
||||
if "xyz_intermediate" not in latent_feats:
|
||||
latent_feats["xyz_intermediate"] = [xyz]
|
||||
else:
|
||||
latent_feats["xyz_intermediate"].append(xyz)
|
||||
|
||||
if "alpha_intermediate" not in latent_feats:
|
||||
latent_feats["alpha_intermediate"] = [alpha]
|
||||
else:
|
||||
latent_feats["alpha_intermediate"].append(alpha)
|
||||
return latent_feats
|
||||
|
||||
def _1d_update(self, msa, pair, state, xyz):
|
||||
pair = self.norm_pair_bias(pair)
|
||||
pair = pair + self.msa_str_bias(xyz)
|
||||
|
||||
state = self.norm_state_bias(state)
|
||||
state_update = self.proj_state_bias(state)
|
||||
|
||||
msa = msa.type_as(state_update)
|
||||
msa = msa.index_add(1, torch.tensor([0,], device=state_update.device), state_update[None])
|
||||
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
|
||||
msa = msa + self.msa_col_attn(msa)
|
||||
msa = msa + self.msa_transition(msa)
|
||||
return msa
|
||||
|
||||
def _2d_update(self, msa, pair, state, xyz):
|
||||
msa_bias = self.outer_product(msa)
|
||||
pair = pair + msa_bias
|
||||
str_bias = self.structure_bias(xyz, state)
|
||||
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
|
||||
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
|
||||
pair = pair + self.drop_row(self.pair_row_attn(pair, str_bias))
|
||||
pair = pair + self.drop_col(self.pair_col_attn(pair, str_bias))
|
||||
pair = pair + self.pair_transition(pair)
|
||||
return pair
|
||||
|
||||
def _3d_update(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
|
||||
block_outputs = self.structure_attn(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
return block_outputs["state"], block_outputs["xyz"], block_outputs["alpha"]
|
||||
|
||||
def forward(self, latent_feats, use_checkpoint):
|
||||
msa, pair, state, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
|
||||
if use_checkpoint:
|
||||
msa = checkpoint.checkpoint(create_custom_forward(self._1d_update), msa, pair, state, xyz)
|
||||
pair = checkpoint.checkpoint(create_custom_forward(self._2d_update), msa, pair, state, xyz)
|
||||
# 3D track cannot use re-entrant = False because of chiral features call to autograd
|
||||
#TODO: allow this to happen since new versions of Pytorch will be using reentrant=False
|
||||
state, xyz, alpha = checkpoint.checkpoint(create_custom_forward(self._3d_update), \
|
||||
msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
else:
|
||||
msa= self._1d_update(msa, pair, state, xyz)
|
||||
pair = self._2d_update(msa, pair, state, xyz)
|
||||
state, xyz, alpha = self._3d_update(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
latent_feats = self._pack_outputs(msa, pair, state, xyz, alpha, latent_feats)
|
||||
return latent_feats
|
||||
|
||||
|
||||
class RF2_withgradients(nn.Module):
|
||||
"""
|
||||
this is an updated version of the RF2 block, without computing rotations
|
||||
to allow gradients to flow through all blocks
|
||||
"""
|
||||
def __init__(self, global_config=None, block_params=None, is_full=False, **kwargs
|
||||
) -> None:
|
||||
super(RF2_withgradients, self).__init__()
|
||||
d_msa, d_msa_full, d_pair = global_config.d_msa, global_config.d_msa_full, global_config.d_pair
|
||||
self.is_full = is_full
|
||||
if self.is_full:
|
||||
d_msa = d_msa_full
|
||||
|
||||
self.msa_row_attn = MSARowAttentionWithBias(
|
||||
d_msa=d_msa, d_pair=d_pair, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels)
|
||||
if self.is_full:
|
||||
self.msa_col_attn = MSAColGlobalAttention(
|
||||
d_msa=d_msa,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels
|
||||
)
|
||||
else:
|
||||
self.msa_col_attn = MSAColAttention(
|
||||
d_msa=d_msa, n_head=block_params.n_msa_head, d_hidden=block_params.n_msa_channels
|
||||
)
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
|
||||
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_col)
|
||||
self.msa_transition = FeedForwardLayer(d_msa, r_ff=4)
|
||||
self.compute_structure_bias = structure_bias_factory[block_params.structure_bias_type](
|
||||
d_rbf=block_params.structure_bias_channels,
|
||||
d_pair=d_pair
|
||||
)
|
||||
self.pair_row_attn = BiasedAxialAttention(
|
||||
d_pair=d_pair,
|
||||
d_bias=d_pair,
|
||||
n_head=block_params.n_pair_head,
|
||||
d_hidden=block_params.n_pair_channels,
|
||||
is_row=True
|
||||
)
|
||||
self.pair_col_attn = BiasedAxialAttention(
|
||||
d_pair=d_pair,
|
||||
d_bias=d_pair,
|
||||
n_head=block_params.n_pair_head,
|
||||
d_hidden=block_params.n_pair_channels,
|
||||
is_row=False
|
||||
)
|
||||
self.tri_mult_incoming = TriangleMultiplication(
|
||||
d_pair=d_pair, d_hidden=block_params.n_pair_channels, outgoing=False
|
||||
)
|
||||
self.tri_mult_outgoing = TriangleMultiplication(
|
||||
d_pair=d_pair, d_hidden=block_params.n_pair_channels, outgoing=True
|
||||
)
|
||||
self.pair_transition = FeedForwardLayer(
|
||||
d_pair, r_ff=4
|
||||
)
|
||||
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels)
|
||||
|
||||
self.structure_attn = FullyConnectedSE3_noR(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_rbf=block_params.structure_bias_channels,
|
||||
num_layers=block_params.n_se3_layers,
|
||||
num_channels=block_params.n_se3_channels,
|
||||
num_degrees=block_params.n_se3_degrees,
|
||||
n_heads=block_params.n_se3_head,
|
||||
div=block_params.n_div,
|
||||
l0_in_features=block_params.l0_in_features,
|
||||
l0_out_features=block_params.l0_out_features,
|
||||
l1_in_features=block_params.l1_in_features,
|
||||
l1_out_features=block_params.l1_out_features,
|
||||
num_edge_features=block_params.n_se3_edge_features
|
||||
)
|
||||
self.structure_transition = FeedForwardLayer(
|
||||
block_params.l0_out_features, r_ff=4
|
||||
)
|
||||
self.proj_state = nn.Linear(block_params.l0_out_features, d_msa)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
pass
|
||||
|
||||
def _unpack_inputs(self, latent_feats):
|
||||
pair, xyz, is_atom, atom_frames, chirals = \
|
||||
latent_feats["pair"], \
|
||||
latent_feats["xyz"], latent_feats["is_atom"], \
|
||||
latent_feats["atom_frames"], latent_feats["chirals"]
|
||||
if self.is_full:
|
||||
msa = latent_feats["msa_full"]
|
||||
else:
|
||||
msa = latent_feats["msa"]
|
||||
return msa, pair, xyz[..., :3, :], is_atom, atom_frames, chirals
|
||||
|
||||
def _pack_outputs(self, msa, pair, state, xyz, latent_feats):
|
||||
if self.is_full:
|
||||
latent_feats["msa_full"] = msa
|
||||
else:
|
||||
latent_feats["msa"] = msa
|
||||
latent_feats["pair"] = pair
|
||||
latent_feats["state"] = state
|
||||
latent_feats["xyz"] = xyz
|
||||
return latent_feats
|
||||
|
||||
def _1d_update(self, msa, pair):
|
||||
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
|
||||
msa = msa + self.drop_col(self.msa_col_attn(msa))
|
||||
msa = msa + self.msa_transition(msa)
|
||||
msa_bias = self.outer_product(msa)
|
||||
|
||||
pair = pair + msa_bias
|
||||
return msa, pair
|
||||
|
||||
def _2d_update(self, pair, xyz):
|
||||
# break 3d symmetries with bias from coordinates
|
||||
structure_bias = self.compute_structure_bias(xyz)
|
||||
pair = pair + self.drop_row(self.pair_row_attn(pair, structure_bias))
|
||||
pair = pair + self.drop_col(self.pair_col_attn(pair, structure_bias))
|
||||
|
||||
# provide triangle inductive bias
|
||||
pair = pair + self.drop_row(self.tri_mult_outgoing(pair))
|
||||
pair = pair + self.drop_row(self.tri_mult_incoming(pair))
|
||||
pair = pair + self.pair_transition(pair)
|
||||
return pair
|
||||
|
||||
def _3d_update(self, msa, pair, state, xyz, is_atom, atom_frames, chirals):
|
||||
# apply structure attention and update seq features
|
||||
state, xyz = self.structure_attn(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
state = state + self.structure_transition(state)
|
||||
|
||||
# state features bias the msa first row features
|
||||
state_update = self.proj_state(state)
|
||||
msa = msa.type_as(state_update)
|
||||
msa = msa.index_add(1, torch.tensor([0,], device=state_update.device), state_update.unsqueeze(1))
|
||||
return msa, state, xyz
|
||||
|
||||
def forward(self, latent_feats, use_checkpoint):
|
||||
msa, pair, xyz, is_atom, atom_frames, chirals = self._unpack_inputs(latent_feats)
|
||||
state = None
|
||||
if use_checkpoint:
|
||||
msa, pair = checkpoint.checkpoint(create_custom_forward(self._1d_update), msa, pair)
|
||||
pair = checkpoint.checkpoint(create_custom_forward(self._2d_update), pair, xyz)
|
||||
# 3D track cannot use re-entrant = False because of chiral features call to autograd
|
||||
#TODO: allow this to happen since new versions of Pytorch will be using reentrant=False
|
||||
msa, state, xyz = checkpoint.checkpoint(create_custom_forward(self._3d_update), \
|
||||
msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
else:
|
||||
msa, pair = self._1d_update(msa, pair)
|
||||
pair = self._2d_update(pair, xyz)
|
||||
msa, state, xyz = self._3d_update(msa, pair, state, xyz, is_atom, atom_frames, chirals)
|
||||
latent_feats = self._pack_outputs(msa, pair, state, xyz, latent_feats)
|
||||
return latent_feats
|
||||
|
||||
|
||||
|
||||
block_factory = {
|
||||
"RF2_withgradients": partial(RF2_withgradients, is_full=False),
|
||||
"RF2_withgradients_full": partial(RF2_withgradients, is_full=True),
|
||||
"RF2aa": partial(RF2_block, is_full=False),
|
||||
"RF2aa_full": partial(RF2_block, is_full=True)
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
# pre-activation bottleneck resblock
|
||||
class ResBlock2D_bottleneck(nn.Module):
|
||||
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
|
||||
super(ResBlock2D_bottleneck, self).__init__()
|
||||
padding = self._get_same_padding(kernel, dilation)
|
||||
|
||||
n_b = n_c // 2 # bottleneck channel
|
||||
|
||||
layer_s = list()
|
||||
# pre-activation
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# project down to n_b
|
||||
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# convolution
|
||||
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# dropout
|
||||
layer_s.append(nn.Dropout(p_drop))
|
||||
# project up
|
||||
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
|
||||
|
||||
# make final layer initialize with zeros
|
||||
#nn.init.zeros_(layer_s[-1].weight)
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# zero-initialize final layer right before residual connection
|
||||
nn.init.zeros_(self.layer[-1].weight)
|
||||
|
||||
def _get_same_padding(self, kernel, dilation):
|
||||
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return x + out
|
||||
|
||||
class ResidualNetwork(nn.Module):
|
||||
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
|
||||
dilation=[1,2,4,8], p_drop=0.15):
|
||||
super(ResidualNetwork, self).__init__()
|
||||
|
||||
|
||||
layer_s = list()
|
||||
# project to n_feat_block
|
||||
if n_feat_in != n_feat_block:
|
||||
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
|
||||
|
||||
# add resblocks
|
||||
for i_block in range(n_block):
|
||||
d = dilation[i_block%len(dilation)]
|
||||
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
|
||||
layer_s.append(res_block)
|
||||
|
||||
if n_feat_out != n_feat_block:
|
||||
# project to n_feat_out
|
||||
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
@@ -1,180 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
|
||||
|
||||
#def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
# optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
||||
#):
|
||||
# """
|
||||
# Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
# initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
||||
# linearly between 0 and the initial lr set in the optimizer.
|
||||
#
|
||||
# Args:
|
||||
# optimizer (:class:`~torch.optim.Optimizer`):
|
||||
# The optimizer for which to schedule the learning rate.
|
||||
# num_warmup_steps (:obj:`int`):
|
||||
# The number of steps for the warmup phase.
|
||||
# num_training_steps (:obj:`int`):
|
||||
# The total number of training steps.
|
||||
# num_cycles (:obj:`int`, `optional`, defaults to 1):
|
||||
# The number of hard restarts to use.
|
||||
# last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
# The index of the last epoch when resuming training.
|
||||
#
|
||||
# Return:
|
||||
# :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
# """
|
||||
#
|
||||
# def lr_lambda(current_step):
|
||||
# if current_step < num_warmup_steps:
|
||||
# return float(current_step) / float(max(1, num_warmup_steps))
|
||||
# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
# if progress >= 1.0:
|
||||
# return 0.0
|
||||
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
||||
#
|
||||
# return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
#
|
||||
|
||||
class CosineAnnealingWarmupRestarts(_LRScheduler):
|
||||
"""
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
first_cycle_steps (int): First cycle step size.
|
||||
cycle_mult(float): Cycle steps magnification. Default: -1.
|
||||
max_lr(float): First cycle's max learning rate. Default: 0.1.
|
||||
min_lr(float): Min learning rate. Default: 0.001.
|
||||
warmup_steps(int): Linear warmup step size. Default: 0.
|
||||
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer : torch.optim.Optimizer,
|
||||
first_cycle_steps : int,
|
||||
cycle_mult : float = 1.,
|
||||
max_lr : float = 0.1,
|
||||
min_lr : float = 0.001,
|
||||
warmup_steps : int = 0,
|
||||
gamma : float = 1.,
|
||||
last_epoch : int = -1
|
||||
):
|
||||
assert warmup_steps < first_cycle_steps
|
||||
|
||||
self.first_cycle_steps = first_cycle_steps # first cycle step size
|
||||
self.cycle_mult = cycle_mult # cycle steps magnification
|
||||
self.base_max_lr = max_lr # first max learning rate
|
||||
self.max_lr = max_lr # max learning rate in the current cycle
|
||||
self.min_lr = min_lr # min learning rate
|
||||
self.warmup_steps = warmup_steps # warmup step size
|
||||
self.gamma = gamma # decrease rate of max learning rate by cycle
|
||||
|
||||
self.cur_cycle_steps = first_cycle_steps # first cycle step size
|
||||
self.cycle = 0 # cycle count
|
||||
self.step_in_cycle = last_epoch # step size of the current cycle
|
||||
|
||||
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
|
||||
|
||||
# set learning rate min_lr
|
||||
self.init_lr()
|
||||
|
||||
def init_lr(self):
|
||||
self.base_lrs = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = self.min_lr
|
||||
self.base_lrs.append(self.min_lr)
|
||||
|
||||
def get_lr(self):
|
||||
if self.step_in_cycle == -1:
|
||||
return self.base_lrs
|
||||
elif self.step_in_cycle < self.warmup_steps:
|
||||
return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
|
||||
else:
|
||||
return [base_lr + (self.max_lr - base_lr) \
|
||||
* (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
|
||||
/ (self.cur_cycle_steps - self.warmup_steps))) / 2
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
self.step_in_cycle = self.step_in_cycle + 1
|
||||
if self.step_in_cycle >= self.cur_cycle_steps:
|
||||
self.cycle += 1
|
||||
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
|
||||
self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
|
||||
else:
|
||||
if epoch >= self.first_cycle_steps:
|
||||
if self.cycle_mult == 1.:
|
||||
self.step_in_cycle = epoch % self.first_cycle_steps
|
||||
self.cycle = epoch // self.first_cycle_steps
|
||||
else:
|
||||
n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
|
||||
self.cycle = n
|
||||
self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
|
||||
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
|
||||
else:
|
||||
self.cur_cycle_steps = self.first_cycle_steps
|
||||
self.step_in_cycle = epoch
|
||||
|
||||
self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
|
||||
self.last_epoch = math.floor(epoch)
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_ratio=0.001, last_epoch=-1):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
||||
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(
|
||||
min_ratio, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
def get_stepwise_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_steps_decay, decay_rate, last_epoch=-1):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
||||
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
|
||||
num_fades = (current_step-num_warmup_steps)//num_steps_decay
|
||||
return (decay_rate**num_fades)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
218
rf2aa/scripts/generate_sample_lengths.py
Normal file
218
rf2aa/scripts/generate_sample_lengths.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils import data
|
||||
from tqdm import tqdm
|
||||
from rf2aa.data.data_loader import default_dataloader_params, loader_tf_complex, loader_distil_tf, get_train_valid_set, DistilledDataset, loader_pdb, loader_complex, loader_na_complex, loader_fb, loader_dna_rna, loader_sm_compl_assembly_single, loader_sm_compl_assembly, loader_sm, loader_atomize_pdb, loader_atomize_complex
|
||||
from rf2aa.chemical import load_pdb_ideal_sdf_strings
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
class OrderedSampler(torch.utils.data.sampler.Sampler):
|
||||
"""
|
||||
Custom sampler that samples specific indices from a dataset.
|
||||
"""
|
||||
def __init__(self, indices):
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.indices)
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Process crop size and MSA limit.')
|
||||
parser.add_argument('--crop-size', type=int, help='Crop size for the data loader', default=256)
|
||||
parser.add_argument('--msa-limit', type=int, help='MSA limit for the data loader', default=None)
|
||||
parser.add_argument('--output', type=str, help='Output file path for the lengths', default='sample_lengths.pt')
|
||||
parser.add_argument('--num-samples', type=int, help='Number of samples for the loop', default=None)
|
||||
parser.add_argument('--num-workers', type=int, help='Number of DataLoader workers', default=4)
|
||||
return parser.parse_args()
|
||||
|
||||
def seed_everything():
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def build_dataset(crop_size, msa_limit):
|
||||
"""
|
||||
Build a DistilledDataset with the given crop size and MSA limit overwriting the default parameters.
|
||||
"""
|
||||
|
||||
# Load default parameters
|
||||
loader_param = default_dataloader_params
|
||||
|
||||
# Override default parameters
|
||||
loader_param['CROP'] = crop_size or loader_param.get('CROP')
|
||||
loader_param['MSA_LIMIT'] = msa_limit or loader_param.get('MSA_LIMIT')
|
||||
|
||||
# Build dataset
|
||||
(
|
||||
train_ID_dict,
|
||||
valid_ID_dict,
|
||||
weights_dict,
|
||||
train_dict,
|
||||
valid_dict,
|
||||
homo,
|
||||
chid2hash,
|
||||
chid2taxid,
|
||||
chid2smpartners,
|
||||
) = get_train_valid_set(loader_param)
|
||||
|
||||
# define atomize_pdb train/valid sets, which use the same examples as pdb set
|
||||
train_ID_dict['atomize_pdb'] = train_ID_dict['pdb']
|
||||
valid_ID_dict['atomize_pdb'] = valid_ID_dict['pdb']
|
||||
weights_dict['atomize_pdb'] = weights_dict['pdb']
|
||||
train_dict['atomize_pdb'] = train_dict['pdb']
|
||||
valid_dict['atomize_pdb'] = valid_dict['pdb']
|
||||
|
||||
# define atomize_pdb train/valid sets, which use the same examples as pdb set
|
||||
train_ID_dict['atomize_complex'] = train_ID_dict['compl']
|
||||
valid_ID_dict['atomize_complex'] = valid_ID_dict['compl']
|
||||
weights_dict['atomize_complex'] = weights_dict['compl']
|
||||
train_dict['atomize_complex'] = train_dict['compl']
|
||||
valid_dict['atomize_complex'] = valid_dict['compl']
|
||||
|
||||
# Assign loaders to each dataset
|
||||
loader_dict = dict(
|
||||
pdb = loader_pdb,
|
||||
peptide = loader_pdb,
|
||||
compl = loader_complex,
|
||||
neg_compl = loader_complex,
|
||||
na_compl = loader_na_complex,
|
||||
neg_na_compl = loader_na_complex,
|
||||
distil_tf = loader_distil_tf,
|
||||
tf = loader_tf_complex,
|
||||
neg_tf = loader_tf_complex,
|
||||
fb = loader_fb,
|
||||
rna = loader_dna_rna,
|
||||
dna = loader_dna_rna,
|
||||
sm_compl = loader_sm_compl_assembly_single,
|
||||
metal_compl = loader_sm_compl_assembly_single,
|
||||
sm_compl_multi = loader_sm_compl_assembly_single,
|
||||
sm_compl_covale = loader_sm_compl_assembly_single,
|
||||
sm_compl_asmb = loader_sm_compl_assembly,
|
||||
sm = loader_sm,
|
||||
atomize_pdb = loader_atomize_pdb,
|
||||
atomize_complex = loader_atomize_complex,
|
||||
sm_compl_furthest_neg = loader_sm_compl_assembly,
|
||||
sm_compl_permuted_neg = loader_sm_compl_assembly,
|
||||
sm_compl_docked_neg = loader_sm_compl_assembly,
|
||||
)
|
||||
|
||||
# Get ligand dictionary. This is used for loading negative examples.
|
||||
ligand_dictionary = load_pdb_ideal_sdf_strings(return_only_sdf_strings=True)
|
||||
|
||||
# Build dataset
|
||||
train_set = DistilledDataset(
|
||||
train_ID_dict,
|
||||
train_dict,
|
||||
loader_dict,
|
||||
homo,
|
||||
chid2hash,
|
||||
chid2taxid,
|
||||
chid2smpartners,
|
||||
loader_param,
|
||||
native_NA_frac=0.25,
|
||||
ligand_dictionary=ligand_dictionary
|
||||
)
|
||||
|
||||
return train_set
|
||||
|
||||
def main(crop_size, msa_limit, output_file, num_samples, num_workers):
|
||||
print(f"Building dataset with CROP_SIZE={crop_size} and MSA_LIMIT={msa_limit}...")
|
||||
# Setup the dataset
|
||||
train_set = build_dataset(crop_size, msa_limit)
|
||||
|
||||
# Datasets where we will assume `LEN_EXIST` returns the appropriate length
|
||||
fixed_length_datasets = ['pdb', 'fb']
|
||||
|
||||
# Datasets where we pass the sample through the DataLoader to measure the length
|
||||
variable_length_datasets = ["compl", "neg_compl", "na_compl", "neg_na_compl", "distil_tf","tf","neg_tf","rna","dna", "sm_compl", "metal_compl", "sm_compl_multi", "sm_compl_covale", "sm_compl_asmb", "sm", "sm_compl_docked_neg", "sm_compl_permuted_neg", "sm_compl_furthest_neg", "atomize_pdb", "atomize_complex"]
|
||||
|
||||
# Assert that the monomer and variable dataset lists do not overlap
|
||||
assert len(set(fixed_length_datasets) & set(variable_length_datasets)) == 0, "Monomer and variable datasets should not overlap"
|
||||
|
||||
# Create a tensor of zeros of the same length as the dataset
|
||||
lengths = torch.zeros(len(train_set))
|
||||
|
||||
offset = 0
|
||||
indices_to_process = []
|
||||
for key, index_list in train_set.index_dict.items():
|
||||
# Add offset to every value in list
|
||||
adjusted_index_list = [x + offset for x in index_list]
|
||||
|
||||
# Either add index to list to process later or calculate length directly
|
||||
if key in variable_length_datasets:
|
||||
indices_to_process.extend(adjusted_index_list)
|
||||
elif key in fixed_length_datasets:
|
||||
# Calculate lengths and remove duplicates
|
||||
if key in ['pdb', 'fb']:
|
||||
train_set.dataset_dict[key]['LENGTH'] = train_set.dataset_dict[key]["SEQUENCE"].apply(len)
|
||||
else:
|
||||
train_set.dataset_dict[key]['LENGTH'] = train_set.dataset_dict[key]["LEN_EXIST"]
|
||||
fixed_lengths = train_set.dataset_dict[key].drop_duplicates(subset=['CLUSTER'])
|
||||
|
||||
# Create a dictionary for mapping cluster ID to length
|
||||
cluster_length_map = dict(zip(fixed_lengths['CLUSTER'], fixed_lengths['LENGTH']))
|
||||
|
||||
# Map the cluster ID to the index in the dataset
|
||||
ids = train_set.ID_dict[key]
|
||||
fixed_lengths_processed = [cluster_length_map.get(id) for id in ids]
|
||||
|
||||
# Replace the indices within lengths corresponding to adjusted_index_list with the fixed lengths
|
||||
lengths[adjusted_index_list] = torch.tensor(fixed_lengths_processed, dtype=torch.float)
|
||||
|
||||
offset += len(index_list)
|
||||
|
||||
sampler = None
|
||||
if num_samples:
|
||||
print(f"Sampling {num_samples} examples...")
|
||||
np.random.shuffle(indices_to_process)
|
||||
indices_to_process = indices_to_process[:num_samples]
|
||||
sampler = OrderedSampler(indices_to_process)
|
||||
else:
|
||||
num_samples = len(indices_to_process)
|
||||
sampler = OrderedSampler(indices_to_process)
|
||||
|
||||
# Create the training loader
|
||||
dataloader_kwargs = {
|
||||
"shuffle": False,
|
||||
"num_workers": num_workers,
|
||||
"pin_memory": False,
|
||||
"batch_size": 1,
|
||||
}
|
||||
|
||||
print("num_workers:", dataloader_kwargs["num_workers"])
|
||||
|
||||
train_loader = data.DataLoader(train_set, sampler=sampler, **dataloader_kwargs)
|
||||
|
||||
print(f"Looping through {num_samples} batches.")
|
||||
print("First ten indices:", indices_to_process[:10])
|
||||
print("Last ten indices:", indices_to_process[-10:])
|
||||
for batch_idx, batch in tqdm(enumerate(train_loader), total=num_samples, desc="Processing batches"):
|
||||
try:
|
||||
# Get the N_residues from the first tensor, which is the last dimension
|
||||
length = batch[0].shape[-1]
|
||||
|
||||
# Store the length in a slot corresponding to the absolute index of the example
|
||||
lengths[indices_to_process[batch_idx]] = length
|
||||
except Exception as e:
|
||||
print(f"An error occurred while processing batch {batch_idx}: {e}")
|
||||
|
||||
# Break if limiting examples
|
||||
if batch_idx >= num_samples - 1:
|
||||
break
|
||||
|
||||
# Save length tensor to pickle
|
||||
print(f"Saving lengths as a tensor to {output_file}...")
|
||||
torch.save(lengths, output_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
seed_everything()
|
||||
main(args.crop_size, args.msa_limit, args.output, args.num_samples, args.num_workers)
|
||||
@@ -1,91 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import submitit
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def add_slurm_args(parser: Optional[argparse.ArgumentParser] = None, prefix: str = "-") -> argparse.ArgumentParser:
|
||||
if parser is None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Submits a job to the digs cluster via submitit"
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}slurm_log_path",
|
||||
default="/home/psturm/RF2-allatom/slurm_logs",
|
||||
type=str,
|
||||
help="Path where slurm logs will go."
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}local",
|
||||
action="store_true",
|
||||
help="Set to true to train locally rather than submitting to slurm",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}slurm_partition",
|
||||
type=str,
|
||||
default="gpu",
|
||||
help="Slurm partition to run job on",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}gpu_type",
|
||||
type=str,
|
||||
default="a6000",
|
||||
help="Which gpus to run on, slurm gres constraint",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}cpu_memory",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Amount of cpu job memory to request for slurm submission",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}cpus_per_task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of cpu cores to request for slurm submission",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}timeout_min",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum number of minutes for slurm job to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}max_slurm_jobs_at_once",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Maximum number of array jobs to run at once",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}num_gpus",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of GPUs to train on"
|
||||
)
|
||||
parser.add_argument(
|
||||
f"{prefix}nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of nodes to submit to"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def create_executor(args: Dict, log_folder: str, job_name: str) -> submitit.AutoExecutor:
|
||||
executor = submitit.AutoExecutor(folder=log_folder)
|
||||
executor.update_parameters(
|
||||
slurm_partition=args.slurm_partition,
|
||||
slurm_mem=f"{args.cpu_memory}gb",
|
||||
slurm_job_name=job_name,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
slurm_ntasks_per_node=args.num_gpus,
|
||||
slurm_array_parallelism=args.max_slurm_jobs_at_once,
|
||||
nodes=args.nodes,
|
||||
timeout_min=args.timeout_min,
|
||||
)
|
||||
if args.gpu_type != "none":
|
||||
executor.update_parameters(
|
||||
slurm_gres=f"gpu:{args.gpu_type}:{args.num_gpus}",
|
||||
)
|
||||
|
||||
return executor
|
||||
220
rf2aa/test_inference.py
Normal file
220
rf2aa/test_inference.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest, random, os, sys
|
||||
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(script_dir)
|
||||
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
|
||||
import rf2aa.data.parsers as parsers
|
||||
import rf2aa.util as util
|
||||
from rf2aa.chemical import load_pdb_ideal_sdf_strings, NTOTALDOFS
|
||||
from rf2aa.data.data_loader import MSAFeaturize, TemplFeaturize, \
|
||||
center_and_realign_missing, generate_xyz_prev, get_bond_distances
|
||||
from rf2aa.data.compose_dataset import default_dataloader_params
|
||||
from rf2aa.kinematics import get_chirals, xyz_to_t2d
|
||||
from rf2aa.tensor_util import assert_equal, cmp
|
||||
MAXLAT=256
|
||||
MAXSEQ=2048
|
||||
|
||||
MODEL_PARAM ={
|
||||
"n_extra_block" : 4,
|
||||
"n_main_block" : 32,
|
||||
"n_ref_block" : 4,
|
||||
"d_msa" : 256,
|
||||
"d_pair" : 192,
|
||||
"d_templ" : 64,
|
||||
"n_head_msa" : 8,
|
||||
"n_head_pair" : 6,
|
||||
"n_head_templ" : 4,
|
||||
"d_hidden" : 32,
|
||||
"d_hidden_templ" : 64,
|
||||
"p_drop" : 0.0,
|
||||
"lj_lin" : 0.7,
|
||||
'symmetrize_repeats': False,
|
||||
'repeat_length': float('nan'),
|
||||
'symmsub_k': float('nan'),
|
||||
'sym_method': float('nan'),
|
||||
'main_block': float('nan'),
|
||||
'copy_main_block_template': False
|
||||
}
|
||||
|
||||
SE3_param = {
|
||||
"num_layers" : 1,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 64,
|
||||
"l0_out_features": 64,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 2,
|
||||
"num_edge_features": 64,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
SE3_ref_param = {
|
||||
"num_layers" : 2,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 64,
|
||||
"l0_out_features": 64,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 2,
|
||||
"num_edge_features": 64,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
MODEL_PARAM['SE3_param'] = SE3_param
|
||||
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
|
||||
params = default_dataloader_params
|
||||
|
||||
def make_deterministic(seed=0):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
def featurize_input(three_letter="HEM", pickle_input=None):
|
||||
"""
|
||||
make dummy features for the model
|
||||
"""
|
||||
if pickle_input is not None:
|
||||
if os.path.exists(pickle_input):
|
||||
rf_input = torch.load(pickle_input)
|
||||
return rf_input
|
||||
ligands = load_pdb_ideal_sdf_strings(return_only_sdf_strings=True)
|
||||
sdf = ligands[three_letter]
|
||||
mol, msa_sm, ins_sm, xyz_sm, mask_sm = parsers.parse_mol(sdf, filetype="sdf", string=True)
|
||||
a3m_sm = {"msa": msa_sm.unsqueeze(0), "ins": ins_sm.unsqueeze(0)}
|
||||
G = util.get_nxgraph(mol)
|
||||
N_symmetry, sm_L, _ = xyz_sm.shape
|
||||
Ls = [ sm_L]
|
||||
a3m = a3m_sm
|
||||
msa = a3m['msa'].long()
|
||||
ins = a3m['ins'].long()
|
||||
bond_feats = util.get_bond_feats(mol)
|
||||
chirals = get_chirals(mol, xyz_sm[0])
|
||||
atom_frames = util.get_atom_frames(msa_sm, G)
|
||||
idx = torch.arange(sm_L)
|
||||
same_chain = torch.ones((sm_L, sm_L)).long()
|
||||
dist_matrix = get_bond_distances(bond_feats)
|
||||
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins,
|
||||
p_mask=0.0, params={'MAXLAT': MAXLAT, 'MAXSEQ': MAXSEQ, 'MAXCYCLE': 1}, tocpu=True)
|
||||
|
||||
xyz_t, f1d_t, mask_t, _ = TemplFeaturize({"ids":[]}, sm_L, params, offset=0,
|
||||
npick=0, pick_top=True)
|
||||
xyz_t_frames = util.xyz_t_to_frame_xyz(xyz_t[None], seq, atom_frames[None])
|
||||
mask_t_2d = util.get_prot_sm_mask(mask_t, seq[0])[None] # (B, T, L)
|
||||
mask_t_2d = mask_t_2d[:,:,None]*mask_t_2d[:,:,:,None] # (B, T, L, L)
|
||||
t2d = xyz_to_t2d(xyz_t_frames, mask_t_2d[None])
|
||||
alpha_t = torch.zeros(1, sum(Ls), NTOTALDOFS*3)
|
||||
ntempl = xyz_t.shape[0]
|
||||
xyz_t = torch.stack(
|
||||
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
|
||||
)
|
||||
L = sum(Ls)
|
||||
xyz_prev, _ = generate_xyz_prev(xyz_t, mask_t, params)
|
||||
alpha_prev = torch.zeros((1,L,NTOTALDOFS,2))
|
||||
rf_input = {
|
||||
"msa_clust": msa_seed,
|
||||
"msa_extra": msa_extra,
|
||||
"seq": seq,
|
||||
"seq_unmasked": msa_seed_orig[:, 0],
|
||||
"xyz_prev": xyz_prev[None],
|
||||
"alpha_prev": alpha_prev,
|
||||
"idx_pdb": idx[None],
|
||||
"bond_feats": bond_feats[None],
|
||||
"dist_matrix": dist_matrix[None],
|
||||
"chirals": chirals[None],
|
||||
"atom_frames": atom_frames[None],
|
||||
"t1d": f1d_t[None],
|
||||
"t2d": t2d[0],
|
||||
"xyz_t": xyz_t[...,1,:][None],
|
||||
"alpha_t": alpha_t[None],
|
||||
"mask_t": mask_t_2d,
|
||||
"same_chain": same_chain[None],
|
||||
"msa_prev": None,
|
||||
"pair_prev": None,
|
||||
"state_prev": None,
|
||||
"mask_recycle": None
|
||||
|
||||
}
|
||||
if pickle_input is not None:
|
||||
torch.save(rf_input, pickle_input)
|
||||
return rf_input
|
||||
|
||||
|
||||
class FoldAndDock3TestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(FoldAndDock3TestCase, self).__init__()
|
||||
self.name = "fd3_inference"
|
||||
pickle_input = str(os.path.join("test_pickles", f"{self.name}_input"))
|
||||
#MODEL_PARAM['use_extra_l1'] = True
|
||||
#MODEL_PARAM['use_atom_frames'] = True
|
||||
|
||||
# refactored model param allowing for backwards compatibility
|
||||
MODEL_PARAM['use_chiral_l1'] = True
|
||||
MODEL_PARAM['use_lj_l1'] = True
|
||||
MODEL_PARAM['use_atom_frames'] = True
|
||||
MODEL_PARAM['use_same_chain'] = True
|
||||
MODEL_PARAM['recycling_type'] = 'all'
|
||||
self.model = LegacyRoseTTAFoldModule(
|
||||
**MODEL_PARAM,
|
||||
aamask = util.allatom_mask,
|
||||
atom_type_index = util.atom_type_index,
|
||||
ljlk_parameters = util.ljlk_parameters,
|
||||
lj_correction_parameters = util.lj_correction_parameters,
|
||||
num_bonds = util.num_bonds,
|
||||
cb_len = util.cb_length_t,
|
||||
cb_ang = util.cb_angle_t,
|
||||
cb_tor = util.cb_torsion_t,
|
||||
)
|
||||
self.rf_input = featurize_input(pickle_input=pickle_input)
|
||||
|
||||
def test_inference(self):
|
||||
make_deterministic()
|
||||
out = self.model(
|
||||
msa_latent = self.rf_input["msa_clust"].float(),
|
||||
msa_full = self.rf_input["msa_extra"].float(),
|
||||
seq = self.rf_input["seq"],
|
||||
seq_unmasked = self.rf_input["seq_unmasked"],
|
||||
xyz = self.rf_input["xyz_prev"],
|
||||
sctors = self.rf_input["alpha_prev"],
|
||||
idx = self.rf_input["idx_pdb"],
|
||||
bond_feats = self.rf_input["bond_feats"],
|
||||
dist_matrix = self.rf_input["dist_matrix"],
|
||||
chirals = self.rf_input["chirals"],
|
||||
atom_frames = self.rf_input["atom_frames"],
|
||||
t1d = self.rf_input["t1d"],
|
||||
t2d = self.rf_input["t2d"],
|
||||
xyz_t = self.rf_input["xyz_t"],
|
||||
alpha_t = self.rf_input["alpha_t"],
|
||||
mask_t = self.rf_input["mask_t"],
|
||||
same_chain = self.rf_input["same_chain"],
|
||||
msa_prev = self.rf_input["msa_prev"],
|
||||
pair_prev = self.rf_input["pair_prev"],
|
||||
state_prev = self.rf_input["state_prev"],
|
||||
mask_recycle = self.rf_input["mask_recycle"]
|
||||
|
||||
)
|
||||
output_names = ("logits_c6d", "logits_aa", "logits_pae", \
|
||||
"logits_pde", "p_bind", "xyz", "alpha", "xyz_allatom", \
|
||||
"lddt", "seq", "pair", "state")
|
||||
output_dict = dict(zip(output_names, out))
|
||||
test_out_path = os.path.join("test_pickles", f"{self.name}_out")
|
||||
if not os.path.exists(test_out_path):
|
||||
torch.save(output_dict, test_out_path)
|
||||
print(f"saved output at {test_out_path}")
|
||||
else:
|
||||
reference_output_dict = torch.load(test_out_path)
|
||||
for output in output_names:
|
||||
|
||||
want = reference_output_dict[output]
|
||||
got = output_dict[output]
|
||||
if output == "logits_c6d":
|
||||
want = want[0]
|
||||
got = got[0]
|
||||
print(output)
|
||||
cmp(got, want)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
BIN
rf2aa/test_pickles/fd3_inference_input
Normal file
BIN
rf2aa/test_pickles/fd3_inference_input
Normal file
Binary file not shown.
BIN
rf2aa/test_pickles/fd3_inference_out
Normal file
BIN
rf2aa/test_pickles/fd3_inference_out
Normal file
Binary file not shown.
357
rf2aa/trainer_new.py
Normal file
357
rf2aa/trainer_new.py
Normal file
@@ -0,0 +1,357 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
import hydra
|
||||
import os
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
|
||||
from rf2aa.data.data_loader import loader_atomize_pdb
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items
|
||||
from rf2aa.debug import debug_unused_params, debug_used_params
|
||||
from rf2aa.training.EMA import EMA, count_parameters
|
||||
from rf2aa.loss.loss_factory import get_loss_and_misc
|
||||
from rf2aa.training.optimizer import add_weight_decay
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_sampling
|
||||
from rf2aa.model.network import RosettaFold
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
|
||||
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
|
||||
from rf2aa.util import frame_indices, long2alt, allatom_mask, num_bonds, \
|
||||
atom_type_index, ljlk_parameters, lj_correction_parameters, hbtypes, \
|
||||
hbbaseatoms, hbpolys, cb_length_t, cb_angle_t, cb_torsion_t
|
||||
import rf2aa.util as util
|
||||
from rf2aa.util_module import XYZConverter
|
||||
|
||||
#TODO: control environment variables from config
|
||||
# limit thread counts
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||
|
||||
## To reproduce errors
|
||||
import random
|
||||
|
||||
def seed_all(seed=0):
|
||||
random.seed(0)
|
||||
torch.manual_seed(5924)
|
||||
torch.cuda.manual_seed(5924)
|
||||
np.random.seed(6636)
|
||||
|
||||
torch.set_num_threads(4)
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
|
||||
if self.config.experiment.output_dir is not None:
|
||||
self.output_dir = self.config.experiment.output_dir
|
||||
else:
|
||||
self.output_dir = "models/"
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
def construct_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def construct_optimizer(self):
|
||||
if self.config.training_params.weight_decay is not None:
|
||||
opt_params = add_weight_decay(self.model, self.config.training_params.weight_decay)
|
||||
else:
|
||||
opt_params = self.model.parameters()
|
||||
self.optimizer = torch.optim.AdamW(opt_params, lr=self.config.training_params.learning_rate)
|
||||
|
||||
def construct_scheduler(self):
|
||||
self.scheduler = get_stepwise_decay_schedule_with_warmup(self.optimizer, \
|
||||
**self.config.training_params.learning_rate_schedule)
|
||||
|
||||
def construct_scaler(self):
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.config.training_params.use_amp)
|
||||
|
||||
def load_checkpoint(self, rank):
|
||||
if self.config.training_params.resume_train:
|
||||
checkpoint_path = f"{self.output_dir}/{self.config.experiment.name}_last.pt"
|
||||
elif self.config.eval_params.checkpoint_path:
|
||||
checkpoint_path = self.config.eval_params.checkpoint_path
|
||||
map_location = {"cuda:0": f"cuda:{rank}"}
|
||||
self.checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
print(f"Loading checkpoint from {checkpoint_path} on rank:{rank}")
|
||||
|
||||
def load_model(self):
|
||||
torch.cuda.empty_cache()
|
||||
if self.config.training_params.resume_train is None:
|
||||
raise ValueError("Should not load model when resume_train is True")
|
||||
#TODO: check if model should load the final state dict and not the EMA
|
||||
self.model.module.model.load_state_dict(self.checkpoint["final_state_dict"], strict=True)
|
||||
self.model.module.shadow.load_state_dict(self.checkpoint["model_state_dict"], strict=False)
|
||||
print("Checkpoint loaded into model")
|
||||
|
||||
def load_optimizer(self):
|
||||
if self.config.training_params.resume_train is None:
|
||||
raise ValueError("Should not load optimizer when resume_train is True")
|
||||
self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
|
||||
|
||||
def load_scheduler(self):
|
||||
if self.config.training_params.resume_train is None:
|
||||
raise ValueError("Should not load scheduler when resume_train is True")
|
||||
self.scheduler.load_state_dict(self.checkpoint['scheduler_state_dict'])
|
||||
|
||||
def load_scaler(self):
|
||||
if self.config.training_params.resume_train is None:
|
||||
raise ValueError("Should not load scaler when resume_train is True")
|
||||
self.scaler.load_state_dict(self.checkpoint['scaler_state_dict'])
|
||||
|
||||
def construct_dataset(self, rank, world_size):
|
||||
return compose_dataset(self.config.dataset_params, self.config.loader_params, rank, world_size)
|
||||
|
||||
def construct_loss_function(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def move_constants_to_device(self, gpu):
|
||||
self.fi_dev = frame_indices.to(gpu)
|
||||
self.xyz_converter = XYZConverter().to(gpu)
|
||||
|
||||
self.l2a = long2alt.to(gpu)
|
||||
self.aamask = allatom_mask.to(gpu)
|
||||
self.num_bonds = num_bonds.to(gpu)
|
||||
self.atom_type_index = atom_type_index.to(gpu)
|
||||
self.ljlk_parameters = ljlk_parameters.to(gpu)
|
||||
self.lj_correction_parameters = lj_correction_parameters.to(gpu)
|
||||
self.hbtypes = hbtypes.to(gpu)
|
||||
self.hbbaseatoms = hbbaseatoms.to(gpu)
|
||||
self.hbpolys = hbpolys.to(gpu)
|
||||
self.cb_len = cb_length_t.to(gpu)
|
||||
self.cb_ang = cb_angle_t.to(gpu)
|
||||
self.cb_tor = cb_torsion_t.to(gpu)
|
||||
|
||||
|
||||
def checkpoint_model(self, epoch, metadata={}):
|
||||
checkpoint_data = {
|
||||
'epoch' : epoch,
|
||||
'model_state_dict' : self.model.module.shadow.state_dict(),
|
||||
'final_state_dict' : self.model.module.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'scaler_state_dict' : self.scaler.state_dict(),
|
||||
'training_config' : dict(self.config),
|
||||
}
|
||||
checkpoint_data.update(metadata)
|
||||
torch.save(checkpoint_data, f"{self.output_dir}/{self.config.experiment.name}_{epoch}.pt")
|
||||
|
||||
|
||||
def launch_distributed_training(self):
|
||||
world_size = torch.cuda.device_count()
|
||||
if ('MASTER_ADDR' not in os.environ):
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
|
||||
if ('MASTER_PORT' not in os.environ):
|
||||
os.environ['MASTER_PORT'] = '%d'%self.config.ddp_params.port
|
||||
|
||||
if ("SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
|
||||
world_size = int(os.environ["SLURM_NTASKS"])
|
||||
rank = int (os.environ["SLURM_PROCID"])
|
||||
print ("Launched from slurm", rank, world_size)
|
||||
self.train_model(rank, world_size)
|
||||
#mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
else:
|
||||
print ("Launched from interactive")
|
||||
world_size = torch.cuda.device_count()
|
||||
if world_size == 1:
|
||||
# No need for multiple processes with 1 GPU
|
||||
self.train_model(0, world_size)
|
||||
else:
|
||||
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
def init_process_group(self, rank, world_size):
|
||||
gpu = rank % torch.cuda.device_count()
|
||||
dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device("cuda:%d"%gpu)
|
||||
return gpu
|
||||
|
||||
def cleanup(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
def train_model(self, rank, world_size):
|
||||
""" runs model training on each gpu """
|
||||
gpu = self.init_process_group(rank, world_size)
|
||||
train_loader, train_sampler, valid_loaders, valid_sampler = self.construct_dataset(rank, world_size)
|
||||
self.train_loader = train_loader
|
||||
|
||||
# move global information to device
|
||||
self.move_constants_to_device(gpu)
|
||||
|
||||
self.construct_model(device=gpu)
|
||||
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
if rank == 0:
|
||||
print(f"Loading model with {count_parameters(self.model)} parameters")
|
||||
|
||||
self.construct_optimizer()
|
||||
self.construct_scheduler()
|
||||
self.construct_scaler()
|
||||
if self.config.training_params.resume_train:
|
||||
self.load_checkpoint(rank)
|
||||
self.load_model()
|
||||
self.load_optimizer()
|
||||
self.load_scheduler()
|
||||
self.load_scaler()
|
||||
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
|
||||
self.config.experiment.n_epoch,
|
||||
self.config.dataset_params.n_train,
|
||||
world_size)
|
||||
for epoch in range(self.config.experiment.n_epoch):
|
||||
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
|
||||
self.train_epoch(epoch, rank)
|
||||
|
||||
self.cleanup()
|
||||
|
||||
def train_epoch(self, epoch, rank):
|
||||
""" train model """
|
||||
# turn on gradients
|
||||
self.model.train()
|
||||
|
||||
# clear gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for train_idx, inputs in enumerate(self.train_loader):
|
||||
n_cycle = self.recycle_schedule[epoch, train_idx] # number of recycling
|
||||
|
||||
# run forward pass and compute loss
|
||||
loss, loss_dict = self.train_step(inputs, n_cycle)
|
||||
|
||||
# aggregate loss and update parameters
|
||||
loss = loss / self.config.ddp_params.accum
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
if train_idx%self.config.ddp_params.accum == 0:
|
||||
self.update_parameters()
|
||||
|
||||
if train_idx % self.config.log_params.log_every_n_examples == 0 and rank == 0:
|
||||
self.log_intermediate_losses(inputs, loss_dict, n_cycle)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0:
|
||||
self.checkpoint_model(epoch)
|
||||
|
||||
def train_step(self, inputs, n_cycle):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_parameters(self):
|
||||
""" scale, clip gradients and update parameters """
|
||||
# gradient clipping
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_params.grad_clip)
|
||||
self.scaler.step(self.optimizer)
|
||||
scale = self.scaler.get_scale()
|
||||
self.scaler.update()
|
||||
skip_lr_sched = (scale != self.scaler.get_scale())
|
||||
self.optimizer.zero_grad()
|
||||
if not skip_lr_sched:
|
||||
self.scheduler.step()
|
||||
self.model.module.update() # apply EMA
|
||||
|
||||
def log_intermediate_losses(self, inputs, loss_dict, n_cycle):
|
||||
item = inputs[-1]
|
||||
max_mem = torch.cuda.max_memory_allocated()/1e9
|
||||
print(f"Example: {item} Max Memory:{max_mem} Recycle:{n_cycle}\n"+
|
||||
"\t".join([f"{k}:{v}" for k,v in loss_dict.items()]))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class LegacyTrainer(Trainer):
|
||||
""" trains Legacy versions of RFAA """
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
def construct_model(self, device="cpu"):
|
||||
self.model = LegacyRoseTTAFoldModule(
|
||||
**self.config.legacy_model_param,
|
||||
aamask = util.allatom_mask.to(device),
|
||||
atom_type_index = util.atom_type_index.to(device),
|
||||
ljlk_parameters = util.ljlk_parameters.to(device),
|
||||
lj_correction_parameters = util.lj_correction_parameters.to(device),
|
||||
num_bonds = util.num_bonds.to(device),
|
||||
cb_len = util.cb_length_t.to(device),
|
||||
cb_ang = util.cb_angle_t.to(device),
|
||||
cb_tor = util.cb_torsion_t.to(device),
|
||||
|
||||
).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
|
||||
def train_step(self, inputs, n_cycle):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
gpu = self.model.device
|
||||
# HACK: certain features are constructed during the train step
|
||||
# in the future this should only promote the constructed features onto gpu
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input(inputs, self.xyz_converter, gpu)
|
||||
|
||||
output_i = recycle_step_legacy(self.model, network_input, n_cycle, self.config.training_params.use_amp)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
|
||||
msa = msa.to(gpu)
|
||||
mask_msa = mask_msa.to(gpu)
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ComposedTrainer(Trainer):
|
||||
""" trains composed versions of RFAA """
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
def construct_model(self, device="cpu"):
|
||||
self.model = RosettaFold(self.config).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
|
||||
def train_step(self, inputs, n_cycle):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
gpu = self.model.device
|
||||
# HACK: certain features are constructed during the train step
|
||||
# in the future this should only promote the constructed features onto gpu
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input(inputs, self.xyz_converter, gpu)
|
||||
|
||||
output_i = recycle_step_packed(self.model, network_input, n_cycle, self.config.training_params.use_amp)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
|
||||
msa = msa.to(gpu)
|
||||
mask_msa = mask_msa.to(gpu)
|
||||
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/train')
|
||||
def main(config):
|
||||
seed_all()
|
||||
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||||
trainer.launch_distributed_training()
|
||||
|
||||
trainer_factory = {
|
||||
"legacy": LegacyTrainer,
|
||||
"composed": ComposedTrainer,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
rf2aa/training/EMA.py
Normal file
55
rf2aa/training/EMA.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class EMA(nn.Module):
|
||||
|
||||
def __init__(self, model, decay):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
|
||||
self.model = model
|
||||
self.shadow = deepcopy(self.model)
|
||||
|
||||
for param in self.shadow.parameters():
|
||||
param.detach_()
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
if not self.training:
|
||||
print("EMA update should only be called during training", file=stderr, flush=True)
|
||||
return
|
||||
|
||||
model_params = OrderedDict(self.model.named_parameters())
|
||||
shadow_params = OrderedDict(self.shadow.named_parameters())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_params.keys() == shadow_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
||||
if param.requires_grad:
|
||||
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
|
||||
|
||||
model_buffers = OrderedDict(self.model.named_buffers())
|
||||
shadow_buffers = OrderedDict(self.shadow.named_buffers())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_buffers.keys() == shadow_buffers.keys()
|
||||
|
||||
for name, buffer in model_buffers.items():
|
||||
# buffers are copied
|
||||
shadow_buffers[name].copy_(buffer)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.training:
|
||||
return self.model(*args, **kwargs)
|
||||
else:
|
||||
return self.shadow(*args, **kwargs)
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
5
rf2aa/training/checkpoint.py
Normal file
5
rf2aa/training/checkpoint.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# for gradient checkpointing
|
||||
def create_custom_forward(module, **kwargs):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
11
rf2aa/training/optimizer.py
Normal file
11
rf2aa/training/optimizer.py
Normal file
@@ -0,0 +1,11 @@
|
||||
def add_weight_decay(model, l2_coeff):
|
||||
decay, no_decay = [], []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
#if len(param.shape) == 1 or name.endswith(".bias"):
|
||||
if "norm" in name or name.endswith(".bias"):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}]
|
||||
136
rf2aa/training/recycling.py
Normal file
136
rf2aa/training/recycling.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from contextlib import ExitStack
|
||||
|
||||
from rf2aa.util import INIT_CRDS, NTOTAL
|
||||
|
||||
def recycle_step_legacy(ddp_model, input, n_cycle, use_amp):
|
||||
gpu = ddp_model.device
|
||||
xyz_prev, alpha_prev, mask_recycle = \
|
||||
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
if i_cycle < n_cycle -1:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
return_raw=True
|
||||
use_checkpoint=False
|
||||
else:
|
||||
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
return_raw=False
|
||||
use_checkpoint=True
|
||||
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||
output_i = ddp_model(**input_i)
|
||||
return output_i
|
||||
|
||||
def recycle_step_packed(ddp_model, input, n_cycle, use_amp):
|
||||
""" exactly same logic as legacy recycling, except inputs and outputs are dictionaries"""
|
||||
gpu = ddp_model.device
|
||||
xyz_prev, alpha_prev, mask_recycle = \
|
||||
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
if i_cycle < n_cycle -1:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
return_raw = True
|
||||
use_checkpoint=False
|
||||
else:
|
||||
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
return_raw = False
|
||||
use_checkpoint=True
|
||||
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||
rf_outputs, rf_latents = ddp_model(input_i, use_checkpoint=use_checkpoint)
|
||||
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
|
||||
return output_i
|
||||
|
||||
def unpack_outputs(rf_outputs, rf_latents, return_raw):
|
||||
#HACK: this just unpacks the outputs into the way the previous RFAA loss function accepts it
|
||||
# in the future the loss function should accept rf_outputs and rf_latents
|
||||
msa, pair, state = rf_latents["msa"], rf_latents["pair"], rf_latents["state"]
|
||||
|
||||
if return_raw:
|
||||
xyz_prev = rf_outputs["xyzs"][-1][None]
|
||||
alpha_prev = rf_outputs["alphas"][-1]
|
||||
return msa[:, 0], pair, xyz_prev, alpha_prev, None # mask_recycle is always None
|
||||
|
||||
else:
|
||||
c6d_logits, mlm_logits, pae_logits, plddt_logits = rf_outputs["c6d"], rf_outputs["mlm"], \
|
||||
rf_outputs["pae"], rf_outputs["plddt"]
|
||||
pde_logits = None
|
||||
p_bind = None
|
||||
xyz, alphas = rf_outputs["xyzs"], rf_outputs["alphas"]
|
||||
if "xyz_intermediate" in rf_latents:
|
||||
intermediate_xyzs = torch.cat(rf_latents["xyz_intermediate"], dim=0)
|
||||
xyz = torch.cat((intermediate_xyzs, xyz), dim=0)
|
||||
|
||||
if "alpha_intermediate" in rf_latents:
|
||||
alpha_intermediate = torch.cat(rf_latents["alpha_intermediate"], dim=0)
|
||||
alphas = torch.cat((alpha_intermediate, alphas), dim=0)
|
||||
|
||||
xyz_allatom = None
|
||||
return (c6d_logits, mlm_logits, pae_logits, pde_logits, p_bind,
|
||||
xyz, alphas, xyz_allatom, plddt_logits, msa[:, 0], pair, state)
|
||||
|
||||
|
||||
def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
|
||||
input_i = {}
|
||||
for key in network_input:
|
||||
if key in ['msa_latent', 'msa_full', 'seq']:
|
||||
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
|
||||
else:
|
||||
input_i[key] = network_input[key]
|
||||
|
||||
L = input_i["msa_latent"].shape[2]
|
||||
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
|
||||
xyz_prev = INIT_CRDS.reshape(1,1,NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
|
||||
|
||||
input_i['msa_prev'] = msa_prev
|
||||
input_i['pair_prev'] = pair_prev
|
||||
input_i['xyz'] = xyz_prev
|
||||
input_i['mask_recycle'] = mask_recycle
|
||||
input_i['sctors'] = alpha
|
||||
input_i['return_raw'] = return_raw
|
||||
input_i['use_checkpoint'] = use_checkpoint
|
||||
|
||||
input_i.pop('xyz_prev')
|
||||
input_i.pop('alpha_prev')
|
||||
return input_i
|
||||
|
||||
|
||||
def get_recycle_schedule(max_cycle, n_epochs, n_train, world_size, **kwargs):
|
||||
'''
|
||||
get's the number of recycles per example.
|
||||
'''
|
||||
assert n_train % world_size == 0
|
||||
# need to sync different gpus
|
||||
recycle_schedules=[]
|
||||
# make deterministic
|
||||
np.random.seed(0)
|
||||
for i in range(n_epochs):
|
||||
recycle_schedule=[np.random.randint(1,max_cycle+1) for _ in range(n_train//world_size)]
|
||||
recycle_schedules.append(torch.tensor(recycle_schedule))
|
||||
return torch.stack(recycle_schedules, dim=0)
|
||||
|
||||
def get_recycle_schedule_opt(max_cycle, n_epochs, n_train, world_size, **kwargs):
|
||||
assert n_train % world_size == 0
|
||||
np.random.seed(0)
|
||||
recycle_schedule = np.random.randint(1, max_cycle+1, (n_epochs, n_train // world_size))
|
||||
return torch.tensor(recycle_schedule)
|
||||
|
||||
def get_random_recycle(max_cycle, **kwargs):
|
||||
N_cycle = np.random.randint(1, max_cycle+1)
|
||||
return N_cycle
|
||||
|
||||
|
||||
recycle_sampling = {
|
||||
"random": get_random_recycle,
|
||||
"by_batch": get_recycle_schedule_opt
|
||||
}
|
||||
@@ -287,7 +287,7 @@ def xyz_t_to_frame_xyz_sm_mask(xyz_t, is_sm, atom_frames):
|
||||
return xyz_t_frame
|
||||
|
||||
def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
|
||||
B,L,natoms = xyz_in.shape[:3]
|
||||
#B,L,natoms = xyz_in.shape[:3]
|
||||
frames = frame_indices[seq]
|
||||
atoms = is_atom(seq)
|
||||
if torch.any(atoms):
|
||||
@@ -1117,6 +1117,7 @@ def atomize_protein(i_start, msa, xyz, mask, n_res_atomize=5):
|
||||
lig_mask = residue_atomize_mask[r, a].repeat(nat_symm.shape[0], 1)
|
||||
bond_feats = get_atomize_protein_bond_feats(i_start, msa, ra, n_res_atomize=n_res_atomize)
|
||||
#HACK: use networkx graph to make the atom frames, correct implementation will include frames with "residue atoms"
|
||||
# NOTE: REQUIRES NETWORKX < 3.0
|
||||
G = nx.from_numpy_matrix(bond_feats.numpy())
|
||||
|
||||
frames = get_atom_frames(lig_seq, G)
|
||||
@@ -1628,7 +1629,7 @@ def get_alt_query_ligand(chains, ligand_name, partners, lig_akeys, asmb_xfs):
|
||||
|
||||
return xyz_alt_s, mask_alt_s
|
||||
|
||||
def get_automorphs(mol, xyz_sm, mask_sm):
|
||||
def get_automorphs(mol, xyz_sm, mask_sm, max_symm=1000):
|
||||
"""Enumerate atom symmetry permutations."""
|
||||
try:
|
||||
automorphs = openbabel.vvpairUIntUInt()
|
||||
@@ -1647,7 +1648,9 @@ def get_automorphs(mol, xyz_sm, mask_sm):
|
||||
except Exception as e:
|
||||
xyz_sm = xyz_sm[None]
|
||||
mask_sm = mask_sm[None]
|
||||
|
||||
if xyz_sm.shape[0] > max_symm:
|
||||
xyz_sm = xyz_sm[:max_symm]
|
||||
mask_sm = mask_sm[:max_symm]
|
||||
return xyz_sm, mask_sm
|
||||
|
||||
def expand_xyz_sm_to_ntotal(xyz_sm, mask_sm, N_symmetry=None):
|
||||
|
||||
@@ -262,8 +262,7 @@ def make_full_graph(xyz, pair, idx):
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
|
||||
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
|
||||
return G, pair[b,i,j][...,None]
|
||||
|
||||
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
|
||||
|
||||
69
rf2aa/validate.py
Normal file
69
rf2aa/validate.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import os
|
||||
import hydra
|
||||
import pandas as pd
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from rf2aa.trainer_new import LegacyTrainer
|
||||
from rf2aa.data.compose_dataset import compose_posebusters
|
||||
|
||||
class Validator(LegacyTrainer):
|
||||
|
||||
def evaluate_model(self, rank, world_size):
|
||||
raise NotImplementedError()
|
||||
|
||||
def compose_dataset(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def valid_step(self, inputs, n_cycle):
|
||||
pass
|
||||
|
||||
|
||||
class PoseBustersBenchmark(Validator):
|
||||
|
||||
def construct_dataset(self, rank, world_size):
|
||||
return compose_posebusters(self.config.loader_params, rank, world_size)
|
||||
|
||||
def evaluate_model(self, rank, world_size):
|
||||
world_size = torch.cuda.device_count()
|
||||
if ('MASTER_ADDR' not in os.environ):
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
|
||||
if ('MASTER_PORT' not in os.environ):
|
||||
os.environ['MASTER_PORT'] = '%d'%self.config.ddp_params.port
|
||||
|
||||
gpu = self.init_process_group(rank, world_size)
|
||||
benchmark_loader = self.construct_dataset(rank, world_size)
|
||||
|
||||
# move global information to device
|
||||
self.move_constants_to_device(gpu)
|
||||
|
||||
self.construct_model(device=gpu)
|
||||
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
|
||||
self.load_checkpoint(rank)
|
||||
self.load_model()
|
||||
self.model.eval()
|
||||
records = []
|
||||
for inputs in benchmark_loader:
|
||||
item = inputs[-1]
|
||||
with torch.no_grad():
|
||||
loss, loss_dict = self.train_step(inputs, self.config.loader_params.maxcycle)
|
||||
loss_dict["CHAINID"] = item["CHAINID"][0]
|
||||
for k, v in loss_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
loss_dict[k] = v.item()
|
||||
records.append(loss_dict)
|
||||
df = pd.DataFrame(records)
|
||||
df.to_csv(f"{self.output_dir}/{self.config.experiment.name}_posebusters.csv")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/train')
|
||||
def main(config):
|
||||
benchmarker = PoseBustersBenchmark(config=config)
|
||||
benchmarker.evaluate_model(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user