mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Add af3-style 2track net
Add config file Error in config file Update config Fix unit tests. Update config Turn of bias in attention norms Add some temp fixes Change amp to use bfloat16 More model stability changes add allatom data transform added in allatom noising and denoising allatom flow matching trains now add loss params in generative refinement yaml Some fixes and reorg for diffusive training Remove debug output; update config Silly bugfix Validation bugfix. Adjust diffusive weights. One more bugfix some progress towards validation in allatom fm Add recycling. Add stability changes from trunk. Fix rare bug in get_residue_contacts more work towards sampler: runs validation, but ooms updates for diffusion infernece in validation Sampler respects use_amp flag Revert "Sampler respects use_amp flag" This reverts commit 722c32dda150bc3f12167839bd4cbae2fe8026c5. changes for training Move hardcoded options to options system. Add proper conversion of atom/bond level features changes to allow gradients to flow through diffusion module added mlm back into config fix d_t bug Several updates to allatom gen refinement: 1) DNA and ligand bond distances are correctly computed; 2) atom graph includes nearby-bonded neighbors even when distance is large; 3) small change to default parameters. Update/bugfix BiasedSequenceAttention in refinement Reenable checkpointing Small bugfix. Set bonded atoms to D=-1 so they are always preferred over D=0 atoms Updates to AF2-like training Change af3-like architecture config Bugfixes in validation. Small config updates for stable training dt buig added timstep embedding Some very minor stability fixes Msa module Re commit accidentally reverted changes. Track t in loss reporting. Revert some of the scaling I was doing. add alignment to vector field matching Several bugfixes, code simplifications
This commit is contained in:
@@ -222,6 +222,7 @@ class ConvSE3(nn.Module):
|
||||
max_degree: int = 4,
|
||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||
allow_fused_output: bool = False,
|
||||
sum_over_edge: bool = True,
|
||||
low_memory: bool = False
|
||||
):
|
||||
"""
|
||||
@@ -242,6 +243,7 @@ class ConvSE3(nn.Module):
|
||||
self.self_interaction = self_interaction
|
||||
self.max_degree = max_degree
|
||||
self.allow_fused_output = allow_fused_output
|
||||
self.sum_over_edge = sum_over_edge
|
||||
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
|
||||
|
||||
# channels_in: account for the concatenation of edge features
|
||||
|
||||
@@ -50,7 +50,10 @@ class NormSE3(nn.Module):
|
||||
└──> feature_phase ────────────────────────────┘
|
||||
"""
|
||||
|
||||
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||
#NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||
|
||||
#fd w/o disconnected gradients this value still causes exploding grads
|
||||
NORM_CLAMP = 2 ** -16
|
||||
|
||||
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
||||
super().__init__()
|
||||
|
||||
@@ -381,6 +381,54 @@ class ChemicalData:
|
||||
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,"N" ,"C" ,"N" ,"N" ,"C" ,"C" ,"C" ,"O" ,"N" ,"C" ,"N" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None),#G
|
||||
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,"N" ,"C" ,"O" ,"N" ,"C" ,"O" ,"C" ,"C" ,None,None,None,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None,None),#U
|
||||
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,None,None,None,None,None,None,None,None,None,None,None,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None,None,None,None),#RX
|
||||
("N", "C","C" ,"O", "C", "C", "N" ,"C" ,"C" ,"N" ,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),# HIS-D NOT CORRECT!!!!!!!!!!
|
||||
(None,"Al",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Al
|
||||
(None,"As",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# As
|
||||
(None,"Au",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Au
|
||||
(None,"B",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# B
|
||||
(None,"Be",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Be
|
||||
(None,"Br",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Br
|
||||
(None,"C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# C
|
||||
(None,"Ca",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ca
|
||||
(None,"Cl",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cl
|
||||
(None,"Co",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Co
|
||||
(None,"Cr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cr
|
||||
(None,"Cu",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cu
|
||||
(None,"F",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# F
|
||||
(None,"Fe",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Fe
|
||||
(None,"Hg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Hg
|
||||
(None,"I",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# I
|
||||
(None,"Ir",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ir
|
||||
(None,"K",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# K
|
||||
(None,"Li",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Li
|
||||
(None,"Mg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mg
|
||||
(None,"Mn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mn
|
||||
(None,"Mo",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mo
|
||||
(None,"N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# N
|
||||
(None,"Ni",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ni
|
||||
(None,"O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# O
|
||||
(None,"Os",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Os
|
||||
(None,"P",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# P
|
||||
(None,"Pb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pb
|
||||
(None,"Pd",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pd
|
||||
(None,"Pr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pr
|
||||
(None,"Pt",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pt
|
||||
(None,"Re",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Re
|
||||
(None,"Rh",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Rh
|
||||
(None,"Ru",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ru
|
||||
(None,"S",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# S
|
||||
(None,"Sb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sb
|
||||
(None,"Se",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Se
|
||||
(None,"Si",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Si
|
||||
(None,"Sn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sn
|
||||
(None,"Tb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Tb
|
||||
(None,"Te",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Te
|
||||
(None,"U",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# U
|
||||
(None,"W",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# W
|
||||
(None,"V",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# V
|
||||
(None,"Y",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Y
|
||||
(None,"Zn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Zn
|
||||
(None,"ATM",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# ATM ]
|
||||
]
|
||||
|
||||
# frames for generic FAPE
|
||||
@@ -1171,6 +1219,9 @@ class ChemicalData:
|
||||
],
|
||||
]
|
||||
|
||||
# atom ids forming polymer connections
|
||||
self.protein_connect = (0,2) # N, C
|
||||
self.na_connect = (4,10) # P, O3'
|
||||
else:
|
||||
# USE PHOSPHATE FRAME
|
||||
self.aa2long=[
|
||||
@@ -1363,6 +1414,54 @@ class ChemicalData:
|
||||
("O","P","O","O","C","C","O","C","O","C","C","O","N","C","N","N","C","C","C","O","N","C","N","H","H","H","H","H","H","H","H","H","H","H",None,None),#G
|
||||
("O","P","O","O","C","C","O","C","O","C","C","O","N","C","O","N","C","O","C","C",None,None,None,"H","H","H","H","H","H","H","H","H","H",None,None,None),#U
|
||||
("O","P","O","O","C","C","O","C","O","C","C","O",None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H",None,None,None,None,None),#RX
|
||||
("N", "C","C" ,"O", "C", "C", "N" ,"C" ,"C" ,"N" ,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),# HIS-D NOT CORRECT!!!!!!!!!!
|
||||
(None,"Al",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Al
|
||||
(None,"As",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# As
|
||||
(None,"Au",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Au
|
||||
(None,"B",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# B
|
||||
(None,"Be",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Be
|
||||
(None,"Br",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Br
|
||||
(None,"C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# C
|
||||
(None,"Ca",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ca
|
||||
(None,"Cl",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cl
|
||||
(None,"Co",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Co
|
||||
(None,"Cr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cr
|
||||
(None,"Cu",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cu
|
||||
(None,"F",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# F
|
||||
(None,"Fe",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Fe
|
||||
(None,"Hg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Hg
|
||||
(None,"I",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# I
|
||||
(None,"Ir",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ir
|
||||
(None,"K",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# K
|
||||
(None,"Li",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Li
|
||||
(None,"Mg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mg
|
||||
(None,"Mn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mn
|
||||
(None,"Mo",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mo
|
||||
(None,"N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# N
|
||||
(None,"Ni",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ni
|
||||
(None,"O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# O
|
||||
(None,"Os",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Os
|
||||
(None,"P",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# P
|
||||
(None,"Pb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pb
|
||||
(None,"Pd",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pd
|
||||
(None,"Pr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pr
|
||||
(None,"Pt",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pt
|
||||
(None,"Re",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Re
|
||||
(None,"Rh",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Rh
|
||||
(None,"Ru",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ru
|
||||
(None,"S",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# S
|
||||
(None,"Sb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sb
|
||||
(None,"Se",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Se
|
||||
(None,"Si",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Si
|
||||
(None,"Sn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sn
|
||||
(None,"Tb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Tb
|
||||
(None,"Te",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Te
|
||||
(None,"U",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# U
|
||||
(None,"W",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# W
|
||||
(None,"V",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# V
|
||||
(None,"Y",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Y
|
||||
(None,"Zn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Zn
|
||||
(None,"ATM",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# ATM ]
|
||||
]
|
||||
|
||||
# frames for generic FAPE
|
||||
@@ -2151,6 +2250,12 @@ class ChemicalData:
|
||||
],
|
||||
]
|
||||
|
||||
# atom ids forming polymer connections
|
||||
self.protein_connect = (0,2) # N, C
|
||||
self.na_connect = (1,8) # P, O3'
|
||||
|
||||
# general case
|
||||
|
||||
self.aabonds=[
|
||||
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
|
||||
((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # ala
|
||||
@@ -2444,6 +2549,8 @@ class ChemicalData:
|
||||
for j,a in enumerate(i_l):
|
||||
if (a is None):
|
||||
self.long2alt[i,j] = j
|
||||
elif ("H" in a):
|
||||
self.long2alt[i,j] = i_l.index(a)
|
||||
else:
|
||||
self.long2alt[i,j] = i_lalt.index(a)
|
||||
self.allatom_mask[i,j] = True
|
||||
@@ -2453,14 +2560,18 @@ class ChemicalData:
|
||||
self.allatom_mask[self.NNAPROTAAS:,1] = True
|
||||
|
||||
# bond graph traversal
|
||||
self.num_bonds = torch.zeros((self.NAATOKENS,self.NTOTAL,self.NTOTAL), dtype=torch.long)
|
||||
self.MAX_BOND_DIST = 9 # largest bond separation we consider
|
||||
self.num_bonds = torch.full((self.NAATOKENS,self.NTOTAL,self.NTOTAL), self.MAX_BOND_DIST, dtype=torch.long)
|
||||
self.num_bonds[:,torch.arange(self.NTOTAL),torch.arange(self.NTOTAL)] = 0 # atom self-interaction
|
||||
|
||||
# compute for all protein & na using csgraph.shortest_path
|
||||
for i in range(self.NNAPROTAAS):
|
||||
num_bonds_i = np.zeros((self.NTOTAL,self.NTOTAL))
|
||||
for (bnamei,bnamej) in self.aabonds[i]:
|
||||
bi,bj = self.aa2long[i].index(bnamei),self.aa2long[i].index(bnamej)
|
||||
num_bonds_i[bi,bj] = 1
|
||||
num_bonds_i = scipy.sparse.csgraph.shortest_path (num_bonds_i,directed=False)
|
||||
num_bonds_i[num_bonds_i>=4] = 4
|
||||
num_bonds_i[num_bonds_i>=self.MAX_BOND_DIST] = self.MAX_BOND_DIST
|
||||
self.num_bonds[i,...] = torch.tensor(num_bonds_i)
|
||||
|
||||
|
||||
@@ -2472,6 +2583,21 @@ class ChemicalData:
|
||||
self.idx2aatype.append(y)
|
||||
self.aatype2idx = {x:i for i,x in enumerate(self.idx2aatype)}
|
||||
|
||||
# element indices
|
||||
self.idx2elt = []
|
||||
for x in self.aa2elt:
|
||||
for y in x:
|
||||
if y and y not in self.idx2elt:
|
||||
self.idx2elt.append(y)
|
||||
self.elt2idx = {x:i for i,x in enumerate(self.idx2elt)}
|
||||
|
||||
self.aa2eltidx = torch.zeros((self.NAATOKENS,self.NTOTAL), dtype=torch.long)
|
||||
for i in range(self.NAATOKENS):
|
||||
for j in range(self.NTOTAL):
|
||||
if self.aa2elt[i][j] is not None:
|
||||
self.aa2eltidx[i,j] = self.elt2idx[ self.aa2elt[i][j] ]
|
||||
self.NELTTYPES = len(self.elt2idx)
|
||||
|
||||
# LJ/LK scoring parameters
|
||||
self.atom_type_index = torch.zeros((self.NAATOKENS,self.NTOTAL), dtype=torch.long)
|
||||
|
||||
|
||||
222
rf2aa/config/train/af3.yaml
Normal file
222
rf2aa/config/train/af3.yaml
Normal file
@@ -0,0 +1,222 @@
|
||||
defaults:
|
||||
- base
|
||||
|
||||
experiment:
|
||||
name: rf2aa-af3_v4
|
||||
trainer: "flow_matching"
|
||||
|
||||
training_params:
|
||||
resume_from_checkpoint_path: /home/dimaio/RF2-allatom-af3/rf2aa/models/rf2aa-af3_v4_last.pt
|
||||
reset_optimizer_params: True
|
||||
EMA: 0.99
|
||||
weight_decay: 0.01
|
||||
learning_rate: .001
|
||||
learning_rate_schedule:
|
||||
num_warmup_steps: 0
|
||||
num_steps_decay: 5000
|
||||
decay_rate: 0.95
|
||||
grad_clip: 0.2
|
||||
use_amp: False
|
||||
ddp_backend: nccl
|
||||
|
||||
loader_params:
|
||||
maxseq: 1024
|
||||
maxtoken: 1024
|
||||
maxlat: 1
|
||||
|
||||
ddp_params:
|
||||
accum: 1
|
||||
batch_size: 1
|
||||
port: 12465
|
||||
|
||||
loss_param:
|
||||
w_dist: 1.0
|
||||
w_str: 0.0
|
||||
w_inter_fape: 0.0
|
||||
w_lig_fape: 0.0
|
||||
w_lddt: 0.0
|
||||
w_aa: 0.5
|
||||
w_bond: 0.0
|
||||
w_bind: 0.0
|
||||
binder_loss_label_smoothing: 0.0
|
||||
w_clash: 0.0
|
||||
w_atom_bond: 0.0
|
||||
w_skip_bond: 0.0
|
||||
w_rigid: 0.0
|
||||
w_hb: 0.0
|
||||
w_pae: 0.00
|
||||
w_pde: 0.00
|
||||
lj_lin: 0.75
|
||||
w_trans: 1.0
|
||||
t_normalize_clip: 0.925
|
||||
trans_scale: 1.0
|
||||
|
||||
interpolant:
|
||||
min_t: 1e-2
|
||||
separate_t: False
|
||||
provide_kappa: False
|
||||
hierarchical_t: False
|
||||
codesign_separate_t: False
|
||||
codesign_forward_fold_prop: 0.0
|
||||
codesign_inverse_fold_prop: 0.0
|
||||
|
||||
twisting:
|
||||
use: False
|
||||
|
||||
rots:
|
||||
corrupt: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
exp_rate: 10
|
||||
|
||||
trans:
|
||||
corrupt: True
|
||||
batch_ot: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
sample_temp: 1.0
|
||||
vpsde_bmin: 0.1
|
||||
vpsde_bmax: 20.0
|
||||
potential: null
|
||||
potential_t_scaling: False
|
||||
rog:
|
||||
weight: 10.0
|
||||
cutoff: 5.0
|
||||
|
||||
aatypes:
|
||||
corrupt: False
|
||||
schedule: linear
|
||||
schedule_exp_rate: 10
|
||||
temp: 1.0
|
||||
noise: 0.0
|
||||
do_purity: False
|
||||
train_extra_mask: 0.0
|
||||
interpolant_type: masking
|
||||
num_tokens: 80
|
||||
|
||||
sampling:
|
||||
num_timesteps: 10
|
||||
do_sde: False
|
||||
self_condition: False
|
||||
|
||||
dataset_params:
|
||||
n_train: 25600
|
||||
validate_every_n_epochs: 1
|
||||
fraction_pdb: 0.12
|
||||
fraction_fb: 0.36
|
||||
fraction_compl: 0.055
|
||||
fraction_neg_compl: 0
|
||||
fraction_na_compl: 0.055
|
||||
fraction_neg_na_compl: 0
|
||||
fraction_distil_tf: 0.055
|
||||
fraction_tf: 0
|
||||
fraction_neg_tf: 0
|
||||
fraction_rna: 0.02
|
||||
fraction_dna: 0.005
|
||||
fraction_sm_compl: 0.11
|
||||
fraction_metal_compl: 0.03
|
||||
fraction_sm_compl_multi: 0.025
|
||||
fraction_sm_compl_covale: 0.025
|
||||
fraction_sm: 0.0
|
||||
fraction_atomize_pdb: 0.0425
|
||||
fraction_atomize_complex: 0.0425
|
||||
fraction_sm_compl_asmb: 0.055
|
||||
n_valid_pdb: 256
|
||||
n_valid_homo: 0
|
||||
n_valid_dslf: 0
|
||||
n_valid_compl: 256
|
||||
n_valid_neg_compl: 0
|
||||
n_valid_na_compl: 256
|
||||
n_valid_neg_na_compl: 0
|
||||
n_valid_distil_tf: 0
|
||||
n_valid_tf: 0
|
||||
n_valid_neg_tf: 0
|
||||
n_valid_rna: 256
|
||||
n_valid_dna: 256
|
||||
n_valid_sm_compl: 256
|
||||
n_valid_metal_compl: 0
|
||||
n_valid_sm_compl_multi: 0
|
||||
n_valid_sm_compl_covale: 0
|
||||
n_valid_sm_compl_strict: 256
|
||||
n_valid_sm: 0
|
||||
n_valid_atomize_pdb: 0
|
||||
n_valid_atomize_complex: 0
|
||||
n_valid_sm_compl_asmb: 0
|
||||
p_homo_cut: 0
|
||||
p_short_crop: 0
|
||||
p_dslf_crop: 0
|
||||
dslf_fb_upsample: 1
|
||||
|
||||
model:
|
||||
global_params:
|
||||
d_msa: 384
|
||||
d_msa_full: 64
|
||||
d_pair: 128
|
||||
d_state: 128
|
||||
|
||||
embedding:
|
||||
rf2aa:
|
||||
params:
|
||||
p_drop: 0.15
|
||||
d_templ: 64
|
||||
n_head_templ: 4
|
||||
d_hidden_templ: 64
|
||||
templ_p_drop: 0.25
|
||||
symmetrize_repeats: False
|
||||
repeat_length: null
|
||||
symmsub_k: null
|
||||
sym_method: null
|
||||
main_block: null
|
||||
copy_main_block_template: False
|
||||
additional_dt1d: 0
|
||||
recycling_type: "msa_pair_only"
|
||||
use_same_chain: False
|
||||
blocks:
|
||||
AF3_full:
|
||||
num_blocks: 4
|
||||
params:
|
||||
p_drop_row: 0.25
|
||||
p_drop_pair: 0.25
|
||||
msa_transition_drop: 0.0
|
||||
outer_product_channels: 16
|
||||
p_drop_outer_product: 0.0
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 32
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 8
|
||||
norm_msa_row: False
|
||||
AF3:
|
||||
num_blocks: 48
|
||||
params:
|
||||
p_drop_row: 0.25
|
||||
p_drop_pair: 0.25
|
||||
msa_transition_drop: 0.0
|
||||
outer_product_channels: 16
|
||||
p_drop_outer_product: 0.0
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 32
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 32
|
||||
norm_msa_row: False
|
||||
refinement:
|
||||
generative:
|
||||
params:
|
||||
num_attention_layers: 24
|
||||
num_channels: 32
|
||||
num_degrees: 2
|
||||
num_layers: 3
|
||||
n_heads: 4
|
||||
div: 4
|
||||
l0_in_features: 128
|
||||
l0_out_features: 128
|
||||
num_edge_features: 32
|
||||
n_channels: 32
|
||||
msa_transition_drop: 0.0
|
||||
compute_gradients: True
|
||||
auxiliary_predictors:
|
||||
c6d:
|
||||
n_feat: 128
|
||||
input_feature: "pair"
|
||||
mlm:
|
||||
n_feat: 384
|
||||
input_feature: "msa"
|
||||
@@ -7,7 +7,7 @@ experiment:
|
||||
model:
|
||||
embedding: null
|
||||
blocks: null
|
||||
refinment: null
|
||||
refinement: null
|
||||
auxiliary_predictors: {}
|
||||
legacy_model: null
|
||||
dataset_params:
|
||||
@@ -130,6 +130,7 @@ loss_param:
|
||||
w_hb: 0.0
|
||||
w_pae: 0.05
|
||||
w_pde: 0.05
|
||||
w_trans: 0.0
|
||||
lj_lin: 0.75
|
||||
|
||||
log_params:
|
||||
|
||||
@@ -20,7 +20,7 @@ interpolant:
|
||||
rots:
|
||||
corrupt: True
|
||||
train_schedule: linear
|
||||
sample_schedule: exp
|
||||
sample_schedule: linear
|
||||
exp_rate: 10
|
||||
|
||||
trans:
|
||||
@@ -52,5 +52,3 @@ interpolant:
|
||||
num_timesteps: 100
|
||||
do_sde: False
|
||||
self_condition: False
|
||||
training_params:
|
||||
resume_from_checkpoint_path: /net/tukwila/rohith/assorted_weights/models/rf2aa-baseline-noinplace_134.pt
|
||||
|
||||
148
rf2aa/config/train/generative_refinement.yaml
Normal file
148
rf2aa/config/train/generative_refinement.yaml
Normal file
@@ -0,0 +1,148 @@
|
||||
defaults:
|
||||
- rf2aa
|
||||
|
||||
experiment:
|
||||
name: rfaa-flow-matching-allatom
|
||||
trainer: "flow_matching"
|
||||
|
||||
interpolant:
|
||||
min_t: 1e-2
|
||||
separate_t: False
|
||||
provide_kappa: False
|
||||
hierarchical_t: False
|
||||
codesign_separate_t: False
|
||||
codesign_forward_fold_prop: 0.0
|
||||
codesign_inverse_fold_prop: 0.0
|
||||
|
||||
twisting:
|
||||
use: False
|
||||
|
||||
rots:
|
||||
corrupt: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
exp_rate: 10
|
||||
|
||||
trans:
|
||||
corrupt: True
|
||||
batch_ot: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
sample_temp: 1.0
|
||||
vpsde_bmin: 0.1
|
||||
vpsde_bmax: 20.0
|
||||
potential: null
|
||||
potential_t_scaling: False
|
||||
rog:
|
||||
weight: 10.0
|
||||
cutoff: 5.0
|
||||
|
||||
aatypes:
|
||||
corrupt: False
|
||||
schedule: linear
|
||||
schedule_exp_rate: 10
|
||||
temp: 1.0
|
||||
noise: 0.0
|
||||
do_purity: False
|
||||
train_extra_mask: 0.0
|
||||
interpolant_type: masking
|
||||
num_tokens: 80
|
||||
|
||||
sampling:
|
||||
num_timesteps: 20
|
||||
do_sde: False
|
||||
self_condition: False
|
||||
|
||||
model:
|
||||
blocks:
|
||||
RF2aa_full:
|
||||
num_blocks: 1
|
||||
params:
|
||||
d_rbf: 64
|
||||
p_drop_row: 0.25
|
||||
p_drop_pair: 0.25
|
||||
p_drop_layer: 0.0
|
||||
msa_transition_drop: 0.0
|
||||
outer_product_channels: 16
|
||||
p_drop_outer_product: 0.0
|
||||
n_pair_head: 6
|
||||
n_pair_channels: 32
|
||||
n_msa_head: 8
|
||||
n_msa_channels: 8
|
||||
structure_bias_gate_channels: 16
|
||||
structure_bias_channels: 64
|
||||
n_se3_layers: 1
|
||||
n_se3_channels: 32
|
||||
n_se3_degrees: 2
|
||||
n_se3_head: 4
|
||||
n_div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 6
|
||||
l1_out_features: 2
|
||||
n_se3_edge_features: 64
|
||||
sc_pred_d_hidden: 128
|
||||
sc_pred_p_drop: 0.0
|
||||
residual_state: True
|
||||
RF2aa:
|
||||
num_blocks: 1
|
||||
refinement:
|
||||
generative:
|
||||
params:
|
||||
num_atom_encode_layers: 1
|
||||
num_attention_layers: 1
|
||||
num_channels: 32
|
||||
num_degrees: 2
|
||||
n_heads: 4
|
||||
div: 4
|
||||
l0_in_features: 64
|
||||
l0_out_features: 64
|
||||
l1_in_features: 3
|
||||
l1_out_features: 1
|
||||
num_edge_features: 4
|
||||
n_channels: 32
|
||||
msa_transition_drop: 0.0
|
||||
|
||||
dataset_params:
|
||||
n_train: 24000
|
||||
validate_every_n_epochs: 1
|
||||
fraction_pdb: 0.12
|
||||
fraction_fb: 0.36
|
||||
fraction_compl: 0.055
|
||||
fraction_neg_compl: 0
|
||||
fraction_na_compl: 0.055
|
||||
fraction_neg_na_compl: 0
|
||||
fraction_distil_tf: 0.055
|
||||
fraction_tf: 0
|
||||
fraction_neg_tf: 0
|
||||
fraction_rna: 0.02
|
||||
fraction_dna: 0.005
|
||||
fraction_sm_compl: 0.11
|
||||
fraction_metal_compl: 0.03
|
||||
fraction_sm_compl_multi: 0.025
|
||||
fraction_sm_compl_covale: 0.025
|
||||
fraction_sm: 0.0
|
||||
fraction_atomize_pdb: 0.0425
|
||||
fraction_atomize_complex: 0.0425
|
||||
fraction_sm_compl_asmb: 0.055
|
||||
loss_param:
|
||||
w_trans: 10.0
|
||||
t_normalize_clip: 0.9
|
||||
trans_scale: 0.1
|
||||
w_dist: 1.0
|
||||
w_str: 0.0
|
||||
w_inter_fape: 0.0
|
||||
w_lig_fape: 0.0
|
||||
w_lddt: 0.0
|
||||
w_aa: 3.0
|
||||
w_bond: 0.0
|
||||
w_bind: 0.0
|
||||
binder_loss_label_smoothing: 0.0
|
||||
w_clash: 0.0
|
||||
w_atom_bond: 0.0
|
||||
w_skip_bond: 0.0
|
||||
w_rigid: 0.0
|
||||
w_hb: 0.0
|
||||
w_pae: 0.0
|
||||
w_pde: 0.0
|
||||
lj_lin: 0.75
|
||||
@@ -30,6 +30,7 @@ model:
|
||||
n_se3_edge_features: 32
|
||||
residual_state: True
|
||||
norm_msa_row: True
|
||||
bias_in_attn_norm: True
|
||||
use_flash_attn: True
|
||||
|
||||
RF2aa:
|
||||
@@ -41,6 +42,7 @@ model:
|
||||
n_se3_edge_features: 32
|
||||
residual_state: True
|
||||
norm_msa_row: True
|
||||
bias_in_attn_norm: True
|
||||
use_flash_attn: True
|
||||
|
||||
refinement:
|
||||
|
||||
@@ -299,7 +299,7 @@ def MSAFeaturize(
|
||||
ins,
|
||||
params,
|
||||
p_mask=0.15,
|
||||
eps=1e-6,
|
||||
eps=1e-4,
|
||||
nmer=1,
|
||||
L_s=[],
|
||||
term_info=None,
|
||||
|
||||
@@ -16,7 +16,6 @@ def prepare_input(inputs, xyz_converter, gpu):
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
|
||||
# transfer inputs to device
|
||||
B, _, N, L = msa.shape
|
||||
|
||||
@@ -141,6 +140,7 @@ def get_loss_calc_items(inputs,device="cpu"):
|
||||
return seq.to(device), same_chain.to(device), idx_pdb.to(device), bond_feats.to(device), \
|
||||
dist_matrix.to(device), atom_frames.to(device), true_crds.to(device), mask_crds.to(device)
|
||||
|
||||
# prepate inputs for flow matching
|
||||
def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
@@ -158,7 +158,7 @@ def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
|
||||
msa = msa.to(device, non_blocking=True)
|
||||
atom_frames = atom_frames.to(device, non_blocking=True)
|
||||
|
||||
interpolant.set_device(device)
|
||||
#interpolant.set_device(device)
|
||||
batch = data_transforms.convert_dataloader_inputs_to_rigids(inputs, interpolant._device)
|
||||
noisy_batch = interpolant.corrupt_batch(batch)
|
||||
rotmats, trans = noisy_batch["rotmats_t"], noisy_batch["trans_t"]
|
||||
@@ -213,6 +213,62 @@ def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
|
||||
symmRs = None
|
||||
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
|
||||
|
||||
def prepare_input_fm_allatom(inputs, interpolant, xyz_converter, device="cpu"):
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input(inputs, xyz_converter, device)
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
B, _, N, L = msa.shape
|
||||
|
||||
idx_pdb = idx_pdb.to(device, non_blocking=True) # (B, L)
|
||||
true_crds = true_crds.to(device, non_blocking=True) # (B, L, 27, 3)
|
||||
mask_crds = mask_crds.to(device, non_blocking=True) # (B, L, 27)
|
||||
same_chain = same_chain.to(device, non_blocking=True)
|
||||
t1d = t1d.to(device, non_blocking=True)
|
||||
seq = seq.to(device, non_blocking=True)
|
||||
msa = msa.to(device, non_blocking=True)
|
||||
atom_frames = atom_frames.to(device, non_blocking=True)
|
||||
|
||||
interpolant.set_device(device)
|
||||
allatom_mask = ChemData().allatom_mask.to(device, non_blocking=True)
|
||||
|
||||
# remove symmetry dimension
|
||||
if len(true_crds.shape) == 5:
|
||||
true_crds = true_crds[:, 0:1]
|
||||
mask_crds = mask_crds[:, 0:1]
|
||||
else:
|
||||
true_crds = true_crds[None]
|
||||
mask_crds = mask_crds[None]
|
||||
|
||||
# want to unroll the coordinate tensors to get the full coordinates in (atoms, 3)
|
||||
is_real_atom = allatom_mask[msa[:,0,0]][None].bool()
|
||||
atom_coords = true_crds[is_real_atom]
|
||||
atom_mask = mask_crds[is_real_atom]
|
||||
t = interpolant.sample_t(B)[:, None]
|
||||
#t = torch.tensor([[0.95]], device=t.device)
|
||||
|
||||
trans_1 = center_allatom_chain(atom_coords, atom_mask)
|
||||
|
||||
#fd what to do with "real" atoms that do not have a native position?
|
||||
#fd -> lets try the origin
|
||||
trans_1 = trans_1.nan_to_num()
|
||||
|
||||
diffuse_mask = torch.ones_like(atom_mask).long()
|
||||
trans_t = interpolant._corrupt_trans(trans_1[None], t, atom_mask[None], diffuse_mask[None])
|
||||
xyz_noised_t = torch.zeros((1, B, L, ChemData().NTOTAL, 3), device=device)
|
||||
xyz_noised_t[is_real_atom] = trans_t
|
||||
|
||||
# add trans_t (and t) to network input
|
||||
network_input["trans_t"] = xyz_noised_t
|
||||
network_input["t"] = t
|
||||
|
||||
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, \
|
||||
unclamp, negative, symmRs, Lasu, ch_label, t, trans_1, atom_mask
|
||||
|
||||
|
||||
def construct_template_feats(xyz_t, mask_t, t1d, seq, atom_frames, xyz_converter, use_atom_frames=False):
|
||||
B, T, Lasu, _, _ = xyz_t.shape
|
||||
@@ -243,3 +299,16 @@ def construct_template_feats(xyz_t, mask_t, t1d, seq, atom_frames, xyz_converter
|
||||
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, Lasu*Osub, 3*ChemData().NTOTALDOFS)
|
||||
alpha_prev = torch.zeros((B,Lasu*Osub,ChemData().NTOTALDOFS,2))
|
||||
return t2d, mask_t_2d, alpha_t, alpha_prev
|
||||
|
||||
def center_allatom_chain(atom_coords, atom_mask):
|
||||
"""
|
||||
center allatom coordinates at origin
|
||||
"""
|
||||
assert len(atom_coords.shape) == 2
|
||||
assert len(atom_mask.shape) == 1
|
||||
assert atom_coords.shape[-1] == 3
|
||||
|
||||
atom_coords_allatom = atom_coords[atom_mask]
|
||||
atom_coords_allatom = atom_coords_allatom - atom_coords_allatom.mean(dim=0)
|
||||
atom_coords[atom_mask] = atom_coords_allatom
|
||||
return atom_coords
|
||||
@@ -28,3 +28,7 @@ def debug_grads(model):
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
print(f"{name}: {param.grad.norm().item()}")
|
||||
|
||||
def debug_nan_params(model):
|
||||
for name, param in model.named_parameters():
|
||||
print(f"{name}: {torch.sum(param.isnan())}")
|
||||
|
||||
@@ -191,7 +191,7 @@ def calculate_neighbor_angles(R_ac, R_ab):
|
||||
# sin(alpha) = |u x v| / (|u|*|v|)
|
||||
y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
|
||||
# avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
|
||||
y = torch.max(y, torch.tensor(1e-9))
|
||||
y = torch.max(y, torch.tensor(1e-4))
|
||||
angle = torch.atan2(y, x)
|
||||
return angle
|
||||
|
||||
|
||||
@@ -8,14 +8,9 @@ import pickle
|
||||
import os
|
||||
import torch
|
||||
from typing import List, Dict, Any
|
||||
#import tree
|
||||
from rf2aa.flow_matching import rigid_utils as ru
|
||||
from torch_scatter import scatter_add, scatter
|
||||
#from Bio.PDB.Chain import Chain
|
||||
#from Bio import PDB
|
||||
#from cogen.data import protein, residue_constants, parsers
|
||||
from torch_geometric.utils import scatter
|
||||
from glob import glob
|
||||
#from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
Rigid = ru.Rigid
|
||||
#Protein = protein.Protein
|
||||
@@ -140,19 +135,19 @@ def adjust_oxygen_pos(
|
||||
|
||||
# Calpha to carbonyl both in the current frame.
|
||||
calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / (
|
||||
torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7
|
||||
torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-4
|
||||
)
|
||||
# For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0.
|
||||
# The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change.
|
||||
# For masked positions, they are all 0 and so we add 1e-4 to avoid division by 0.
|
||||
# The positions are in Angstroms and so are on the order ~1 so 1e-4 is an insignificant change.
|
||||
|
||||
# Nitrogen of the next frame to carbonyl of the current frame.
|
||||
nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / (
|
||||
torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7
|
||||
torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-4
|
||||
)
|
||||
|
||||
carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl # (N-1, 3)
|
||||
carbonyl_to_oxygen = carbonyl_to_oxygen / (
|
||||
torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7
|
||||
torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-4
|
||||
)
|
||||
|
||||
atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23
|
||||
@@ -161,17 +156,17 @@ def adjust_oxygen_pos(
|
||||
|
||||
# Calpha to carbonyl both in the current frame. (N, 3)
|
||||
calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / (
|
||||
torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
|
||||
torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-4
|
||||
)
|
||||
# Calpha to nitrogen both in the current frame. (N, 3)
|
||||
calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / (
|
||||
torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
|
||||
torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-4
|
||||
)
|
||||
carbonyl_to_oxygen_term: torch.Tensor = (
|
||||
calpha_to_carbonyl_term + calpha_to_nitrogen_term
|
||||
) # (N, 3)
|
||||
carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
|
||||
torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7
|
||||
torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-4
|
||||
)
|
||||
|
||||
# Create a mask that is 1 when the next residue is not available either
|
||||
@@ -236,7 +231,7 @@ def chain_str_to_int(chain_str: str):
|
||||
#chain_feats['bb_mask'] = chain_feats['atom_mask'][:, CA_IDX]
|
||||
#bb_pos = chain_feats['atom_positions'][:, CA_IDX]
|
||||
#if center:
|
||||
#bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
|
||||
#bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-4)
|
||||
#centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
|
||||
#scaled_pos = centered_pos / scale_factor
|
||||
#else:
|
||||
@@ -341,8 +336,8 @@ def align_structures(
|
||||
reference_positions = center_zero(reference_positions, batch_indices)
|
||||
|
||||
# Compute covariance matrix for optimal rotation (Q.T @ P) -> [B x 3 x 3].
|
||||
cov = scatter_add(
|
||||
batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0
|
||||
cov = scatter(
|
||||
batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0, reduce="sum"
|
||||
)
|
||||
|
||||
# Perform singular value decomposition. (all [B x 3 x 3])
|
||||
|
||||
@@ -105,14 +105,28 @@ class Interpolant:
|
||||
aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
|
||||
batch_nm_0, batch_nm_1, mask=batch_mask
|
||||
)
|
||||
|
||||
# fd shortcut with batch 1
|
||||
if num_batch == 1:
|
||||
return aligned_nm_0
|
||||
|
||||
aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
|
||||
aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)
|
||||
|
||||
|
||||
# Compute cost matrix of aligned noise to ground truth
|
||||
batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
|
||||
|
||||
#fd ensure masked positions are zeroed out
|
||||
cost_matrix = torch.sum(
|
||||
torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
|
||||
torch.linalg.norm(
|
||||
batch_mask[...,None] * (
|
||||
aligned_nm_0.nan_to_num() -
|
||||
aligned_nm_1.nan_to_num()
|
||||
), dim=-1
|
||||
), dim=-1
|
||||
) / torch.sum(batch_mask, dim=-1)
|
||||
|
||||
#fd gpu->cpu slowdown
|
||||
noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
|
||||
return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]
|
||||
|
||||
|
||||
@@ -564,7 +564,7 @@ class Rotation:
|
||||
else:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
def get_rotvec(self, eps=1e-6) -> torch.Tensor:
|
||||
def get_rotvec(self, eps=1e-4) -> torch.Tensor:
|
||||
"""
|
||||
Return the underlying axis-angle rotation vector.
|
||||
|
||||
@@ -1265,7 +1265,7 @@ class Rigid:
|
||||
p_neg_x_axis: torch.Tensor,
|
||||
origin: torch.Tensor,
|
||||
p_xy_plane: torch.Tensor,
|
||||
eps: float = 1e-8
|
||||
eps: float = 1e-4
|
||||
):
|
||||
"""
|
||||
Implements algorithm 21. Constructs transformations from sets of 3
|
||||
|
||||
@@ -4,8 +4,10 @@ from typing import Any, Dict, Tuple
|
||||
from rf2aa.flow_matching.interpolant import _centered_gaussian, _uniform_so3
|
||||
import rf2aa.flow_matching.data_utils as du
|
||||
from rf2aa.flow_matching import data_transforms
|
||||
from rf2aa.training.recycling import recycle_step_packed
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input_fm, construct_template_feats
|
||||
from rf2aa.training.recycling import recycle_step_packed, recycle_step_gen
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input_fm, construct_template_feats, prepare_input_fm_allatom
|
||||
from rf2aa.training.recycling import unpack_outputs
|
||||
from rf2aa.util import rigid_from_3_points, writepdb_file
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class Sampler:
|
||||
self.xyz_converter = xyz_converter
|
||||
self.is_training = is_training
|
||||
|
||||
def sample(self, inputs: Tuple[str, Any]) -> Dict[str, Any]:
|
||||
def sample(self, inputs: Tuple[str, Any], use_amp=False) -> Dict[str, Any]:
|
||||
# first receive inputs from dataloader
|
||||
# convert them into features
|
||||
network_input = self._get_network_input(inputs)
|
||||
@@ -37,7 +39,7 @@ class Sampler:
|
||||
updated_features = self._construct_xt_features(rotmats_t_1, trans_t_1, network_input)
|
||||
network_input.update(updated_features)
|
||||
# run model
|
||||
output_i = recycle_step_packed(self.model, network_input,1, use_amp=True, nograds=True)
|
||||
output_i = recycle_step_packed(self.model, network_input,1, use_amp=use_amp, nograds=True)
|
||||
xyz = output_i[5][-1]
|
||||
N, Ca, C = xyz[...,0, :], xyz[...,1, :], xyz[...,2, :]
|
||||
px1s.append(xyz)
|
||||
@@ -48,13 +50,13 @@ class Sampler:
|
||||
rotmats_t_1, trans_t_1, d_t, t_1)
|
||||
# set prev_rots, prev_trans to curr
|
||||
rotmats_t_1, trans_t_1 = rotmats_t_2, trans_t_2
|
||||
t_1 = t_2
|
||||
#
|
||||
# only integrated to min_t, still need to take final step
|
||||
|
||||
updated_features = self._construct_xt_features(rotmats_t_1, trans_t_1, network_input)
|
||||
network_input.update(updated_features)
|
||||
# run model
|
||||
output_i = recycle_step_packed(self.model, network_input,1, use_amp=True, nograds=True)
|
||||
output_i = recycle_step_packed(self.model, network_input,1, use_amp=use_amp, nograds=True)
|
||||
# return the updated positions
|
||||
mask = torch.ones(xyz.shape[:-1], device=xyz.device).bool()
|
||||
mask = F.pad(mask, (0,33))
|
||||
@@ -113,4 +115,88 @@ class Sampler:
|
||||
for i, xyz in enumerate(xyz_list):
|
||||
writepdb_file(f, xyz, seq, modelnum=i, atom_mask=mask)
|
||||
|
||||
|
||||
|
||||
class AllAtomSampler(Sampler):
|
||||
""" sampler for model which predicts all atom positions, not frames/torsions """
|
||||
def __init__(self, model, num_timesteps, min_t, interpolant, xyz_converter, is_training) -> None:
|
||||
super().__init__(model, num_timesteps, min_t, interpolant, xyz_converter, is_training)
|
||||
self.allatom_mask = ChemData().allatom_mask.to(self.device)
|
||||
|
||||
def sample(self, inputs: Tuple[str, Any], n_cycle=1, use_amp=False) -> Dict[str, Any]:
|
||||
# first receive inputs from dataloader
|
||||
# convert them into features
|
||||
network_input = self._get_network_input(inputs)
|
||||
ts = torch.linspace(self.min_t, 1.0, self.num_timesteps)
|
||||
# create prior
|
||||
seq_unmasked = network_input["seq_unmasked"].to(self.device)
|
||||
trans_t_1 = self._setup_prior(seq_unmasked)
|
||||
network_input["trans_t"] = trans_t_1[None]
|
||||
# run first model fwd pass to get evoformer features
|
||||
output_i = recycle_step_gen(self.model, network_input, n_cycle, use_amp=use_amp, nograds=True)
|
||||
latent_feats = {
|
||||
"msa": output_i[-3],
|
||||
"pair": output_i[-2],
|
||||
"state": output_i[-1],
|
||||
"seq_unmasked": seq_unmasked,
|
||||
"dist_matrix": network_input["dist_matrix"].to(self.device),
|
||||
"idx": network_input["idx"].to(self.device),
|
||||
"trans_t": trans_t_1[None],
|
||||
"t": network_input["t"]
|
||||
}
|
||||
output_i_trunk = output_i
|
||||
# run the model refinement for n_steps
|
||||
output_i, px1, xts = self.run_refiner(latent_feats, ts)
|
||||
# HACK: get features from evoformer, this needs to become a dictionary to allow for assignment
|
||||
output_i = list(output_i)
|
||||
for i in range(len(output_i)):
|
||||
if output_i[i] is None:
|
||||
output_i[i] = output_i_trunk[i]
|
||||
return tuple(output_i)
|
||||
|
||||
def _setup_prior(self, seq_unmasked):
|
||||
B, L = seq_unmasked.shape
|
||||
xyz = torch.zeros(B, L, ChemData().NTOTAL, 3, device=self.device)
|
||||
is_real_atom = self.allatom_mask[seq_unmasked]
|
||||
num_atoms = is_real_atom.sum()
|
||||
xyz[is_real_atom] = _centered_gaussian(B, num_atoms, self.device) * du.NM_TO_ANG_SCALE
|
||||
return xyz
|
||||
|
||||
def _get_network_input(self, inputs):
|
||||
out = prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, device=self.device)
|
||||
network_input = out[2]
|
||||
return network_input
|
||||
|
||||
def _take_step(self, pred_trans_1, trans_t_1, d_t, t_1, seq_unmasked):
|
||||
is_real_atom = self.allatom_mask[seq_unmasked]
|
||||
trans_t_2_rolled = pred_trans_1.clone()
|
||||
pred_trans_1 = pred_trans_1[is_real_atom]
|
||||
trans_t_1 = trans_t_1[is_real_atom]
|
||||
trans_t_2 = self.interpolant._trans_euler_step(
|
||||
d_t, t_1, pred_trans_1, trans_t_1)
|
||||
trans_t_2_rolled[is_real_atom] = trans_t_2
|
||||
return trans_t_2_rolled
|
||||
|
||||
def run_refiner(self, latent_feats, ts):
|
||||
px1s = []
|
||||
xts = []
|
||||
t_1 = ts[0]
|
||||
trans_t_1 = latent_feats["trans_t"][0]
|
||||
for t_2 in ts[1:]:
|
||||
d_t = t_2-t_1
|
||||
# collect features for each step
|
||||
pred_trans_1 = self._run_diffusion_step(latent_feats)
|
||||
xts.append(trans_t_1)
|
||||
trans_t_2 = self._take_step(pred_trans_1, trans_t_1, d_t, t_1, seq_unmasked=latent_feats["seq_unmasked"])
|
||||
trans_t_1 = trans_t_2
|
||||
latent_feats["trans_t"] = trans_t_1[None]
|
||||
t_1 = t_2
|
||||
outputs = {}
|
||||
outputs["xyz"] = pred_trans_1
|
||||
outputs["state"] = latent_feats["state"]
|
||||
output_i = unpack_outputs(outputs, latent_feats, return_raw=False)
|
||||
return output_i, px1s, xts
|
||||
|
||||
def _run_diffusion_step(self, latent_feats):
|
||||
outputs = self.model.module.model.refinement(latent_feats)
|
||||
pred_trans_1 = outputs["xyz"]
|
||||
return pred_trans_1
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scale_rotmat(
|
||||
rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-7
|
||||
rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scale rotation matrix. This is done by converting it to vector representation,
|
||||
@@ -83,7 +83,7 @@ def skew_matrix_exponential_map_axis_angle(
|
||||
|
||||
|
||||
def skew_matrix_exponential_map(
|
||||
angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-7
|
||||
angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the matrix exponential of a rotation vector in skew matrix representation. Maps the
|
||||
@@ -137,7 +137,7 @@ def skew_matrix_exponential_map(
|
||||
return exp_skew
|
||||
|
||||
|
||||
def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-7) -> torch.Tensor:
|
||||
def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-4) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotation vectors to rotation matrix representation. The length of the rotation vector
|
||||
is the angle of rotation, the unit vector the rotation axis.
|
||||
@@ -232,7 +232,7 @@ def rotmat_to_rotvec(rotation_matrices: torch.Tensor) -> torch.Tensor:
|
||||
skew_outer = skew_outer + (torch.relu(skew_outer) - skew_outer) * id3
|
||||
|
||||
# Get basic rotation vector as sqrt of diagonal (is unit vector).
|
||||
vector_pi = torch.sqrt(torch.diagonal(torch.clamp(skew_outer, min=1e-8), dim1=-2, dim2=-1))
|
||||
vector_pi = torch.sqrt(torch.diagonal(torch.clamp(skew_outer, min=1e-4), dim1=-2, dim2=-1))
|
||||
|
||||
# Compute the signs of vector elements (up to a global phase).
|
||||
# Fist select indices for outer product slices with the largest norm.
|
||||
@@ -326,7 +326,7 @@ def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def _rotquat_to_axis_angle(
|
||||
rotation_quaternions: torch.Tensor, tol: float = 1e-7
|
||||
rotation_quaternions: torch.Tensor, tol: float = 1e-4
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Auxiliary routine for computing rotation angle and rotation axis from unit quaternions. To avoid
|
||||
@@ -334,7 +334,7 @@ def _rotquat_to_axis_angle(
|
||||
|
||||
Args:
|
||||
rotation_quaternions (torch.Tensor): Rotation quaternions in [r, i, j, k] format.
|
||||
tol (float, optional): Threshold for small rotations. Defaults to 1e-7.
|
||||
tol (float, optional): Threshold for small rotations. Defaults to 1e-4.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Rotation angles and axes.
|
||||
@@ -385,7 +385,7 @@ def rotquat_to_rotmat(rotation_quaternions: torch.Tensor) -> torch.Tensor:
|
||||
def apply_rotvec_to_rotmat(
|
||||
rotation_matrices: torch.Tensor,
|
||||
rotation_vectors: torch.Tensor,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Update a rotation encoded in a rotation matrix with a rotation vector.
|
||||
@@ -602,7 +602,7 @@ class BaseSampleSO3(nn.Module):
|
||||
num_omega: int,
|
||||
sigma_grid: torch.Tensor,
|
||||
omega_exponent: int = 3,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
interpolate: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
overwrite_cache: bool = False,
|
||||
@@ -625,7 +625,7 @@ class BaseSampleSO3(nn.Module):
|
||||
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
|
||||
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
|
||||
its power with the provided number. Defaults to 3.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
|
||||
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
|
||||
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
|
||||
cache_dir: Path to an optional cache directory. If set to None, lookup tables are
|
||||
@@ -907,7 +907,7 @@ class SampleIGSO3(BaseSampleSO3):
|
||||
num_omega: int,
|
||||
sigma_grid: torch.Tensor,
|
||||
omega_exponent: int = 3,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
interpolate: bool = True,
|
||||
l_max: int = 1000,
|
||||
cache_dir: Optional[str] = None,
|
||||
@@ -931,7 +931,7 @@ class SampleIGSO3(BaseSampleSO3):
|
||||
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
|
||||
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
|
||||
its power with the provided number. Defaults to 3.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
|
||||
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
|
||||
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
|
||||
l_max (int, optional): Maximum number of terms used in the series expansion.
|
||||
@@ -1022,7 +1022,7 @@ class SampleUSO3(BaseSampleSO3):
|
||||
num_omega: int,
|
||||
sigma_grid: torch.Tensor,
|
||||
omega_exponent: int = 3,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
interpolate: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
overwrite_cache: bool = False,
|
||||
@@ -1044,7 +1044,7 @@ class SampleUSO3(BaseSampleSO3):
|
||||
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
|
||||
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
|
||||
its power with the provided number. Defaults to 3.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
|
||||
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
|
||||
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
|
||||
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
|
||||
cache_dir: Path to an optional cache directory. If set to None, lookup tables are
|
||||
@@ -1122,7 +1122,7 @@ class ScoreSO3(nn.Module):
|
||||
sigma_grid: torch.Tensor,
|
||||
omega_exponent: int = 3,
|
||||
l_max: int = 1000,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
cache_dir: Optional[str] = None,
|
||||
overwrite_cache: bool = False,
|
||||
) -> None:
|
||||
@@ -1359,7 +1359,7 @@ def uniform_so3_density(omega: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def igso3_expansion(
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the IGSO(3) angle probability distribution function for pairs of angles and std dev
|
||||
@@ -1418,7 +1418,7 @@ def igso3_expansion(
|
||||
|
||||
|
||||
def digso3_expansion(
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the derivative of the IGSO(3) angle probability distribution function with respect to
|
||||
@@ -1477,7 +1477,7 @@ def digso3_expansion(
|
||||
|
||||
|
||||
def dlog_igso3_expansion(
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
|
||||
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the derivative of the logarithm of the IGSO(3) angle distribution function for pairs of
|
||||
@@ -1509,7 +1509,7 @@ def generate_lookup_table(
|
||||
omega_grid: torch.Tensor,
|
||||
sigma_grid: torch.Tensor,
|
||||
l_max: int = 1000,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
):
|
||||
"""
|
||||
Auxiliary function for generating a lookup table from IGSO(3) expansions and their derivatives.
|
||||
@@ -1550,7 +1550,7 @@ def generate_igso3_lookup_table(
|
||||
omega_grid: torch.Tensor,
|
||||
sigma_grid: torch.Tensor,
|
||||
l_max: int = 1000,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate a lookup table for the IGSO(3) probability distribution function of angles.
|
||||
@@ -1579,7 +1579,7 @@ def generate_dlog_igso3_lookup_table(
|
||||
omega_grid: torch.Tensor,
|
||||
sigma_grid: torch.Tensor,
|
||||
l_max: int = 1000,
|
||||
tol: float = 1e-7,
|
||||
tol: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate a lookup table for the derivative of the logarithm of the angular IGSO(3) probability
|
||||
@@ -1650,7 +1650,7 @@ def Exp(A): return exp(hat(A))
|
||||
|
||||
|
||||
# Angle of rotation SO(3) to R^+, this is the norm in our chosen orthonormal basis
|
||||
def Omega(R, eps=1e-6):
|
||||
def Omega(R, eps=1e-4):
|
||||
# multiplying by (1-epsilon) prevents instability of arccos when provided with -1 or 1 as input.
|
||||
R_ = R.to(torch.float64)
|
||||
assert not torch.any(torch.abs(R) > 1.1)
|
||||
|
||||
@@ -17,6 +17,8 @@ from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.kinematics import get_dih, get_ang
|
||||
from rf2aa.scoring import HbHybType
|
||||
|
||||
from rf2aa.flow_matching import data_utils as du
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1655,28 +1657,70 @@ def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain,
|
||||
|
||||
lddt = torch.cat(lddt_s, dim=1) # (N, L, Natm)
|
||||
|
||||
# per-res
|
||||
final_lddt_by_res = torch.clamp(
|
||||
(lddt[-1]*atm_mask[0]).sum(-1)
|
||||
/ (atm_mask.sum(-1) + eps), min=0.0, max=1.0)
|
||||
|
||||
# calculate lddt prediction loss
|
||||
nbin = pred_lddt.shape[1]
|
||||
bin_step = 1.0 / nbin
|
||||
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
|
||||
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
|
||||
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
|
||||
pred_lddt, true_lddt_label[-1])
|
||||
|
||||
res_mask = atm_mask.any(dim=-1)
|
||||
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
|
||||
|
||||
# method 1: average per-residue
|
||||
#lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-4) # L
|
||||
#lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-4)
|
||||
|
||||
# method 2: average per-atom
|
||||
# per-struct
|
||||
atm_mask = atm_mask * (pair_mask_accum != 0)
|
||||
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
|
||||
|
||||
# calculate lddt prediction loss
|
||||
if pred_lddt is not None:
|
||||
nbin = pred_lddt.shape[1]
|
||||
bin_step = 1.0 / nbin
|
||||
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
|
||||
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
|
||||
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
|
||||
pred_lddt, true_lddt_label[-1])
|
||||
|
||||
res_mask = atm_mask.any(dim=-1)
|
||||
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
|
||||
else:
|
||||
lddt_loss = None # no pred lddt provided
|
||||
|
||||
|
||||
return lddt_loss, lddt
|
||||
|
||||
def rms_aln_tgt(predin, true, mask):
|
||||
def centroid(X):
|
||||
return X.mean(dim=-2, keepdim=True)
|
||||
|
||||
pred = predin[mask]
|
||||
true = true[mask]
|
||||
pred = pred - centroid(pred)
|
||||
cT = centroid(true)
|
||||
true = true - cT
|
||||
C = torch.einsum('ji,jk->ik', pred, true)
|
||||
V, S, W = torch.svd(C)
|
||||
d = torch.ones([3,3], device=pred.device)
|
||||
d[:,-1] = torch.sign(torch.det(V)*torch.det(W)).unsqueeze(0)
|
||||
U = torch.matmul(d*V, W.permute(1,0)) # (3, 3)
|
||||
|
||||
rpred = torch.matmul(predin, U) + cT
|
||||
return rpred
|
||||
|
||||
|
||||
def translation_vector_field(pred_trans_1, noaln_gt_trans_1, mask, r3_t, params):
|
||||
gt_trans_1 = rms_aln_tgt(noaln_gt_trans_1, pred_trans_1.detach(), mask)
|
||||
|
||||
return translation_vector_field_noaln(pred_trans_1, gt_trans_1, mask, r3_t, params)
|
||||
|
||||
|
||||
def translation_vector_field_noaln(pred_trans_1, gt_trans_1, mask, r3_t, params):
|
||||
t_normalize_clip = params.t_normalize_clip
|
||||
trans_scale = params.trans_scale # global scale
|
||||
|
||||
t_dep_scale = 1 - torch.min( # t-dependant scale
|
||||
r3_t[..., None], torch.tensor(t_normalize_clip)
|
||||
) # (B, 1, 1)
|
||||
trans_error = trans_scale / t_dep_scale * (gt_trans_1 - pred_trans_1)
|
||||
|
||||
loss_denom = 3 * torch.sum(mask)
|
||||
trans_loss = torch.sum(
|
||||
trans_error*trans_error*mask[...,None], dim=(-1,-2)
|
||||
) / loss_denom
|
||||
|
||||
return trans_loss
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins
|
||||
from rf2aa.loss.loss import resolve_equiv_natives, resolve_equiv_natives_asmb, \
|
||||
resolve_symmetry_predictions, resolve_symmetry, mask_unresolved_frames, \
|
||||
compute_general_FAPE, torsionAngleLoss, calc_lddt, calc_allatom_lddt_loss, \
|
||||
calc_crd_rmsd, calc_BB_bond_geom, calc_lj, calc_atom_bond_loss
|
||||
calc_crd_rmsd, calc_BB_bond_geom, calc_lj, calc_atom_bond_loss,translation_vector_field
|
||||
|
||||
from rf2aa.util import is_atom, is_protein, Ls_from_same_chain_2d, get_prot_sm_mask, \
|
||||
xyz_to_frame_xyz, get_frames, writepdb
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
@@ -17,20 +18,16 @@ cce_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
def get_loss_and_misc(
|
||||
trainer,
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
loss_param
|
||||
):
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
|
||||
|
||||
if pred_allatom is None:
|
||||
_, pred_allatom = trainer.xyz_converter.compute_all_atom(msa[0][0][None],pred_crds[-1], alphas[-1])
|
||||
#pred_crds = pred_crds[:, None]
|
||||
#alphas = alphas[:, None]
|
||||
|
||||
if (symmRs is not None):
|
||||
###
|
||||
# resolve symmetry
|
||||
###
|
||||
if symmRs is not None:
|
||||
true_crds = true_crds[:,0]
|
||||
atom_mask = atom_mask[:,0]
|
||||
mapT2P = resolve_symmetry_predictions(pred_crds, true_crds, atom_mask, Lasu) # (Nlayer, Ltrue)
|
||||
@@ -88,230 +85,233 @@ def get_loss_and_misc(
|
||||
c6d = xyz_to_c6d(true_crds_frame)
|
||||
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
|
||||
|
||||
# contact accuray not as useful to track anymore
|
||||
#prob = self.active_fn(logit_s[0]) # distogram
|
||||
#acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
|
||||
loss, loss_dict = calc_loss(
|
||||
trainer, logit_s, c6d,
|
||||
trainer, loss_param, logit_s, c6d,
|
||||
logit_aa_s, msa, mask_msa, logit_pae, logit_pde, p_bind,
|
||||
pred_crds, alphas, pred_allatom, true_crds,
|
||||
pred_crds, alphas, pred_allatom, trans_1, r3_t, true_crds,
|
||||
atom_mask, res_mask, mask_2d, same_chain,
|
||||
pred_lddts, idx_pdb, bond_feats, dist_matrix,
|
||||
atom_frames=atom_frames,unclamp=unclamp, negative=negative,
|
||||
item=item, task=task, **loss_param
|
||||
item=item, task=task
|
||||
)
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
def calc_loss(trainer, logit_s, label_s,
|
||||
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
|
||||
pred, pred_tors, pred_allatom, true,
|
||||
mask_crds, mask_BB, mask_2d, same_chain,
|
||||
pred_lddt, idx, bond_feats, dist_matrix, atom_frames=None, unclamp=False,
|
||||
negative=False, interface=False,
|
||||
w_dist=1.0, w_aa=1.0, w_str=1.0, w_inter_fape=0.0, w_lig_fape=1.0, w_lddt=1.0,
|
||||
w_bond=1.0, w_clash=0.0, w_atom_bond=0.0, w_skip_bond=0.0, w_rigid=0.0, w_hb=0.0, w_bind=0.0,
|
||||
w_pae=0.0, w_pde=0.0, lj_lin=0.85, eps=1e-4, binder_loss_label_smoothing = 0.0, item=None, task=None, out_dir='./'
|
||||
):
|
||||
gpu = pred.device
|
||||
def calc_loss(
|
||||
trainer, loss_param, logit_s, label_s,
|
||||
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
|
||||
pred, pred_tors, pred_allatom, trans_1, r3_t, true,
|
||||
mask_crds, mask_BB, mask_2d, same_chain,
|
||||
pred_lddt, idx, bond_feats, dist_matrix,
|
||||
atom_frames=None, unclamp=False, negative=False, interface=False, item=None, task=None,
|
||||
eps=1e-4, binder_loss_label_smoothing = 0.0, out_dir='./'
|
||||
):
|
||||
# fd: force these to be specified in config
|
||||
w_dist, w_aa, w_str = loss_param["w_dist"], loss_param["w_aa"], loss_param["w_str"]
|
||||
w_trans, w_inter_fape, w_lig_fape = loss_param["w_trans"], loss_param["w_inter_fape"], loss_param["w_lig_fape"]
|
||||
w_lddt, w_bond, w_clash = loss_param["w_lddt"], loss_param["w_bond"], loss_param["w_clash"]
|
||||
w_atom_bond, w_skip_bond, w_rigid = loss_param["w_atom_bond"], loss_param["w_skip_bond"], loss_param["w_rigid"]
|
||||
w_hb, w_bind, w_pae = loss_param["w_hb"], loss_param["w_bind"], loss_param["w_pae"]
|
||||
w_pde, lj_lin, w_pae = loss_param["w_pde"], loss_param["lj_lin"], loss_param["w_pae"]
|
||||
|
||||
# track losses for printing to local log and uploading to WandB
|
||||
loss_dict = OrderedDict()
|
||||
# track losses for printing to local log and uploading to WandB
|
||||
loss_dict = OrderedDict()
|
||||
|
||||
B, L, natoms = true.shape[:3]
|
||||
seq = label_aa_s[:,0].clone()
|
||||
B, L, natoms = true.shape[:3]
|
||||
seq = label_aa_s[:,0].clone()
|
||||
|
||||
assert (B==1) # fd - code assumes a batch size of 1
|
||||
assert (B==1) # fd - code assumes a batch size of 1
|
||||
|
||||
tot_loss = 0.0
|
||||
# set up frames
|
||||
frames, frame_mask = get_frames(
|
||||
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
|
||||
tot_loss = 0.0
|
||||
# set up frames
|
||||
frames, frame_mask = get_frames(
|
||||
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
|
||||
|
||||
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
|
||||
frames_BB = frames.clone()
|
||||
frames_BB[..., 1:, :, :] = 0
|
||||
frame_mask_BB = frame_mask.clone()
|
||||
frame_mask_BB[...,1:] =False
|
||||
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
|
||||
frames_BB = frames.clone()
|
||||
frames_BB[..., 1:, :, :] = 0
|
||||
frame_mask_BB = frame_mask.clone()
|
||||
frame_mask_BB[...,1:] =False
|
||||
|
||||
# c6d loss
|
||||
for i in range(4):
|
||||
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
|
||||
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
|
||||
mask_2d_ = mask_2d
|
||||
else:
|
||||
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
|
||||
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
|
||||
bb_frame_good = bb_frame_good[...,0] # (1,L)
|
||||
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
|
||||
mask_2d_ = mask_2d & loss_mask_2d
|
||||
# c6d loss
|
||||
for i in range(4):
|
||||
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
|
||||
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
|
||||
mask_2d_ = mask_2d
|
||||
else:
|
||||
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
|
||||
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
|
||||
bb_frame_good = bb_frame_good[...,0] # (1,L)
|
||||
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
|
||||
mask_2d_ = mask_2d & loss_mask_2d
|
||||
|
||||
if negative.item():
|
||||
# Don't compute inter-chain distogram losses
|
||||
# for negative examples.
|
||||
mask_2d_ = mask_2d_ * same_chain
|
||||
if negative.item():
|
||||
# Don't compute inter-chain distogram losses
|
||||
# for negative examples.
|
||||
mask_2d_ = mask_2d_ * same_chain
|
||||
|
||||
#fd upcast loss to float to avoid overflow
|
||||
loss = (mask_2d_*loss.float()).sum() / (mask_2d_.sum() + eps)
|
||||
tot_loss += w_dist*loss
|
||||
loss_dict[f'c6d_{i}'] = loss.detach()
|
||||
#fd upcast loss to float to avoid overflow
|
||||
loss = (mask_2d_*loss.float()).sum() / (mask_2d_.sum() + eps)
|
||||
tot_loss += w_dist*loss
|
||||
loss_dict[f'c6d_{i}'] = loss.detach()
|
||||
|
||||
# masked token prediction loss
|
||||
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
|
||||
loss = loss * mask_aa_s.reshape(B, -1)
|
||||
loss = loss.float().sum() / (mask_aa_s.sum() + 1e-4)
|
||||
tot_loss += w_aa*loss
|
||||
loss_dict['aa_cce'] = loss.detach()
|
||||
# masked token prediction loss
|
||||
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
|
||||
loss = loss * mask_aa_s.reshape(B, -1)
|
||||
loss = loss.float().sum() / (mask_aa_s.sum() + 1e-4)
|
||||
tot_loss += w_aa*loss
|
||||
loss_dict['aa_cce'] = loss.detach()
|
||||
|
||||
# col 4: binder loss
|
||||
# only apply binding loss to complexes
|
||||
# note that this will apply loss to positive sets w/o a corresponding negative set
|
||||
# (e.g., homomers). Maybe want to change this?
|
||||
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
|
||||
if (torch.sum(same_chain==0) > 0):
|
||||
bce = torch.nn.BCELoss()
|
||||
target = torch.tensor(
|
||||
[abs(float(not negative) - binder_loss_label_smoothing)],
|
||||
device=p_bind.device
|
||||
)
|
||||
loss = bce(p_bind,target)
|
||||
else:
|
||||
# avoid unused parameter error
|
||||
loss = 0.0 * p_bind.sum()
|
||||
|
||||
tot_loss += w_bind * loss
|
||||
loss_dict['binder_bce_loss'] = loss.detach()
|
||||
|
||||
|
||||
### GENERAL LAYERS
|
||||
# Structural loss (layer-wise backbone FAPE)
|
||||
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
|
||||
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
|
||||
dclamp_prot = 10
|
||||
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
|
||||
# whereas other losses also maks unresolved ligand atoms (mask_BB)
|
||||
# frames with unresolved ligand atoms are masked in compute_general_FAPE
|
||||
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
|
||||
|
||||
# create 2d masks for intrachain and interchain fape calculations
|
||||
nframes = frame_mask.shape[-1]
|
||||
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
|
||||
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
|
||||
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, ChemData().NTOTAL)
|
||||
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
|
||||
different_chain = ~same_chain.bool()
|
||||
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
|
||||
|
||||
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
|
||||
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
|
||||
pae_loss = 0.0 * logit_pae.sum()
|
||||
pde_loss = 0.0 * logit_pde.sum()
|
||||
elif negative: # inter-chain fapes should be ignored for negative cases
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
|
||||
dclamp=dclamp,
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
# col 4: binder loss
|
||||
# only apply binding loss to complexes
|
||||
# note that this will apply loss to positive sets w/o a corresponding negative set
|
||||
# (e.g., homomers). Maybe want to change this?
|
||||
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
|
||||
if (torch.sum(same_chain==0) > 0):
|
||||
bce = torch.nn.BCELoss()
|
||||
target = torch.tensor(
|
||||
[abs(float(not negative) - binder_loss_label_smoothing)],
|
||||
device=p_bind.device
|
||||
)
|
||||
|
||||
loss = bce(p_bind,target)
|
||||
else:
|
||||
# avoid unused parameter error
|
||||
loss = 0.0 * p_bind.sum()
|
||||
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
tot_loss += w_bind * loss
|
||||
loss_dict['binder_bce_loss'] = loss.detach()
|
||||
|
||||
# change clamp for intra protein to 10, leave rest at 30
|
||||
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
|
||||
if not unclamp:
|
||||
is_prot = is_protein(seq) # (1,L)
|
||||
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
|
||||
# zero out rows and columns with small molecules
|
||||
same_chain_clamp_mask[:, ~is_prot[0]] = 0
|
||||
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
|
||||
dclamp_2d *= ~same_chain_clamp_mask.bool()
|
||||
dclamp_2d += same_chain_clamp_mask*dclamp_prot
|
||||
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
dclamp=None,
|
||||
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
)
|
||||
### GENERAL LAYERS
|
||||
# Structural loss (layer-wise backbone FAPE)
|
||||
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
|
||||
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
|
||||
dclamp_prot = 10
|
||||
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
|
||||
# whereas other losses also maks unresolved ligand atoms (mask_BB)
|
||||
# frames with unresolved ligand atoms are masked in compute_general_FAPE
|
||||
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
|
||||
|
||||
# free up big intermediate data tensors
|
||||
del dclamp_2d
|
||||
if not unclamp:
|
||||
del same_chain_clamp_mask
|
||||
# create 2d masks for intrachain and interchain fape calculations
|
||||
nframes = frame_mask.shape[-1]
|
||||
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
|
||||
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
|
||||
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, ChemData().NTOTAL)
|
||||
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
|
||||
different_chain = ~same_chain.bool()
|
||||
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
|
||||
|
||||
num_layers = pred.shape[0]
|
||||
gamma = 1.0 # equal weighting of fape across all layers
|
||||
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
|
||||
w_bb_fape = torch.flip(w_bb_fape, (0,))
|
||||
w_bb_fape = w_bb_fape / w_bb_fape.sum()
|
||||
bb_l_fape = (w_bb_fape*tot_str).sum()
|
||||
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
|
||||
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
|
||||
pae_loss = 0.0 * logit_pae.sum()
|
||||
pde_loss = 0.0 * logit_pde.sum()
|
||||
elif negative: # inter-chain fapes should be ignored for negative cases
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
|
||||
dclamp=dclamp,
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
)
|
||||
|
||||
tot_loss += 0.5*w_str*bb_l_fape
|
||||
for i in range(len(tot_str)):
|
||||
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
|
||||
loss_dict['bb_fape_full'] = bb_l_fape.detach()
|
||||
else:
|
||||
|
||||
tot_loss += w_pae*pae_loss + w_pde*pde_loss
|
||||
loss_dict['pae_loss'] = pae_loss.detach()
|
||||
loss_dict['pde_loss'] = pde_loss.detach()
|
||||
if logit_pae is not None:
|
||||
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
if logit_pde is not None:
|
||||
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
|
||||
|
||||
## small-molecule ligands
|
||||
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
|
||||
# change clamp for intra protein to 10, leave rest at 30
|
||||
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
|
||||
if not unclamp:
|
||||
is_prot = is_protein(seq) # (1,L)
|
||||
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
|
||||
# zero out rows and columns with small molecules
|
||||
same_chain_clamp_mask[:, ~is_prot[0]] = 0
|
||||
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
|
||||
dclamp_2d *= ~same_chain_clamp_mask.bool()
|
||||
dclamp_2d += same_chain_clamp_mask*dclamp_prot
|
||||
|
||||
## AllAtom loss
|
||||
# get ground-truth torsion angles
|
||||
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
|
||||
true, seq, mask_in=mask_crds)
|
||||
tors_mask *= mask_BB[...,None]
|
||||
tot_str, pae_loss, pde_loss = compute_general_FAPE(
|
||||
pred[:,res_mask,:,:3],
|
||||
true[:,res_mask[0],:3],
|
||||
mask_crds[:,res_mask[0],:3],
|
||||
frames_BB[:,res_mask[0]],
|
||||
frame_mask_BB[:,res_mask[0]],
|
||||
dclamp=None,
|
||||
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
|
||||
logit_pae=logit_pae,
|
||||
logit_pde=logit_pde,
|
||||
)
|
||||
|
||||
# get alternative coordinates for ground-truth
|
||||
true_alt = torch.zeros_like(true)
|
||||
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
|
||||
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
|
||||
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
|
||||
predTs = pred[-1,...]
|
||||
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
|
||||
# free up big intermediate data tensors
|
||||
del dclamp_2d
|
||||
if not unclamp:
|
||||
del same_chain_clamp_mask
|
||||
|
||||
# - resolve symmetry
|
||||
xs_mask = trainer.aamask[seq] # (B, L, 27)
|
||||
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
|
||||
xs_mask *= mask_crds # mask missing atoms & residues as well
|
||||
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
|
||||
num_layers = pred.shape[0]
|
||||
gamma = 1.0 # equal weighting of fape across all layers
|
||||
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
|
||||
w_bb_fape = torch.flip(w_bb_fape, (0,))
|
||||
w_bb_fape = w_bb_fape / w_bb_fape.sum()
|
||||
bb_l_fape = (w_bb_fape*tot_str).sum()
|
||||
|
||||
# torsion angle loss
|
||||
l_tors = torsionAngleLoss(
|
||||
pred_tors,
|
||||
true_tors,
|
||||
true_tors_alt,
|
||||
tors_mask,
|
||||
tors_planar,
|
||||
eps = 1e-4)
|
||||
tot_loss += w_str*l_tors
|
||||
loss_dict['torsion'] = l_tors.detach()
|
||||
tot_loss += 0.5*w_str*bb_l_fape
|
||||
for i in range(len(tot_str)):
|
||||
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
|
||||
loss_dict['bb_fape_full'] = bb_l_fape.detach()
|
||||
|
||||
### FINETUNING LAYERS
|
||||
# lddts (CA)
|
||||
tot_loss += w_pae*pae_loss + w_pde*pde_loss
|
||||
loss_dict['pae_loss'] = pae_loss.detach()
|
||||
loss_dict['pde_loss'] = pde_loss.detach()
|
||||
|
||||
## small-molecule ligands
|
||||
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
|
||||
|
||||
## AllAtom loss
|
||||
# get ground-truth torsion angles
|
||||
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
|
||||
true, seq, mask_in=mask_crds)
|
||||
tors_mask *= mask_BB[...,None]
|
||||
|
||||
# get alternative coordinates for ground-truth
|
||||
true_alt = torch.zeros_like(true)
|
||||
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
|
||||
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
|
||||
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
|
||||
predTs = pred[-1,...]
|
||||
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
|
||||
|
||||
# - resolve symmetry
|
||||
xs_mask = trainer.aamask[seq] # (B, L, 27)
|
||||
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
|
||||
xs_mask *= mask_crds # mask missing atoms & residues as well
|
||||
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
|
||||
|
||||
# torsion angle loss
|
||||
l_tors = torsionAngleLoss(
|
||||
pred_tors,
|
||||
true_tors,
|
||||
true_tors_alt,
|
||||
tors_mask,
|
||||
tors_planar,
|
||||
eps = 1e-4)
|
||||
tot_loss += w_str*l_tors
|
||||
loss_dict['torsion'] = l_tors.detach()
|
||||
|
||||
### FINETUNING LAYERS
|
||||
# lddts (CA)
|
||||
if pred_lddt is not None:
|
||||
ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface)
|
||||
loss_dict['ca_lddt'] = ca_lddt[-1].detach()
|
||||
|
||||
@@ -323,174 +323,162 @@ def calc_loss(trainer, logit_s, label_s,
|
||||
loss_dict['lddt_loss'] = lddt_loss.detach()
|
||||
loss_dict['allatom_lddt'] = allatom_lddt[0].detach()
|
||||
|
||||
# FAPE losses
|
||||
# allatom fape and torsion angle loss
|
||||
# frames, frame_mask = get_frames(
|
||||
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
|
||||
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
|
||||
l_fape = torch.zeros((pred.shape[0])).to(gpu)
|
||||
# FAPE losses
|
||||
# allatom fape and torsion angle loss
|
||||
# frames, frame_mask = get_frames(
|
||||
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
|
||||
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
|
||||
l_fape = torch.zeros((pred.shape[0])).to(gpu)
|
||||
|
||||
elif negative.item(): # inter-chain fapes should be ignored for negative cases
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
|
||||
)
|
||||
|
||||
else:
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]]
|
||||
)
|
||||
|
||||
tot_loss += w_str*l_fape[0]
|
||||
loss_dict['allatom_fape'] = l_fape[0].detach()
|
||||
|
||||
# rmsd loss (for logging only)
|
||||
if torch.any(mask_BB[0]):
|
||||
rmsd = calc_crd_rmsd(
|
||||
pred_allatom[:,mask_BB[0],:,:3],
|
||||
nat_symm[None,mask_BB[0],:,:3],
|
||||
xs_mask[:,mask_BB[0]]
|
||||
)
|
||||
loss_dict["rmsd"] = rmsd[0].detach()
|
||||
else:
|
||||
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
|
||||
|
||||
# create protein and not protein masks; not protein could include nucleic acids
|
||||
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
|
||||
not_prot_mask_BB = ~prot_mask_BB.bool()
|
||||
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
|
||||
xs_mask_prot[:,~prot_mask_BB] = False
|
||||
xs_mask_lig[:,~not_prot_mask_BB] = False
|
||||
|
||||
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_prot = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_prot_prot = torch.tensor([0], device=pred.device)
|
||||
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_lig_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_lig_lig = torch.tensor([0], device=pred.device)
|
||||
|
||||
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
|
||||
alignment_radius=10.0
|
||||
)
|
||||
|
||||
# fd rms of target ligand only
|
||||
#fd get target ligand mask
|
||||
#fd this is more difficult than expected with only the data we have
|
||||
#fd a) target ligand is 1st one
|
||||
#fd b) examples are all protein followed by ligand
|
||||
sm_mask = not_prot_mask_BB
|
||||
Ls_prot = Ls_from_same_chain_2d(same_chain[:,~sm_mask][:,:,~sm_mask])
|
||||
Ls_sm = Ls_from_same_chain_2d(same_chain[:,sm_mask][:,:,sm_mask])
|
||||
xs_mask_tgt = xs_mask.clone()
|
||||
xs_mask_tgt[:,:sum(Ls_prot)] = False
|
||||
xs_mask_tgt[:,(sum(Ls_prot)+Ls_sm[0]):]= False
|
||||
|
||||
rmsd_prot_tgt = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_tgt[:,mask_BB[0]],
|
||||
alignment_radius=10.0
|
||||
)
|
||||
else:
|
||||
rmsd_prot_lig = torch.tensor([0], device=pred.device)
|
||||
rmsd_prot_tgt = torch.tensor([0], device=pred.device)
|
||||
|
||||
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
|
||||
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
|
||||
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
|
||||
loss_dict["rmsd_prot_tgt"]= rmsd_prot_tgt[0].detach()
|
||||
|
||||
# cart bonded (bond geometry)
|
||||
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
|
||||
if w_bond > 0.0:
|
||||
tot_loss += w_bond*bond_loss
|
||||
loss_dict['bond_geom'] = bond_loss.detach()
|
||||
|
||||
# clash [use all atoms not just those in native]
|
||||
clash_loss = calc_lj(
|
||||
seq[0], pred_allatom,
|
||||
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, trainer.lj_correction_parameters, trainer.num_bonds,
|
||||
lj_lin=lj_lin
|
||||
elif negative.item(): # inter-chain fapes should be ignored for negative cases
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]],
|
||||
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
|
||||
)
|
||||
if w_clash > 0.0:
|
||||
tot_loss += w_clash*clash_loss.mean()
|
||||
loss_dict['clash_loss'] = clash_loss[0].detach()
|
||||
if torch.any(mask_BB[0]):
|
||||
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
|
||||
pred=pred_allatom[:,mask_BB[0]],
|
||||
true=nat_symm[None,mask_BB[0]],
|
||||
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
|
||||
seq=seq[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
atom_bond_loss = torch.tensor(0, device=gpu)
|
||||
skip_bond_loss = torch.tensor(0, device=gpu)
|
||||
rigid_loss = torch.tensor(0, device=gpu)
|
||||
|
||||
if w_atom_bond >= 0.0:
|
||||
tot_loss += w_atom_bond*atom_bond_loss
|
||||
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
|
||||
else:
|
||||
l_fape, _, _ = compute_general_FAPE(
|
||||
pred_allatom[:,res_mask[0],:,:3],
|
||||
nat_symm[None,res_mask[0],:,:3],
|
||||
xs_mask[:,res_mask[0]],
|
||||
frames[:,res_mask[0]],
|
||||
frame_mask[:,res_mask[0]]
|
||||
)
|
||||
|
||||
if w_skip_bond >= 0.0:
|
||||
tot_loss += w_skip_bond*skip_bond_loss
|
||||
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
|
||||
tot_loss += w_str*l_fape[0]
|
||||
loss_dict['allatom_fape'] = l_fape[0].detach()
|
||||
|
||||
if w_rigid >= 0.0:
|
||||
tot_loss += w_rigid*rigid_loss
|
||||
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
|
||||
chain_prot = same_chain.clone()
|
||||
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
|
||||
# translation loss (for generative refinement only)
|
||||
if trans_1 is not None:
|
||||
allatom_mask = ChemData().allatom_mask.to(seq.device, non_blocking=True)
|
||||
is_real_atom = allatom_mask[seq].bool()
|
||||
pred_trans_1 = pred_allatom[is_real_atom]
|
||||
mask = mask_crds[is_real_atom]
|
||||
trans = translation_vector_field (pred_trans_1, trans_1, mask, r3_t, loss_param)
|
||||
tot_loss += w_trans*trans[0]
|
||||
loss_dict["trans_loss"] = trans[0].detach()
|
||||
loss_dict["t"] = r3_t[0,0].detach()
|
||||
|
||||
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, negative=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
|
||||
# rmsd loss (for logging only)
|
||||
if torch.any(mask_BB[0]):
|
||||
rmsd = calc_crd_rmsd(
|
||||
pred_allatom[:,mask_BB[0],:,:3],
|
||||
nat_symm[None,mask_BB[0],:,:3],
|
||||
xs_mask[:,mask_BB[0]]
|
||||
)
|
||||
loss_dict["rmsd"] = rmsd[0].detach()
|
||||
else:
|
||||
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
|
||||
|
||||
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
|
||||
|
||||
chain_lig = same_chain.clone()
|
||||
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
|
||||
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
|
||||
|
||||
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
|
||||
# create protein and not protein masks; not protein could include nucleic acids
|
||||
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
|
||||
not_prot_mask_BB = ~prot_mask_BB.bool()
|
||||
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
|
||||
xs_mask_prot[:,~prot_mask_BB] = False
|
||||
xs_mask_lig[:,~not_prot_mask_BB] = False
|
||||
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_prot = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_prot_prot = torch.tensor([0], device=pred.device)
|
||||
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_lig_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
rmsd_lig_lig = torch.tensor([0], device=pred.device)
|
||||
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
|
||||
rmsd_prot_lig = calc_crd_rmsd(
|
||||
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
|
||||
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
|
||||
alignment_radius=10.0
|
||||
)
|
||||
else:
|
||||
rmsd_prot_lig = torch.tensor([0], device=pred.device)
|
||||
|
||||
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
|
||||
chain_prot_lig_inter += protein_mask_2d
|
||||
chain_prot_lig_inter += not_protein_mask_2d
|
||||
_, allatom_lddt_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
|
||||
chain_prot_lig_inter, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
|
||||
loss_dict['total_loss'] = tot_loss.detach()
|
||||
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
|
||||
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
|
||||
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
|
||||
|
||||
return tot_loss, loss_dict
|
||||
# cart bonded (bond geometry)
|
||||
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
|
||||
if w_bond > 0.0:
|
||||
tot_loss += w_bond*bond_loss
|
||||
loss_dict['bond_geom'] = bond_loss.detach()
|
||||
|
||||
# clash [use all atoms not just those in native]
|
||||
clash_loss = calc_lj(
|
||||
seq[0], pred_allatom,
|
||||
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, trainer.lj_correction_parameters, trainer.num_bonds,
|
||||
lj_lin=lj_lin
|
||||
)
|
||||
if w_clash > 0.0:
|
||||
tot_loss += w_clash*clash_loss.mean()
|
||||
loss_dict['clash_loss'] = clash_loss[0].detach()
|
||||
if torch.any(mask_BB[0]):
|
||||
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
|
||||
pred=pred_allatom[:,mask_BB[0]],
|
||||
true=nat_symm[None,mask_BB[0]],
|
||||
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
|
||||
seq=seq[:,mask_BB[0]]
|
||||
)
|
||||
else:
|
||||
atom_bond_loss = torch.tensor(0, device=gpu)
|
||||
skip_bond_loss = torch.tensor(0, device=gpu)
|
||||
rigid_loss = torch.tensor(0, device=gpu)
|
||||
|
||||
if w_atom_bond >= 0.0:
|
||||
tot_loss += w_atom_bond*atom_bond_loss
|
||||
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
|
||||
|
||||
if w_skip_bond >= 0.0:
|
||||
tot_loss += w_skip_bond*skip_bond_loss
|
||||
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
|
||||
|
||||
if w_rigid >= 0.0:
|
||||
tot_loss += w_rigid*rigid_loss
|
||||
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
|
||||
chain_prot = same_chain.clone()
|
||||
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
|
||||
|
||||
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, negative=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
|
||||
|
||||
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
|
||||
chain_prot, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
|
||||
|
||||
chain_lig = same_chain.clone()
|
||||
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
|
||||
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
|
||||
|
||||
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
|
||||
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
|
||||
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
|
||||
|
||||
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
|
||||
chain_prot_lig_inter += protein_mask_2d
|
||||
chain_prot_lig_inter += not_protein_mask_2d
|
||||
_, allatom_lddt_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
|
||||
chain_prot_lig_inter, interface=True, N_stripe=10)
|
||||
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
|
||||
loss_dict['total_loss'] = tot_loss.detach()
|
||||
return tot_loss, loss_dict
|
||||
|
||||
|
||||
### this file will contain specific calls to the loss function
|
||||
|
||||
452
rf2aa/model/AF3_blocks.py
Normal file
452
rf2aa/model/AF3_blocks.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
||||
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
||||
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
||||
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
|
||||
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
|
||||
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
|
||||
from rf2aa.training.checkpoint import create_custom_forward
|
||||
from rf2aa.util_module import Dropout, init_lecun_normal
|
||||
from opt_einsum import contract as einsum
|
||||
|
||||
|
||||
# MSA transformer
|
||||
class AF3_full_block(nn.Module):
|
||||
"""
|
||||
AF3_full_block:
|
||||
- MSA/Pair updates as in AF2
|
||||
- MSA then Pair
|
||||
"""
|
||||
def __init__(self, global_config, block_params, **kwargs):
|
||||
super(AF3_full_block, self).__init__()
|
||||
d_msa, d_pair = (
|
||||
global_config.d_msa_full,
|
||||
global_config.d_pair
|
||||
)
|
||||
|
||||
# to do: optionally disable norm bias
|
||||
self.norm_pair_bias = nn.LayerNorm(d_pair, bias=False)
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
|
||||
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_pair)
|
||||
|
||||
# to do: optionally disable norm bias
|
||||
self.msa_row_attn = MSARowAttentionWithBias(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels,
|
||||
nseq_normalization=block_params.norm_msa_row,
|
||||
bias=False
|
||||
)
|
||||
self.msa_col_attn = MSAColGlobalAttention(
|
||||
d_msa=d_msa,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels,
|
||||
bias=False
|
||||
)
|
||||
self.msa_transition = FeedForwardLayer(d_msa, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
# Pair update parameters
|
||||
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels, \
|
||||
p_drop=block_params.p_drop_outer_product)
|
||||
|
||||
# to do: optionally disable norm bias
|
||||
self.tri_mul_outgoing = TriangleMultiplication(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, outgoing=True, bias=False)
|
||||
self.tri_mul_incoming = TriangleMultiplication(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, outgoing=False, bias=False)
|
||||
self.tri_attn_start = TriangleAttention(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, start_node=True)
|
||||
self.tri_attn_end = TriangleAttention(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, start_node=False)
|
||||
|
||||
self.pair_transition = FeedForwardLayer(d_pair, 2) # HACK: hardcoded value for transition
|
||||
|
||||
def _unpack_inputs(self, latent_feats):
|
||||
pair = latent_feats["pair"]
|
||||
msa = latent_feats["msa_full"]
|
||||
|
||||
return msa, pair
|
||||
|
||||
def _pack_outputs(self, msa, pair, latent_feats):
|
||||
latent_feats["msa_full"] = msa
|
||||
latent_feats["pair"] = pair
|
||||
return latent_feats
|
||||
|
||||
def _1d_update(self, msa, pair):
|
||||
pair = self.norm_pair_bias(pair)
|
||||
|
||||
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
|
||||
msa = msa + self.msa_col_attn(msa)
|
||||
msa = msa + self.msa_transition(msa)
|
||||
|
||||
return msa
|
||||
|
||||
def _2d_update(self, msa, pair):
|
||||
msa_bias = self.outer_product(msa)
|
||||
pair = pair + msa_bias
|
||||
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
|
||||
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
|
||||
pair = pair + self.drop_row(self.tri_attn_start(pair))
|
||||
pair = pair + self.drop_row(self.tri_attn_end(pair))
|
||||
pair = pair + self.pair_transition(pair)
|
||||
return pair
|
||||
|
||||
def forward(self, latent_feats, use_checkpoint, use_amp):
|
||||
msa, pair = self._unpack_inputs(latent_feats)
|
||||
if use_checkpoint:
|
||||
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
|
||||
msa = checkpoint.checkpoint(self._1d_update, msa, pair, use_reentrant=True)
|
||||
pair = checkpoint.checkpoint(self._2d_update, msa, pair, use_reentrant=True)
|
||||
else:
|
||||
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
|
||||
msa = self._1d_update(msa, pair)
|
||||
pair = self._2d_update(msa, pair)
|
||||
|
||||
latent_feats = self._pack_outputs(msa, pair, latent_feats)
|
||||
return latent_feats
|
||||
|
||||
|
||||
# Pair transformer
|
||||
class AF3_block(nn.Module):
|
||||
"""
|
||||
Attempt to replicate AF3 architecture from scratch.
|
||||
"""
|
||||
def __init__(self, global_config, block_params, **kwargs):
|
||||
super(AF3_block, self).__init__()
|
||||
d_singleseq, d_pair = (
|
||||
global_config.d_msa,
|
||||
global_config.d_pair
|
||||
)
|
||||
|
||||
# to do: optionally disable norm bias
|
||||
self.norm_pair_bias = nn.LayerNorm(d_pair, bias=False)
|
||||
|
||||
self.drop_row = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_row)
|
||||
self.drop_col = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_pair)
|
||||
|
||||
# single sequence attn
|
||||
self.msa_row_attn = MSARowAttentionWithBias(
|
||||
d_msa=d_singleseq,
|
||||
d_pair=d_pair,
|
||||
n_head=block_params.n_msa_head,
|
||||
d_hidden=block_params.n_msa_channels,
|
||||
nseq_normalization=block_params.norm_msa_row,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.msa_transition = FeedForwardLayer(d_singleseq, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
# to do: optionally disable norm bias
|
||||
self.tri_mul_outgoing = TriangleMultiplication(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, outgoing=True, bias=False)
|
||||
self.tri_mul_incoming = TriangleMultiplication(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, outgoing=False, bias=False)
|
||||
self.tri_attn_start = TriangleAttention(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, start_node=True)
|
||||
self.tri_attn_end = TriangleAttention(
|
||||
d_pair, d_hidden=block_params.n_pair_channels, start_node=False)
|
||||
|
||||
self.pair_transition = FeedForwardLayer(d_pair, 4) # HACK: hardcoded value for transition
|
||||
|
||||
|
||||
def _unpack_inputs(self, latent_feats):
|
||||
pair = latent_feats["pair"]
|
||||
singleseq = latent_feats["msa"]
|
||||
return singleseq, pair
|
||||
|
||||
def _pack_outputs(self, singleseq, pair, latent_feats):
|
||||
latent_feats["msa"] = singleseq
|
||||
latent_feats["pair"] = pair
|
||||
return latent_feats
|
||||
|
||||
def _1d_update(self, msa, pair):
|
||||
pair = self.norm_pair_bias(pair)
|
||||
msa = msa + self.drop_row(self.msa_row_attn(msa, pair)) # pair biased attn
|
||||
msa = msa + self.msa_transition(msa)
|
||||
|
||||
return msa
|
||||
|
||||
def _2d_update(self, msa, pair):
|
||||
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
|
||||
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
|
||||
pair = pair + self.drop_row(self.tri_attn_start(pair))
|
||||
pair = pair + self.drop_col(self.tri_attn_end(pair))
|
||||
pair = pair + self.pair_transition(pair)
|
||||
return pair
|
||||
|
||||
def forward(self, latent_feats, use_checkpoint, use_amp):
|
||||
singleseq, pair = self._unpack_inputs(latent_feats)
|
||||
drop_layer = 0
|
||||
if use_checkpoint:
|
||||
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
|
||||
# 2d then 1d update
|
||||
pair = checkpoint.checkpoint(self._2d_update, singleseq, pair, use_reentrant=True)
|
||||
singleseq = checkpoint.checkpoint(self._1d_update, singleseq, pair, use_reentrant=True)
|
||||
else:
|
||||
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
|
||||
# 2d then 1d update
|
||||
pair = self._2d_update(singleseq, pair)
|
||||
singleseq = self._1d_update(singleseq, pair)
|
||||
|
||||
latent_feats = self._pack_outputs(singleseq, pair, latent_feats)
|
||||
return latent_feats
|
||||
|
||||
class MsaModule(nn.Module):
|
||||
def __init__(self,
|
||||
n_blocks,
|
||||
subsampled_embedding,
|
||||
outer_product,
|
||||
msa_pair_weighted_averaging,
|
||||
msa_transition,
|
||||
triangle_multiplication_incoming,
|
||||
triangle_multiplication_outgoing,
|
||||
triangle_attention_starting,
|
||||
triangle_attention_ending,
|
||||
pair_transition,
|
||||
):
|
||||
super(MsaModule, self).__init__()
|
||||
self.n_blocks = n_blocks
|
||||
self.msa_subsampler = MsaSubsampleEmbedder(subsampled_embedding)
|
||||
self.outer_product = OuterProductMean(outer_product)
|
||||
self.msa_pair_weighted_averaging = MsaPairWeightedAverage(msa_pair_weighted_averaging)
|
||||
self.msa_transition = FeedForwardLayer(msa_transition)
|
||||
|
||||
#TODO: check if row and col dropout are right
|
||||
self.drop_row = Dropout(broadcast_dim=1, p_drop=0.25)
|
||||
self.drop_col = Dropout(broadcast_dim=2, p_drop=0.25)
|
||||
|
||||
self.tri_mult_outgoing = TriangleMultiplication(triangle_multiplication_outgoing)
|
||||
self.tri_mult_incoming = TriangleMultiplication(triangle_multiplication_incoming)
|
||||
self.tri_attn_start = TriangleAttention(triangle_attention_starting)
|
||||
self.tri_attn_end = TriangleAttention(triangle_attention_ending)
|
||||
self.pair_transition = FeedForwardLayer(pair_transition)
|
||||
|
||||
def forward(self,
|
||||
f_dict,
|
||||
pair_II,
|
||||
S_inputs
|
||||
):
|
||||
msa_SI = f_dict["msa_SI"]
|
||||
msa_SI = self.msa_subsampler(msa_SI, S_inputs)
|
||||
for i in range(self.n_blocks):
|
||||
pair_II = pair_II + self.outer_product(msa_SI)
|
||||
msa_SI = msa_SI + self.drop_row(self.msa_pair_weighted_averaging(msa_SI, pair_II))
|
||||
msa_SI = msa_SI + self.msa_transition(msa_SI)
|
||||
|
||||
pair_II = pair_II + self.drop_row(self.tri_mult_outgoing(pair_II))
|
||||
pair_II = pair_II + self.drop_row(self.tri_mult_incoming(pair_II))
|
||||
pair_II = pair_II + self.drop_row(self.tri_attn_start(pair_II))
|
||||
pair_II = pair_II + self.drop_row(self.tri_attn_end(pair_II))
|
||||
pair_II = pair_II + self.pair_transition(pair_II)
|
||||
return pair_II
|
||||
|
||||
class MsaSubsampleEmbedder(nn.Module):
|
||||
def __init__(self, params):
|
||||
super(MsaSubsampleEmbedder, self).__init__()
|
||||
self.num_sequences = params["num_sequences"]
|
||||
self.emb_msa = nn.Linear(params["msa_dim"], params["msa_channels"], bias=False)
|
||||
self.emb_S_inputs = nn.Linear(params["S_dim"], params["msa_channels"], bias=False)
|
||||
|
||||
def forward(self,
|
||||
msa_SI,
|
||||
S_inputs # (B, L, S_dim)
|
||||
):
|
||||
B, S, I = msa_SI.shape[:3]
|
||||
num_samples = torch.min(torch.tensor([self.num_sequences, S]))
|
||||
weights = torch.ones(num_samples.item(), device=msa_SI.device)
|
||||
samples = torch.multinomial(weights, num_samples, replacement=False)
|
||||
msa_SI = msa_SI[:, samples]
|
||||
msa_SI = self.emb_msa(msa_SI)
|
||||
|
||||
msa_SI = msa_SI + self.emb_S_inputs(S_inputs)
|
||||
return msa_SI
|
||||
|
||||
|
||||
class MsaPairWeightedAverage(nn.Module):
|
||||
""" implements Algorithm 10 from AF3 paper"""
|
||||
def __init__(self, params):
|
||||
super(MsaPairWeightedAverage, self).__init__()
|
||||
self.weighted_average_channels = params["weighted_average_channels"]
|
||||
self.n_heads = params["n_heads"]
|
||||
self.msa_channels = params["msa_channels"]
|
||||
self.pair_channels = params["pair_channels"]
|
||||
self.norm_msa = nn.LayerNorm(self.msa_channels)
|
||||
self.to_v = nn.Linear(self.msa_channels, self.n_heads*self.weighted_average_channels, bias=False)
|
||||
self.norm_pair = nn.LayerNorm(self.pair_channels)
|
||||
self.to_bias = nn.Linear(self.msa_channels, self.n_heads, bias=False)
|
||||
self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
|
||||
self.to_out = nn.Linear(self.weighted_average_channels, self.msa_channels, bias=False)
|
||||
|
||||
def forward(self,
|
||||
msa_SI,
|
||||
pair_II
|
||||
):
|
||||
B, S, I = msa_SI.shape[:3]
|
||||
msa_SI = self.norm_msa(msa_SI)
|
||||
v_SIH = self.to_v(msa_SI).reshape(B, S, I, self.n_heads, self.d_head)
|
||||
bias_IIH = self.to_bias(self.norm_pair(pair_II))
|
||||
gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
|
||||
w_IIH = F.softmax(bias_IIH, dim=-2)
|
||||
weights = torch.einsum( "bijh,bsjhc->bsihc", w_IIH, v_SIH)
|
||||
o_SIH = gate_SIH * weights
|
||||
msa_update_SI = self.to_out(o_SIH.reshape(B, S, I, -1))
|
||||
return msa_update_SI
|
||||
|
||||
class BiasedSequenceAttention(nn.Module):
|
||||
def __init__(self, global_params, block_params):
|
||||
super(BiasedSequenceAttention, self).__init__()
|
||||
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
|
||||
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
|
||||
#
|
||||
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
|
||||
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
|
||||
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
|
||||
|
||||
self.scaling = 1/np.sqrt(block_params.n_channels)
|
||||
self.h = block_params.n_heads
|
||||
self.dim = block_params.n_channels
|
||||
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
def forward(self, state, pair): # TODO: make this as tied-attention
|
||||
B, L = state.shape[:2]
|
||||
#
|
||||
state = self.norm_state(state)
|
||||
pair = self.norm_pair(pair)
|
||||
|
||||
query = self.to_q(state).reshape(B, L, self.h, self.dim)
|
||||
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
|
||||
value = self.to_v(state).reshape(B, L, self.h, self.dim)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(state))
|
||||
|
||||
attn = einsum('bqhd,bkhd->bqkh', query, key)
|
||||
attn = attn + bias
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
|
||||
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
||||
out = gate * out
|
||||
|
||||
out = self.to_out(out)
|
||||
out = state + self.transition(out)
|
||||
|
||||
return out
|
||||
class TemplateEmbedding(nn.Module):
|
||||
def __init__(self, params):
|
||||
super(TemplateEmbedding, self).__init__()
|
||||
self.template_channels = params["template_channels"]
|
||||
self.emb_pair = nn.Linear(params["pair_dim"], params["template_channels"], bias=False)
|
||||
self.norm_pair_before_pairformer = nn.LayerNorm(params["pair_dim"])
|
||||
self.norm_after_pairformer = nn.LayerNorm(params["template_channels"])
|
||||
# HACK: need the actual pairformer block
|
||||
self.pairformer = AF3_block(params["pair_dim"], params["template_channels"], params["pairformer_channels"], params["n_pairformer_layers"])
|
||||
# NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
|
||||
self.agg_emb = nn.Linear(params["template_channels"], params["pair_dim"], bias=False)
|
||||
def forward(self,
|
||||
f_dict,
|
||||
pair_II,
|
||||
):
|
||||
B, I = pair_II.shape[:2]
|
||||
template_frame_mask = f_dict["template_frame_mask"][None, :] * f_dict["template_frame_mask"][:, None]
|
||||
template_pseudo_beta_mask = f_dict["template_pseudo_beta_mask"][None, :] * f_dict["template_pseudo_beta_mask"][:, None]
|
||||
template_feats = torch.cat([f_dict["template_distogram"], template_frame_mask, f_dict["template_unit_vector"], template_pseudo_beta_mask])
|
||||
template_feats = template_feats * (f_dict["asym_id"][None, :] == f_dict["asym_id"][:, None])
|
||||
T = template_feats.shape[1]
|
||||
u_II = torch.zeros(B, I, I, self.template_channels, device=pair_II.device)
|
||||
for i in range(T):
|
||||
v_II = self.emb_pair(self.norm_pair_before_pairformer(pair_II)) + template_feats[:, i]
|
||||
v_II = self.pairformer(v_II)
|
||||
u_II = u_II + self.norm_after_pairformer(v_II)
|
||||
|
||||
u_II = u_II / T
|
||||
|
||||
return self.agg_emb(F.relu(u_II))
|
||||
|
||||
|
||||
class BiasedSequenceAttention(nn.Module):
|
||||
def __init__(self, global_params, block_params):
|
||||
super(BiasedSequenceAttention, self).__init__()
|
||||
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
|
||||
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
|
||||
#
|
||||
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
|
||||
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
|
||||
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
|
||||
|
||||
self.scaling = 1/np.sqrt(block_params.n_channels)
|
||||
self.h = block_params.n_heads
|
||||
self.dim = block_params.n_channels
|
||||
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
def forward(self, state, pair): # TODO: make this as tied-attention
|
||||
B, L = state.shape[:2]
|
||||
#
|
||||
state = self.norm_state(state)
|
||||
pair = self.norm_pair(pair)
|
||||
|
||||
query = self.to_q(state).reshape(B, L, self.h, self.dim)
|
||||
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
|
||||
value = self.to_v(state).reshape(B, L, self.h, self.dim)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(state))
|
||||
|
||||
attn = einsum('bqhd,bkhd->bqkh', query, key)
|
||||
attn = attn + bias
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
|
||||
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
||||
out = gate * out
|
||||
|
||||
out = self.to_out(out)
|
||||
out = state + self.transition(out)
|
||||
|
||||
return out
|
||||
779
rf2aa/model/AF3_structure.py
Normal file
779
rf2aa/model/AF3_structure.py
Normal file
@@ -0,0 +1,779 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import one_hot, sigmoid
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from torch import relu
|
||||
from icecream import ic
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
||||
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
||||
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
||||
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
|
||||
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
|
||||
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
|
||||
from rf2aa.training.checkpoint import create_custom_forward
|
||||
from rf2aa.util_module import Dropout
|
||||
#from rf2aa.alignment import weighted_rigid_align
|
||||
|
||||
'''
|
||||
Glossary:
|
||||
I: # tokens (coarse representation)
|
||||
L: # atoms (fine representation)
|
||||
M: # msa
|
||||
T: # templates
|
||||
'''
|
||||
|
||||
linearNoBias = partial(torch.nn.Linear, bias=False)
|
||||
|
||||
class AtomAttentionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, c_atom, c_atompair, c_token, atom_1d_features, atom_transformer):
|
||||
super().__init__()
|
||||
self.c_atom = c_atom
|
||||
self.c_atompair = c_atompair
|
||||
self.c_s = c_token
|
||||
self.atom_1d_features = atom_1d_features
|
||||
|
||||
self.pair_mlp = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
linearNoBias(self.c_atompair, c_atompair),
|
||||
nn.ReLU(),
|
||||
linearNoBias(self.c_atompair, c_atompair),
|
||||
nn.ReLU(),
|
||||
linearNoBias(self.c_atompair, c_atompair),
|
||||
)
|
||||
|
||||
self.atom_transformer = AtomTransformer(c_atom=c_atom, c_atompair=c_atompair, **atom_transformer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
f, # Dict (Input feature dictionary)
|
||||
Rl, # [B, L, 3]
|
||||
Si_trunk, # [B, I, C_S_trunk]
|
||||
Zij, # [B, I, I, C_Z]
|
||||
tok_idx, # [L] maps l --> i
|
||||
):
|
||||
B, I, _ = Si_trunk.shape
|
||||
|
||||
# Create the atom single conditioning: Embed per-atom meta data
|
||||
Cl = self.linear_no_bias_1(torch.concat(f[feature_name] for feature_name in self.atom_1d_features))
|
||||
|
||||
# Embed offsets between atom reference positions
|
||||
Dlm = f['ref_pos'].unsqueeze(-1) - f['ref_pos'].unsqueeze(-2)
|
||||
Vlm = f['ref_space_uid'].unsqueeze(-1) == f['ref_space_uid'].unsqueeze(-2)
|
||||
Plm = self.linear_1(Dlm) * Vlm
|
||||
|
||||
# Embed pairwise inverse squared distances, and the valid mask
|
||||
Plm += self.linear_2(1/(1+torch.linalg.norm(Dlm, dim=-1))) * Vlm
|
||||
|
||||
# Initialise the atom single representation as the single conditioning.
|
||||
Ql = Cl
|
||||
|
||||
# If provided, add trunk embeddings and noisy positions.
|
||||
## Broadcast the single and pair embedding from the trunk.
|
||||
Sl_trunk = Si_trunk[:, self.tok_idx]
|
||||
Cl += self.linear_3(self.layer_norm_1(Sl_trunk))
|
||||
|
||||
## Add the noisy positions.
|
||||
Ql += self.linear_4(Rl)
|
||||
|
||||
# Add the combined single conditioning to the pair representation.
|
||||
Plm += self.linear_5(relu(Cl)).unsqueeze(-1) + self.linear_6(relu(Cl)).unsqueeze(-2)
|
||||
|
||||
# Run a small MLP on the pair activations
|
||||
Plm += self.pair_mlp(Plm)
|
||||
|
||||
# Cross attention transformer.
|
||||
Ql = self.atom_transformer(Ql, Cl, Plm)
|
||||
|
||||
# Aggregate per-atom representation to per-token representation
|
||||
Ai = torch.zeros((B, I, self.c_token)).reduce(
|
||||
1,
|
||||
tok_idx,
|
||||
relu(self.linear_6(Ql)),
|
||||
'mean',
|
||||
include_self=False)
|
||||
|
||||
return Ai, Ql, Cl, Plm
|
||||
|
||||
|
||||
class AtomAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(self, c_token, c_atom, c_atompair, atom_transformer):
|
||||
super().__init__()
|
||||
self.atom_transformer = AtomTransformer(c_atom=c_atom, c_atompair=c_atompair, **atom_transformer)
|
||||
self.linear_1 = linearNoBias(c_token, c_atom)
|
||||
self.linear_2 = linearNoBias(c_atom, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [L, C_token]
|
||||
Ql_skip, # [L, C_atom]
|
||||
Cl_skip, # [L, C_atom]
|
||||
Plm_skip, # [L, L, C_atompair]
|
||||
tok_idx, # [L] maps l --> i
|
||||
):
|
||||
# Broadcast per-token activiations to per-atom activations and add the skip connection
|
||||
Ql = self.linear_1(Ai[tok_idx]) + Ql_skip
|
||||
|
||||
# Cross attention transformer.
|
||||
Ql = self.atom_transformer(Ql, Cl_skip, Plm_skip)
|
||||
|
||||
# Map to positions update
|
||||
Rl_update = self.linear_2(self.layer_norm_2(Ql))
|
||||
|
||||
return Rl_update
|
||||
|
||||
class AtomTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_atom,
|
||||
c_atompair,
|
||||
diffusion_transformer,
|
||||
n_queries,
|
||||
n_keys,
|
||||
l_max,
|
||||
):
|
||||
super().__init__()
|
||||
self.l_max = l_max
|
||||
subset_centers = torch.arange(0, n_queries, l_max) + (n_queries-1 + n_queries / 2)
|
||||
|
||||
l = torch.arange(l_max).unsqueeze(-1).unsqueeze(-1) # [l_max, 1, 1]
|
||||
m = torch.arange(l_max).unsqueeze(0).unsqueeze(-1) # [1, l_max, 1]
|
||||
c = subset_centers.unsqueeze(0).unsqueeze(0) # [1, 1, S]
|
||||
|
||||
Beta_lms_binary = (torch.abs(l - m) < n_queries / 2) * (torch.abs(m - c) < n_keys / 2)
|
||||
ic(
|
||||
Beta_lms_binary.dtype,
|
||||
)
|
||||
Beta_lm_binary = Beta_lms_binary.prod(dim=-1, dtype=bool)
|
||||
ic(
|
||||
Beta_lm_binary.dtype,
|
||||
)
|
||||
self.Beta_lm = torch.where(Beta_lm_binary, 0, -10e10)
|
||||
|
||||
self.diffusion_transformer = DiffusionTransformer(c_token=c_atom, c_tokenpair=c_atompair, **diffusion_transformer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ql, # [B, L, C_atom]
|
||||
Cl, # [B, L, C_atom]
|
||||
Plm, # [B, L, L, C_atompair]
|
||||
):
|
||||
B, L, _ = Ql.shape
|
||||
assert L < self.l_max
|
||||
Beta_lm = self.Beta_lm[:L, :L]
|
||||
return self.diffusion_transformer(Ql, Cl, Plm, Beta_lm)
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, c_token, c_tokenpair, n_block, diffusion_transformer_block):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.Sequential(*[
|
||||
DiffusionTransformerBlock(c_token=c_token, c_tokenpair=c_tokenpair, **diffusion_transformer_block)
|
||||
for _ in range(n_block)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token]
|
||||
Zij, # [B, I, I, C_tokenpair]
|
||||
Beta_ij, # [I, I]
|
||||
):
|
||||
return self.blocks(Ai, Si, Zij, Beta_ij)
|
||||
|
||||
class DiffusionTransformerBlock(nn.Module):
|
||||
def __init__(self, c_token, c_tokenpair, n_head):
|
||||
super().__init__()
|
||||
self.attention_pair_bias = AttentionPairBias(c_a=c_token, c_pair=c_tokenpair, n_head=n_head)
|
||||
self.conditioned_transition_block = ConditionedTransitionBlock(c_token=c_token)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token]
|
||||
Zij, # [B, I, I, C_tokenpair]
|
||||
Beta_ij, # [I, I]
|
||||
):
|
||||
Bi = self.attention_pair_bias(Ai, Si, Zij, Beta_ij)
|
||||
Ai = Bi + self.conditioned_transition_block(Ai, Si)
|
||||
return Ai, Si, Zij, Beta_ij
|
||||
|
||||
# class MultiHeadLinear(nn.Linear):
|
||||
# def __init__(self, in_features, out_features, h, *args, **kwargs):
|
||||
# self.h = h
|
||||
# self.out_features = out_features
|
||||
# super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
# def forward(self, x):
|
||||
# return sel
|
||||
|
||||
|
||||
class MultiDimLinear(nn.Linear):
|
||||
def __init__(self, in_features, out_shape, **kwargs):
|
||||
self.out_shape = out_shape
|
||||
out_features = np.prod(out_shape)
|
||||
super().__init__(in_features, out_features, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
out = super().forward(x)
|
||||
return out.reshape(x.shape[:-1] + self.out_shape)
|
||||
|
||||
class LinearBiasInit(nn.Linear):
|
||||
|
||||
def __init__(self, *args, biasinit, **kwargs):
|
||||
assert biasinit == -2. # Sanity check
|
||||
self.biasinit = biasinit
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
super().reset_parameters()
|
||||
self.bias.data.fill_(self.biasinit)
|
||||
|
||||
class AttentionPairBias(nn.Module):
|
||||
def __init__(self, c_a, c_pair, n_head):
|
||||
super().__init__()
|
||||
c = c_a // n_head
|
||||
|
||||
ic(c, n_head)
|
||||
self.to_q = MultiDimLinear(c, (n_head, c))
|
||||
self.to_k = MultiDimLinear(c, (n_head, c), bias=False)
|
||||
self.to_v = MultiDimLinear(c, (n_head, c), bias=False)
|
||||
self.to_b = linearNoBias(c_pair, n_head)
|
||||
self.to_g = nn.Sequential(
|
||||
MultiDimLinear(c_a, (n_head, c), bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self.to_a = linearNoBias(c_a, c_a)
|
||||
self.linear_output_project = nn.Sequential(
|
||||
LinearBiasInit(c_a, c_a, biasinit=-2.),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token] | None
|
||||
Zij, # [B, I, I, C_tokepair]
|
||||
Beta_ij, # [I, I]
|
||||
):
|
||||
# Input projections
|
||||
if Si is not None:
|
||||
Ai = self.ada_ln_1(Ai, Si)
|
||||
else:
|
||||
Ai = self.ln_1(Ai, Si)
|
||||
|
||||
Qih = self.to_q(Ai)
|
||||
Kih = self.to_k(Ai)
|
||||
Vih = self.to_v(Ai)
|
||||
Bijh = self.to_b(Zij) + Beta_ij
|
||||
Gih = self.to_g(Ai)
|
||||
|
||||
# Attention
|
||||
Aijh = torch.softmax(torch.pow(self.c, -1/2) * torch.einsum("bihd,bjhd->bijh", Qih, Kih) + Bijh, dim=-2) # softmax over j
|
||||
|
||||
## Gih: [B, I, H, C]
|
||||
## Aijh: [B, I, I, H]
|
||||
## ViH: [B, I, H, C]
|
||||
head_i = torch.einsum("bijh,bjhc->bihc", Aijh, Vih)
|
||||
head_i = Gih * head_i # [B, I, H, C]
|
||||
Ai = torch.concat(head_i, dim=-2) # [B, I, Ca]
|
||||
Ai = self.to_a(Ai)
|
||||
|
||||
# Output projection (from adaLN-Zero)
|
||||
if Si is not None:
|
||||
Ai = self.linear_output_project(Si) * Ai
|
||||
|
||||
return Ai
|
||||
|
||||
# SwiGLU transition block with adaptive layernorm
|
||||
class ConditionedTransitionBlock(nn.Module):
|
||||
def __init__(self, c_token, n=2):
|
||||
super().__init__()
|
||||
self.ada_ln = AdaLN(c_token=c_token)
|
||||
self.linear_1 = linearNoBias(c_token, c_token*n)
|
||||
self.linear_2 = linearNoBias(c_token, c_token*n)
|
||||
self.linear_output_project = nn.Sequential(
|
||||
LinearBiasInit(c_token, c_token, biasinit=-2.),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token]
|
||||
):
|
||||
Ai = self.ada_ln(Ai, Si)
|
||||
Bi = torch.silu(self.linear_1(Ai)) * self.linear_2(Ai)
|
||||
|
||||
# Output projection (from adaLN-Zero)
|
||||
return self.linear_output_project(Si) * self.linear_3(Bi)
|
||||
|
||||
|
||||
class AdaLN(nn.Module):
|
||||
def __init__(self, c_token, n=2):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(normalized_shape=(c_token,), elementwise_affine=False)
|
||||
self.ln_learnable_gain = nn.LayerNorm(normalized_shape=(c_token,), bias=False)
|
||||
self.linear_1 = nn.Linear(c_token, c_token)
|
||||
self.linear_2 = nn.Linear(c_token, c_token)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token]
|
||||
):
|
||||
Ai = self.ln(Ai)
|
||||
Si = self.ln_learnable_gain(Si)
|
||||
return torch.sigmoid(self.linear_1(Si)) * Ai + self.linear_2(Si)
|
||||
|
||||
class DiffusionModule(nn.Module):
|
||||
def __init__(self, sigma_data, c_atom, c_atompair, c_token, c_s, c_z, diffusion_conditioning, atom_attention_encoder, diffusion_transformer, atom_attention_decoder):
|
||||
super().__init__()
|
||||
self.sigma_data = sigma_data
|
||||
self.c_atom = c_atom
|
||||
self.c_atompair = c_atompair
|
||||
self.c_token = c_token
|
||||
|
||||
self.diffusion_conditioning = DiffusionConditioning(sigma_data=sigma_data, c_s=c_s, c_z=c_z, **diffusion_conditioning)
|
||||
self.atom_attention_encoder = AtomAttentionEncoder(c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, **atom_attention_encoder)
|
||||
self.diffusion_transformer = DiffusionTransformer(c_token=c_token, c_tokenpair=c_atompair, **diffusion_transformer)
|
||||
self.layer_norm_1 = nn.LayerNorm(c_token)
|
||||
self.atom_attention_decoder = AtomAttentionDecoder(c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, **atom_attention_decoder)
|
||||
|
||||
def forward(self,
|
||||
X_noisy_L, # [B, L, 3]
|
||||
t, # [B] (0 is ground truth)
|
||||
f, # Dict (Input feature dictionary)
|
||||
S_input_I, # [B, I, C_S_input]
|
||||
S_trunk_I, # [B, I, C_S_trunk]
|
||||
Z_trunk_II, # [B, I, I, C_Z]
|
||||
):
|
||||
# Conditioning
|
||||
S_I, Z_II = self.diffusion_conditioning(t, f, S_input_I, S_trunk_I, Z_trunk_II)
|
||||
|
||||
# Scale positions to dimensionless vectors with approximately unit variance
|
||||
R_noisy_L = X_noisy_L / torch.sqrt(t^2 + self.sigma_data)
|
||||
|
||||
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
|
||||
A_I, Q_skip_L, C_skip_L, P_skip_LL = self.atom_attention_encoder(f, R_noisy_L, S_trunk_I, Z_II)
|
||||
|
||||
# Full self-attention on token level
|
||||
A_I += self.linear_no_bias(self.layer_norm(S_I))
|
||||
A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=0)
|
||||
A_I = self.layer_norm_1(A_I)
|
||||
|
||||
# Broadcast token activations to atoms and run Sequence-local Atom Attention
|
||||
R_update_L = self.atom_attention_decoder(A_I, Q_skip_L, C_skip_L, P_skip_LL)
|
||||
|
||||
# Rescale updates to positions and combine with input positions
|
||||
X_out_L = self.sigma_data**2 / (self.sigma_data**2 + t**2) * X_noisy_L + self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5 * R_update_L
|
||||
|
||||
return X_out_L
|
||||
|
||||
class DiffusionConditioning(nn.Module):
|
||||
def __init__(self, sigma_data, c_z, c_s, c_t_embed):
|
||||
super().__init__()
|
||||
self.sigma_data = sigma_data
|
||||
self.to_zii = nn.Sequential(
|
||||
nn.LayerNorm(c_z),
|
||||
linearNoBias(c_z, c_z)
|
||||
)
|
||||
self.transition_1 = nn.ModuleList([
|
||||
Transition(c=c_s, n=2),
|
||||
Transition(c=c_s, n=2),
|
||||
])
|
||||
self.to_si = nn.Sequential(
|
||||
nn.LayerNorm(c_s),
|
||||
linearNoBias(c_s, c_s)
|
||||
)
|
||||
c_t_embed = 256
|
||||
self.fourier_embedding = FourierEmbedding(c_t_embed)
|
||||
self.process_n = nn.Sequential(
|
||||
nn.LayerNorm(c_t_embed),
|
||||
linearNoBias(c_t_embed, c_s)
|
||||
)
|
||||
self.transition_2 = nn.ModuleList([
|
||||
Transition(c=c_s, n=2),
|
||||
Transition(c=c_s, n=2),
|
||||
])
|
||||
|
||||
def forward(self,
|
||||
t,
|
||||
f,
|
||||
S_inputs_I,
|
||||
S_trunk_I,
|
||||
Z_trunk_II):
|
||||
# Pair conditioning
|
||||
Z_II = torch.concat([Z_trunk_II, self.relative_position_encoding(f)])
|
||||
Z_II = self.to_zii(Z_II)
|
||||
for b in range(2):
|
||||
Z_II += self.transition_1[b](Z_II)
|
||||
|
||||
# Single conditioning
|
||||
S_I = torch.concat([S_trunk_I, S_inputs_I])
|
||||
S_I = self.to_si(S_I)
|
||||
N_T = self.fourier_embedding(1/4 * torch.log(t/self.sigma_data))
|
||||
S_I += self.process_n(N_T)
|
||||
for b in range(2):
|
||||
S_I += self.transition_2[b](S_I)
|
||||
|
||||
return S_I, Z_II
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, n, c):
|
||||
super().__init__()
|
||||
self.n = n
|
||||
self.layer_norm_1 = nn.LayerNorm(c)
|
||||
self.linear_1 = linearNoBias(c, n*c)
|
||||
self.linear_2 = linearNoBias(c, n*c)
|
||||
self.linear_3 = linearNoBias(n*c, c)
|
||||
|
||||
def forward(self,
|
||||
X,
|
||||
):
|
||||
X = self.layer_norm_1(X)
|
||||
A = self.linear_1(X)
|
||||
B = self.linear_2(X)
|
||||
X = self.linear_3(torch.silu(A) * B)
|
||||
return X
|
||||
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
class FourierEmbedding(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
self.register_buffer('w', torch.zeros((c), dtype=torch.float32))
|
||||
self.register_buffer('b', torch.zeros((c), dtype=torch.float32))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.w)
|
||||
nn.init.normal_(self.b)
|
||||
|
||||
def forward(self,
|
||||
t,
|
||||
):
|
||||
return torch.cos(2 * pi * (t*self.w + self.b))
|
||||
|
||||
# from dataclasses import dataclass
|
||||
# @dataclass
|
||||
# class RecyclingInput:
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,
|
||||
c_s,
|
||||
c_z,
|
||||
c_atom,
|
||||
c_atompair,
|
||||
feature_initializer,
|
||||
recycler,
|
||||
diffusion_module,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_initializer = FeatureInitializer(c_s=c_s, c_z=c_z, c_atom=c_atom, c_atompair=c_atompair, **feature_initializer)
|
||||
self.recycler = Recycler(c_s=c_s, c_z=c_z, **recycler)
|
||||
self.diffusion_module = DiffusionModule(c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z, **diffusion_module)
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
X_noisy_L,
|
||||
t,
|
||||
n_cycle,
|
||||
):
|
||||
super().__init__()
|
||||
S_input_I, S_init_I, Z_init_II = self.feature_initializer(f)
|
||||
S_I = torch.zeros_like(S_init_I)
|
||||
Z_II = torch.zeros_like(Z_init_II)
|
||||
for _ in range(n_cycle):
|
||||
S_I, Z_II = self.recycler(f, S_input_I, S_init_I, Z_init_II, S_I, Z_II)
|
||||
X_pred = self.diffusion_module(
|
||||
X_noisy_L,
|
||||
t,
|
||||
f,
|
||||
S_input_I,
|
||||
S_I,
|
||||
Z_II,
|
||||
)
|
||||
return X_pred
|
||||
|
||||
def pre_recycle(self,
|
||||
f,
|
||||
X_noisy_L,
|
||||
t):
|
||||
S_input_I, S_init_I, Z_init_II = self.feature_initializer(f)
|
||||
S_I = torch.zeros_like(S_init_I)
|
||||
Z_II = torch.zeros_like(Z_init_II)
|
||||
return S_input_I, S_init_I, Z_init_II, S_I, Z_II, f, X_noisy_L, t
|
||||
|
||||
def recycle(self,
|
||||
S_input_I,
|
||||
S_init_I,
|
||||
Z_init_II,
|
||||
S_I,
|
||||
Z_II,
|
||||
f,
|
||||
X_noisy_L,
|
||||
t,
|
||||
):
|
||||
S_I, Z_II = self.recycler(
|
||||
S_input_I,
|
||||
S_init_I,
|
||||
Z_init_II,
|
||||
S_I,
|
||||
Z_II
|
||||
)
|
||||
return S_input_I, S_init_I, Z_init_II, S_I, Z_II, f, X_noisy_L, t
|
||||
|
||||
def post_recycle(self,
|
||||
S_input_I,
|
||||
S_init_I,
|
||||
Z_init_II,
|
||||
S_I,
|
||||
Z_II,
|
||||
f,
|
||||
X_noisy_L,
|
||||
t,
|
||||
):
|
||||
X_pred = self.diffusion_module(
|
||||
X_noisy_L,
|
||||
t,
|
||||
f,
|
||||
S_input_I,
|
||||
S_I,
|
||||
Z_II,
|
||||
)
|
||||
return X_pred
|
||||
|
||||
|
||||
class Recycler(nn.Module):
|
||||
def __init__(self,
|
||||
c_s,
|
||||
c_z,
|
||||
template_embedder,
|
||||
msa_module,
|
||||
n_pairformer_blocks,
|
||||
pairformer_block,
|
||||
):
|
||||
super().__init__()
|
||||
self.process_zh = nn.Sequential(
|
||||
nn.LayerNorm(c_z),
|
||||
linearNoBias(c_z, c_z),
|
||||
)
|
||||
self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
|
||||
self.msa_module = MSAModule(**msa_module)
|
||||
self.process_sh = nn.Sequential(
|
||||
nn.LayerNorm(c_s),
|
||||
linearNoBias(c_s, c_s),
|
||||
)
|
||||
self.pairformer_stack = nn.Sequential(*[
|
||||
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) for _ in range(n_pairformer_blocks)
|
||||
])
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
S_inputs_I,
|
||||
S_init_I,
|
||||
Z_init_II,
|
||||
Sh_I,
|
||||
Zh_II,
|
||||
):
|
||||
Z_II = Z_init_II + self.process_zh(Zh_II)
|
||||
Z_II += self.template_embedder(f, Z_II)
|
||||
Z_II += self.msa_module(f['msa'], Z_II, S_inputs_I)
|
||||
S_I = S_init_I + self.process_sh(Sh_I)
|
||||
S_I, Z_II = self.pairformer_stack(S_I, Z_II)
|
||||
return S_I, Z_II
|
||||
|
||||
class PairformerBlock(nn.Module):
|
||||
"""
|
||||
Attempt to replicate AF3 architecture from scratch.
|
||||
"""
|
||||
def __init__(self,
|
||||
c_s,
|
||||
c_z,
|
||||
p_drop,
|
||||
c,
|
||||
attention_pair_bias,
|
||||
n_transition=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.drop_row = Dropout(broadcast_dim=2, p_drop=p_drop)
|
||||
self.drop_col = Dropout(broadcast_dim=1, p_drop=p_drop)
|
||||
|
||||
self.tri_mul_outgoing = TriangleMultiplication(
|
||||
c_z, d_hidden=c, outgoing=True, bias=False)
|
||||
self.tri_mul_incoming = TriangleMultiplication(
|
||||
c_z, d_hidden=c, outgoing=False, bias=False)
|
||||
self.tri_attn_start = TriangleAttention(
|
||||
c_z, d_hidden=c, start_node=True)
|
||||
self.tri_attn_end = TriangleAttention(
|
||||
c_z, d_hidden=c, start_node=False)
|
||||
|
||||
self.z_transition = Transition(c_z, n_transition)
|
||||
self.s_transition = Transition(c_s, n_transition)
|
||||
|
||||
self.attention_pair_bias = AttentionPairBias(c_a=c_s, c_pair=c_z, **attention_pair_bias)
|
||||
|
||||
def forward(self,
|
||||
S_I,
|
||||
Z_II):
|
||||
Z_II += self.drop_row(self.tri_mul_outgoing(Z_II))
|
||||
Z_II += self.drop_row(self.tri_mul_incoming(Z_II))
|
||||
Z_II += self.drop_row(self.tri_attn_start(Z_II))
|
||||
Z_II += self.drop_col(self.tri_attn_end(Z_II))
|
||||
Z_II += self.z_transition(Z_II)
|
||||
S_I += self.attention_pair_bias(S_I, None, Z_II, Beta_II=0)
|
||||
S_I += self.s_transition(S_I)
|
||||
return S_I, Z_II
|
||||
|
||||
|
||||
class FeatureInitializer(nn.Module):
|
||||
def __init__(self,
|
||||
c_s,
|
||||
c_z,
|
||||
c_atom,
|
||||
c_atompair,
|
||||
input_feature_embedder,
|
||||
relative_position_encoding):
|
||||
super().__init__()
|
||||
self.input_feature_embedder = InputFeatureEmbedder(c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, **input_feature_embedder)
|
||||
self.to_s_init = linearNoBias(c_s, c_s)
|
||||
self.to_z_init_i = linearNoBias(c_s, c_z)
|
||||
self.to_z_init_j = linearNoBias(c_s, c_z)
|
||||
self.relative_position_encoding = RelativePositionEncoding(c_z=c_z, **relative_position_encoding)
|
||||
self.process_token_bonds = linearNoBias(1, c_z)
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
):
|
||||
S_inputs_I = self.input_feature_embedder(f)
|
||||
S_init_I = self.to_s_init(S_init_I)
|
||||
Z_init_II = self.to_z_init_i(S_inputs_I).unsqueeze(-3) + self.to_z_init_j(S_inputs_I).unsqueeze(-2)
|
||||
Z_init_II += self.relative_position_encoding(f)
|
||||
Z_init_II += self.process_token_bonds(f['token_bonds'])
|
||||
return S_inputs_I, S_init_I, Z_init_II
|
||||
|
||||
|
||||
class InputFeatureEmbedder(nn.Module):
|
||||
def __init__(self,
|
||||
features,
|
||||
c_atom,
|
||||
c_atompair,
|
||||
c_s,
|
||||
atom_attention_encoder):
|
||||
super().__init__()
|
||||
self.atom_attention_encoder = AtomAttentionEncoder(c_atom=c_atom, c_atompair=c_atompair, c_token=c_s, **atom_attention_encoder)
|
||||
self.features = features
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
A_I,
|
||||
):
|
||||
S_I, _, _, _ = self.atom_attention_encoder(A_I)
|
||||
S_I = torch.concat([A_I] + [f[feature] for feature in self.features])
|
||||
return S_I
|
||||
|
||||
class RelativePositionEncoding(nn.Module):
|
||||
def __init__(self,
|
||||
r_max,
|
||||
s_max,
|
||||
c_z):
|
||||
super().__init__()
|
||||
self.r_max = r_max
|
||||
self.s_max = s_max
|
||||
self.c_z = c_z
|
||||
self.linear = linearNoBias(2*(2*self.r_max+2) + (2*self.s_max+2) + 1, c_z)
|
||||
|
||||
def forward(self,
|
||||
f):
|
||||
b_samechain_II = f['asym_id'].unsqueeze(-1) == f['asym_id'].unsqueeze(-2)
|
||||
b_sameresidue_II = f['residue_index'].unsqueeze(-1) == f['residue_index'].unsqueeze(-2)
|
||||
b_same_entity_II = f['entity_id'].unsqueeze(-1) == f['entity_id'].unsqueeze(-2)
|
||||
d_residue_II = torch.where(
|
||||
b_samechain_II,
|
||||
torch.clip(f['residue_index'].unsqueeze(-2) - f['residue_index'].unsqueeze(-1) + self.r_max, 0, 2*self.r_max),
|
||||
2 * self.r_max + 1
|
||||
)
|
||||
A_relpos_II = one_hot(d_residue_II, 2*self.r_max+2)
|
||||
d_token_II = torch.where(
|
||||
b_samechain_II * b_sameresidue_II,
|
||||
torch.clip(f['token_index'].unsqueeze(-2) - f['token_index'].unsqueeze(-1) + self.r_max, 0, 2*self.r_max),
|
||||
2 * self.r_max + 1
|
||||
)
|
||||
A_reltoken_II = one_hot(d_token_II, 2*self.r_max+2)
|
||||
d_chain_II = torch.where(
|
||||
b_samechain_II,
|
||||
torch.clip(f['sym_id'].unsqueeze(-2) - f['sym_id'].unsqueeze(-1) + self.s_max, 0, 2*self.s_max),
|
||||
2 * self.s_max + 1
|
||||
)
|
||||
A_relchain_II = one_hot(d_chain_II, 2*self.s_max+2)
|
||||
return self.linear(torch.cat([A_relpos_II, A_reltoken_II, b_same_entity_II.unsqueeze(-1), A_relchain_II], dim=3))
|
||||
|
||||
# Mock for testing.
|
||||
class MSAModule(nn.Module):
|
||||
def __init__(self, n_block, c_m):
|
||||
super().__init__()
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
Z_II,
|
||||
S_inputs_I,
|
||||
):
|
||||
return Z_II
|
||||
|
||||
# Mock for testing.
|
||||
class TemplateEmbedder(nn.Module):
|
||||
def __init__(self,
|
||||
n_block,
|
||||
c_z,
|
||||
c):
|
||||
super().__init__()
|
||||
self.c =c
|
||||
self.linear = linearNoBias(c_z, c)
|
||||
|
||||
def forward(self,
|
||||
f,
|
||||
Z_II,
|
||||
):
|
||||
return self.linear(Z_II)
|
||||
|
||||
class Loss:
|
||||
def __init__(self,
|
||||
sigma_data,
|
||||
):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self,
|
||||
f,
|
||||
X_L, # [B, L, 3]
|
||||
X_gt_L, # [B, L, 3]
|
||||
t, # [B]
|
||||
):
|
||||
w_L = 1 + (
|
||||
f['is_dna']*self.alpha_is_dna +
|
||||
f['is_rna'] * self.alpha_is_rna +
|
||||
f['is_ligand'] * self.alpha_is_ligand
|
||||
)
|
||||
# Align ground truth onto predictions.
|
||||
X_gt_aligned_L = weighted_rigid_align(X_gt_L, X_L, w_L)
|
||||
l_mse = 1/3 * w_L * torch.mean(torch.linalg.norm(X_L, X_gt_aligned_L, dim=-1), dim=-1) # [B]
|
||||
|
||||
l_diffusion = (t**2 + self.sigma_data**2) / (t + self.sigma_data)**2 * l_mse
|
||||
|
||||
# TODO: implement auxiliary losses
|
||||
l_total = l_diffusion.sum()
|
||||
|
||||
return l_total, {
|
||||
'diffusion_loss': l_diffusion
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ class RF2_embedding(nn.Module):
|
||||
## Update inputs with outputs from previous forward pass
|
||||
self.recycle = recycling_factory[block_params.recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
self.recycling_type = block_params.recycling_type
|
||||
assert self.recycling_type == "msa_pair", "no backward compatibility to recycling state"
|
||||
assert self.recycling_type != "all", "no backward compatibility to recycling state"
|
||||
|
||||
def _unpack_inputs(self, rf_inputs):
|
||||
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix = \
|
||||
@@ -90,9 +90,7 @@ class RF2_embedding(nn.Module):
|
||||
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
# No support for recycling state
|
||||
#state = state + state_recycle # if state is not recycled these will be zeros
|
||||
# add template embedding
|
||||
|
||||
pair, state = self._add_templ_features(rf_inputs, pair, state)
|
||||
return {
|
||||
"msa": msa_latent,
|
||||
|
||||
367
rf2aa/model/generative_refinement.py
Normal file
367
rf2aa/model/generative_refinement.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from opt_einsum import contract as einsum
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors, SE3TransformerWrapper
|
||||
from rf2aa.model.AF3_structure import FourierEmbedding
|
||||
from rf2aa.model.Track_module import Str2Str
|
||||
from rf2aa.util_module import rbf, make_topk_graph, make_full_graph, init_lecun_normal
|
||||
from rf2aa.util import is_atom,is_nucleic,is_protein
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.loss.loss import calc_chiral_grads
|
||||
from rf2aa.model.layers.Attention_module import FeedForwardLayer
|
||||
from rf2aa.flow_matching import data_utils as du
|
||||
|
||||
import dgl
|
||||
|
||||
def get_bondgraph(bonds, num_bonds, dist_matrix, idx, is_prot, is_na, is_atom):
|
||||
''' Combine different sources of info to get bond-distance graph for all i,j pairs
|
||||
'''
|
||||
L = bonds.shape[1]
|
||||
|
||||
# intra-prot/intra-na bonds
|
||||
bonds[:,torch.arange(L),:,torch.arange(L),:] = num_bonds.transpose(0,1)
|
||||
|
||||
# ligand bonds
|
||||
ia = is_atom.nonzero()[:,0]
|
||||
bonds[:,ia[:,None],1,ia[None,:],1] = dist_matrix[ia[:,None],ia[None,:]].to(dtype=bonds.dtype)
|
||||
|
||||
# we need to handle covalent bonds between residues
|
||||
# to reduce computational load only consider +/- 1 residue
|
||||
# prot-prot
|
||||
ii,jj = ChemData().protein_connect
|
||||
resmask = (is_prot[0,:-1] * is_prot[0,1:] * ((idx[1:]-idx[:-1])==1)).nonzero()[...,0]
|
||||
bonds[:,resmask+1,:,resmask ,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
|
||||
bonds[:,resmask ,:,resmask+1,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
|
||||
# na-na
|
||||
ii,jj = ChemData().na_connect
|
||||
resmask = (is_na[0,:-1] * is_na[0,1:] * ((idx[1:]-idx[:-1])==1)).nonzero()[...,0]
|
||||
bonds[:,resmask+1,:,resmask ,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
|
||||
bonds[:,resmask ,:,resmask+1,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
|
||||
|
||||
return bonds
|
||||
|
||||
|
||||
#
|
||||
def make_atom_graph(
|
||||
xyz, mask, is_prot, is_na, is_atom, idx, num_bonds, dist_matrix, top_k=24, max_nbonds_encode=8, max_nbonds_connect=3
|
||||
):
|
||||
'''
|
||||
Build an atom level graph from a mixed residue/ligand pose
|
||||
Parameters of interest:
|
||||
max_nbonds_encode - edge features encode # bonds, max this number
|
||||
max_nbonds_connect - force connections between atoms this # bonds or fewer
|
||||
Ensure top_k is large enough for max_nbonds_connect:
|
||||
with max_nbonds_connect=3, ~15 atoms are brought in by bonds alone
|
||||
with max_nbonds_connect=2, ~11 atoms are brought in by bonds alone
|
||||
with max_nbonds_connect=1, ~4 atoms are brought in by bonds alone
|
||||
'''
|
||||
B,L,A = xyz.shape[:3]
|
||||
device = xyz.device
|
||||
D = torch.norm(
|
||||
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
|
||||
)
|
||||
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
|
||||
|
||||
bonds = torch.full_like(D, ChemData().MAX_BOND_DIST, dtype=num_bonds.dtype)
|
||||
bonds = get_bondgraph(bonds, num_bonds, dist_matrix, idx, is_prot, is_na, is_atom)
|
||||
|
||||
# set D to _negative_ for close bonded
|
||||
# all missing-atom pairs will have D==0 so need to prefer these
|
||||
D[bonds<=max_nbonds_connect] = -1.0
|
||||
D[bonds==0] = np.inf # set D large for self
|
||||
D[~mask2d] = np.inf # set D large for non-atoms
|
||||
|
||||
# select top K neighbors for each atom
|
||||
# keep indices as batch/res/atm indices
|
||||
nmaxedge = torch.sum(mask)-1 # most edges = num atoms - 1
|
||||
|
||||
if top_k > nmaxedge:
|
||||
top_k = nmaxedge
|
||||
|
||||
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, A, top_k)
|
||||
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
|
||||
bi,ri,ai = mask.nonzero(as_tuple=True)
|
||||
bi = bi[:,None].repeat(1,top_k).reshape(-1)
|
||||
ri = ri[:,None].repeat(1,top_k).reshape(-1)
|
||||
ai = ai[:,None].repeat(1,top_k).reshape(-1)
|
||||
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
|
||||
|
||||
# on each edge, encode:
|
||||
# a) 1-hot encode the number of bonds (up to maxbonds) separating each atom
|
||||
# b) 1/D
|
||||
bonds = bonds[bi,ri,ai,rj,aj]
|
||||
bonds[bonds >= max_nbonds_encode] = max_nbonds_encode
|
||||
|
||||
natm = torch.sum(mask)
|
||||
index = torch.zeros_like(mask, dtype=torch.long, device=device)
|
||||
index[mask] = torch.arange(natm, device=device)
|
||||
src=index[bi,ri,ai]
|
||||
tgt=index[bi,rj,aj]
|
||||
|
||||
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
|
||||
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj])
|
||||
|
||||
edge = torch.cat([
|
||||
F.one_hot(bonds-1),
|
||||
1 / (torch.norm(G.edata['rel_pos'], dim=-1,keepdim=True)+1)
|
||||
], dim=-1)
|
||||
|
||||
return G, edge
|
||||
|
||||
|
||||
class BiasedSequenceAttention(nn.Module):
|
||||
def __init__(self, global_params, block_params):
|
||||
super(BiasedSequenceAttention, self).__init__()
|
||||
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
|
||||
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
|
||||
#
|
||||
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
|
||||
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
|
||||
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
|
||||
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
|
||||
|
||||
self.scaling = 1/np.sqrt(block_params.n_channels)
|
||||
self.h = block_params.n_heads
|
||||
self.dim = block_params.n_channels
|
||||
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
#nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
def forward(self, state, pair): # TODO: make this as tied-attention
|
||||
B, L = state.shape[:2]
|
||||
#
|
||||
state = self.norm_state(state)
|
||||
pair = self.norm_pair(pair)
|
||||
|
||||
query = self.to_q(state).reshape(B, L, self.h, self.dim)
|
||||
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
|
||||
value = self.to_v(state).reshape(B, L, self.h, self.dim)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(state))
|
||||
|
||||
attn = einsum('bqhd,bkhd->bqkh', query, key)
|
||||
attn = attn + bias
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
|
||||
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
||||
out = gate * out
|
||||
|
||||
out = self.to_out(out)
|
||||
out = state + self.transition(out)
|
||||
|
||||
return out
|
||||
|
||||
class GenerativeRefinement(nn.Module):
|
||||
def __init__(self, global_params, block_params) -> None:
|
||||
super(GenerativeRefinement, self).__init__()
|
||||
self.proj_atoms = nn.Linear(ChemData().NELTTYPES, global_params.d_state)
|
||||
|
||||
self.proj_edge = nn.Linear(9, block_params.num_edge_features) #max_bonds_encode+1, need to refactor
|
||||
|
||||
self.norm_state_0 = nn.LayerNorm(global_params.d_state+global_params.d_msa)
|
||||
self.proj_state_0 = nn.Linear(global_params.d_state+global_params.d_msa, global_params.d_state)
|
||||
self.ff_state_0 = FeedForwardLayer(global_params.d_state, 2, zero_init=False)
|
||||
|
||||
self.norm_pair_0 = nn.LayerNorm(global_params.d_pair)
|
||||
self.proj_pair_0 = nn.Linear(global_params.d_pair, global_params.d_pair)
|
||||
self.ff_pair_0 = FeedForwardLayer(global_params.d_pair, 2, zero_init=False)
|
||||
|
||||
self.timestep_embedding_dim = 256
|
||||
self.fourier_embedding = FourierEmbedding(self.timestep_embedding_dim)
|
||||
self.norm_timstep_emb = nn.LayerNorm(self.timestep_embedding_dim)
|
||||
self.emb_timestep = nn.Linear(self.timestep_embedding_dim, global_params.d_state, bias=False)
|
||||
|
||||
|
||||
self.proj_state_1 = nn.Linear(global_params.d_state, global_params.d_state, bias=False)
|
||||
self.proj_state_2 = nn.Linear(global_params.d_state, global_params.d_state, bias=False)
|
||||
self.norm_state_1 = nn.LayerNorm(global_params.d_state)
|
||||
self.norm_state_2 = nn.LayerNorm(global_params.d_state)
|
||||
|
||||
self.sigma_data = 16 # expose as parameter
|
||||
|
||||
self.atom_encoder = nn.ModuleList(
|
||||
[
|
||||
SE3TransformerWrapper(
|
||||
num_layers=block_params.num_layers,
|
||||
num_channels=block_params.num_channels,
|
||||
num_degrees=block_params.num_degrees,
|
||||
n_heads=block_params.n_heads,
|
||||
div=block_params.div,
|
||||
l0_in_features=block_params.l0_in_features,
|
||||
l0_out_features=block_params.l0_out_features,
|
||||
l1_in_features=1,
|
||||
l1_out_features=1,
|
||||
num_edge_features=block_params.num_edge_features,
|
||||
compute_gradients=True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.token_processing = nn.ModuleList(
|
||||
[
|
||||
BiasedSequenceAttention(global_params, block_params)
|
||||
for i in range(block_params.num_attention_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.atom_decoder = nn.ModuleList(
|
||||
[
|
||||
SE3TransformerWrapper(
|
||||
num_layers=block_params.num_layers,
|
||||
num_channels=block_params.num_channels,
|
||||
num_degrees=block_params.num_degrees,
|
||||
n_heads=block_params.n_heads,
|
||||
div=block_params.div,
|
||||
l0_in_features=block_params.l0_in_features,
|
||||
l0_out_features=0,
|
||||
l1_in_features=1,
|
||||
l1_out_features=1,
|
||||
num_edge_features=block_params.num_edge_features,
|
||||
compute_gradients=True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _unpack_latents(self, latent_feats):
|
||||
msa, pair, state = latent_feats["msa"], latent_feats["pair"], latent_feats["state"]
|
||||
|
||||
seq_unmasked = latent_feats["seq_unmasked"]
|
||||
allatom_mask = ChemData().allatom_mask.to(state.device)
|
||||
is_valid_atom = allatom_mask[seq_unmasked]
|
||||
num_bonds = ChemData().num_bonds.to(state.device)
|
||||
num_bonds_sequence = num_bonds[seq_unmasked]
|
||||
xyz = latent_feats["trans_t"][0]
|
||||
t = latent_feats["t"]
|
||||
|
||||
is_atomized = is_atom(seq_unmasked)
|
||||
is_prot = is_protein(seq_unmasked)
|
||||
is_na = is_nucleic(seq_unmasked)
|
||||
dist_matrix = latent_feats["dist_matrix"][0]
|
||||
idx = latent_feats["idx"][0]
|
||||
|
||||
return (
|
||||
msa, pair, state, seq_unmasked,
|
||||
is_valid_atom, num_bonds_sequence, xyz, t,
|
||||
is_atomized, is_prot, is_na, dist_matrix, idx
|
||||
)
|
||||
|
||||
def _embed_1d(self, latent_feats):
|
||||
seq_unmasked = latent_feats["seq_unmasked"]
|
||||
|
||||
# feature set 1: element
|
||||
elts = ChemData().aa2eltidx.to(seq_unmasked.device)
|
||||
elts = elts[seq_unmasked]
|
||||
elts = torch.nn.functional.one_hot(elts, ChemData().NELTTYPES).float()
|
||||
|
||||
return self.proj_atoms(elts)
|
||||
|
||||
def forward(self, latent_feats):
|
||||
# get outputs + noised xyz
|
||||
msa, pair, state, seq_unmasked, is_valid_atom, num_bonds_sequence, xyz, t, \
|
||||
is_atomized, is_prot, is_na, dist_matrix, idx \
|
||||
= self._unpack_latents(latent_feats)
|
||||
|
||||
t_hat = (1-t)*du.NM_TO_ANG_SCALE ## ?
|
||||
|
||||
# initial state embedding from msa (single seq) + state
|
||||
msa = msa.squeeze(1)
|
||||
state = torch.cat([msa,state], dim=-1)
|
||||
state = self.proj_state_0(self.norm_state_0(state))
|
||||
|
||||
# add timestep embedding
|
||||
timestep_emb_T = self.fourier_embedding(1/4 * torch.log(t_hat/self.sigma_data))
|
||||
timestep_emb_T = self.emb_timestep(self.norm_timstep_emb(timestep_emb_T))
|
||||
state = state + timestep_emb_T
|
||||
|
||||
state = self.ff_state_0(state)
|
||||
|
||||
# initial pair embedding
|
||||
pair = self.proj_pair_0(self.norm_pair_0(pair))
|
||||
pair = self.ff_pair_0(pair)
|
||||
|
||||
# add atom embeddings
|
||||
# to do: a lot more...
|
||||
atomstate = state[...,None,:] + self._embed_1d(latent_feats)
|
||||
|
||||
# encode atom level
|
||||
atomstate, xyz = self.atom_update(
|
||||
t_hat, xyz, atomstate, is_valid_atom,
|
||||
is_prot, is_na, is_atomized, idx,
|
||||
num_bonds_sequence, dist_matrix, encoder=True)
|
||||
|
||||
# to residue level
|
||||
atomstate = F.relu_(self.proj_state_1(F.relu_(atomstate)))
|
||||
state = (
|
||||
self.proj_state_1(self.norm_state_1(state))
|
||||
+ (atomstate * is_valid_atom[..., None]).sum(dim=2) / is_valid_atom.sum(dim=2)[...,None] # project down atomstate
|
||||
)
|
||||
|
||||
# res level updates
|
||||
for layer in self.token_processing:
|
||||
state = layer(state, pair)
|
||||
|
||||
# combine atom&res level state, decode atom level
|
||||
atomstate = atomstate + self.proj_state_2(self.norm_state_2(state))[...,None,:]
|
||||
|
||||
atomstate, xyz = self.atom_update(
|
||||
t_hat, xyz, atomstate, is_valid_atom,
|
||||
is_prot, is_na, is_atomized, idx,
|
||||
num_bonds_sequence, dist_matrix, encoder=False)
|
||||
|
||||
# to residue level
|
||||
state = (atomstate * is_valid_atom[..., None]).sum(dim=2) / is_valid_atom.sum(dim=2)[...,None]
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"xyz": xyz,
|
||||
}
|
||||
|
||||
|
||||
def atom_update(self, t_hat, xyz, state, is_valid_atom, is_prot, is_na, is_atomized, idx, num_bonds_sequence, dist_matrix, encoder=True):
|
||||
SE3_SCALE = 10.0 #to do: make this a parameter
|
||||
|
||||
if encoder:
|
||||
layers = self.atom_encoder
|
||||
else:
|
||||
layers = self.atom_decoder
|
||||
|
||||
for layer in layers:
|
||||
G, edge = make_atom_graph(
|
||||
xyz, is_valid_atom, is_prot, is_na, is_atomized, idx, num_bonds_sequence, dist_matrix,
|
||||
)
|
||||
node = state[is_valid_atom]
|
||||
node_l1 = xyz[is_valid_atom].unsqueeze(-2) #torch.ones((node.shape[0], 3,3), device=state.device, dtype=state.dtype)
|
||||
edge = self.proj_edge(edge).unsqueeze(-1)
|
||||
shift = layer(G, node[..., None], node_l1, edge)
|
||||
|
||||
xyz[is_valid_atom] = xyz[is_valid_atom] + shift["1"].squeeze(1)/SE3_SCALE
|
||||
if encoder:
|
||||
state[is_valid_atom] = state[is_valid_atom] + shift["0"][...,0]/SE3_SCALE
|
||||
else:
|
||||
xyz = xyz+shift["0"].sum() #fd hack to avoid unused grads (even though shift['0'] is of dim 0)
|
||||
|
||||
return state, xyz
|
||||
|
||||
@@ -7,23 +7,24 @@ from einops import rearrange
|
||||
from rf2aa.util_module import init_lecun_normal
|
||||
|
||||
class FeedForwardLayer(nn.Module):
|
||||
def __init__(self, d_model, r_ff, p_drop=0.1):
|
||||
def __init__(self, d_model, r_ff, p_drop=0.1, zero_init=True):
|
||||
super(FeedForwardLayer, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
||||
|
||||
self.reset_parameter()
|
||||
self.reset_parameter(zero_init)
|
||||
|
||||
def reset_parameter(self):
|
||||
def reset_parameter(self,zero_init):
|
||||
# initialize linear layer right before ReLu: He initializer (kaiming normal)
|
||||
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear1.bias)
|
||||
|
||||
# initialize linear layer right before residual connection: zero initialize
|
||||
nn.init.zeros_(self.linear2.weight)
|
||||
nn.init.zeros_(self.linear2.bias)
|
||||
if zero_init:
|
||||
nn.init.zeros_(self.linear2.weight)
|
||||
nn.init.zeros_(self.linear2.bias)
|
||||
|
||||
def forward(self, src):
|
||||
src = self.norm(src)
|
||||
|
||||
@@ -77,6 +77,7 @@ class MSA_emb(nn.Module):
|
||||
# pair embedding
|
||||
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix, same_chain=same_chain,cyclize=cyclize)
|
||||
# state embedding
|
||||
|
||||
state = self._state_emb(seq)
|
||||
return msa, pair, state
|
||||
|
||||
@@ -522,6 +523,21 @@ class Templ_emb_NoPtwise(nn.Module):
|
||||
return pair, state
|
||||
|
||||
|
||||
class RecyclingMSAPairOnly(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=0):
|
||||
super(RecyclingMSAPairOnly, self).__init__()
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.proj_pair = nn.Linear(d_pair, d_pair, bias=False)
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.proj_msa = nn.Linear(d_msa, d_msa, bias=False)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = msa.shape[:2]
|
||||
msa = msa + self.proj_msa(self.norm_msa(msa))
|
||||
pair = pair + self.proj_pair(self.norm_pair(pair))
|
||||
|
||||
return msa, pair, state # state is dummy
|
||||
|
||||
class Recycling(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||
super(Recycling, self).__init__()
|
||||
@@ -593,6 +609,7 @@ class RecyclingAllFeatures(nn.Module):
|
||||
return msa, pair, state
|
||||
|
||||
recycling_factory = {
|
||||
"msa_pair_only": RecyclingMSAPairOnly,
|
||||
"msa_pair": Recycling,
|
||||
"all": RecyclingAllFeatures
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ class SE3TransformerWrapper(nn.Module):
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=2,
|
||||
num_edge_features=32,
|
||||
compute_gradients=False):
|
||||
compute_gradients=False, **kwargs):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.l1_in = l1_in_features
|
||||
@@ -63,36 +63,16 @@ class SE3TransformerWrapper(nn.Module):
|
||||
populate_edge="arcsin",
|
||||
final_layer="lin",
|
||||
use_layer_norm=True,
|
||||
compute_gradients=compute_gradients
|
||||
compute_gradients=compute_gradients,
|
||||
)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
|
||||
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||
for n, p in self.se3.named_parameters():
|
||||
if "bias" in n:
|
||||
nn.init.zeros_(p)
|
||||
elif len(p.shape) == 1:
|
||||
continue
|
||||
else:
|
||||
if "radial_func" not in n:
|
||||
p = init_lecun_normal_param(p)
|
||||
else:
|
||||
if "net.6" in n:
|
||||
nn.init.zeros_(p)
|
||||
else:
|
||||
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||
|
||||
# make last layers to be zero-initialized
|
||||
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
if self.l1_out > 0:
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
pass
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
#if self.l1_out > 0:
|
||||
# nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
|
||||
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||
if self.l1_in > 0:
|
||||
@@ -282,7 +262,6 @@ class FullyConnectedSE3(FullyConnectedSE3_noR):
|
||||
|
||||
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, state, is_atom, drop_layer=False, is_motif=None):
|
||||
weight = 0. if drop_layer else 1.
|
||||
|
||||
B, L = node.shape[:2]
|
||||
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||||
|
||||
@@ -320,6 +299,8 @@ class FullyConnectedSE3(FullyConnectedSE3_noR):
|
||||
quat_update = torch.stack([qA, qB, qC, qD], dim=2)
|
||||
|
||||
return state, xyz, quat_update
|
||||
|
||||
|
||||
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals,idx, bond_feats, dist_matrix, drop_layer=False, is_motif=None):
|
||||
|
||||
block_outputs = super().forward(msa, pair, state, xyz, is_atom, atom_frames, chirals,idx, bond_feats, dist_matrix, is_motif=is_motif)
|
||||
|
||||
@@ -36,7 +36,7 @@ class RosettaFold(nn.Module):
|
||||
self.simulator = nn.ModuleList(blocks)
|
||||
|
||||
n_refinement_blocks = len(model_params.refinement.keys())
|
||||
assert n_refinement_blocks <= 1, "only can have one refinment block"
|
||||
assert n_refinement_blocks <= 1, "only can have one refinement block"
|
||||
self.refinement = None
|
||||
if n_refinement_blocks == 1:
|
||||
refinement_type = next(iter(model_params.refinement.keys()))
|
||||
@@ -60,7 +60,7 @@ class RosettaFold(nn.Module):
|
||||
for aux_task in model_params.auxiliary_predictors.keys()
|
||||
}
|
||||
|
||||
def forward(self, rf_inputs, use_checkpoint, use_amp):
|
||||
def forward(self, rf_inputs, use_checkpoint, use_amp, skip_refinement=False):
|
||||
latent_feats = self.embedding(rf_inputs)
|
||||
#load useful primitives into latent_features
|
||||
latent_feats.update(
|
||||
@@ -73,20 +73,24 @@ class RosettaFold(nn.Module):
|
||||
"bond_feats": rf_inputs["bond_feats"],
|
||||
"dist_matrix": rf_inputs["dist_matrix"],
|
||||
"is_motif": rf_inputs.get("is_motif", None),
|
||||
"seq_unmasked": rf_inputs["seq_unmasked"],
|
||||
"trans_t": rf_inputs.get("trans_t", None),
|
||||
"t": rf_inputs.get("t", None),
|
||||
}
|
||||
)
|
||||
for block in self.simulator:
|
||||
latent_feats = block(latent_feats, use_checkpoint, use_amp)
|
||||
|
||||
rf_outputs = {}
|
||||
if self.refinement:
|
||||
rf_outputs = self.refinement(latent_feats)
|
||||
|
||||
for aux_task, aux_predictor in self.auxiliary_predictors.items():
|
||||
input_feature = self.auxiliary_predictor_input_feats[aux_task]
|
||||
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
|
||||
rf_outputs.update({aux_task: auxiliary_predictions})
|
||||
if not skip_refinement:
|
||||
if self.refinement:
|
||||
rf_outputs = self.refinement(latent_feats)
|
||||
|
||||
for aux_task, aux_predictor in self.auxiliary_predictors.items():
|
||||
input_feature = self.auxiliary_predictor_input_feats[aux_task]
|
||||
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
|
||||
rf_outputs.update({aux_task: auxiliary_predictions})
|
||||
|
||||
return rf_outputs, latent_feats
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors, SE3TransformerWrapper
|
||||
from rf2aa.model.Track_module import Str2Str
|
||||
from rf2aa.util_module import rbf, make_topk_graph, init_lecun_normal
|
||||
from rf2aa.util_module import rbf, make_topk_graph, make_full_graph, init_lecun_normal
|
||||
from rf2aa.util import is_atom,is_nucleic,is_protein
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.loss.loss import calc_chiral_grads
|
||||
from rf2aa.model.generative_refinement import GenerativeRefinement
|
||||
|
||||
class LocalRefinementSE3(FullyConnectedSE3):
|
||||
|
||||
@@ -58,7 +60,11 @@ class LocalRefinementSE3(FullyConnectedSE3):
|
||||
def construct_graph(self, xyz, edge):
|
||||
L = xyz.shape[1]
|
||||
idx = torch.arange(L, device=edge.device)[None]
|
||||
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
|
||||
if self.top_k>0:
|
||||
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
|
||||
else:
|
||||
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx) # top_k=1 -> fully connected
|
||||
|
||||
return G, edge_feats
|
||||
|
||||
class RecurrentLocalRefinement(nn.Module):
|
||||
@@ -75,7 +81,7 @@ class RecurrentLocalRefinement(nn.Module):
|
||||
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
|
||||
latent_feats["atom_frames"], latent_feats["chirals"]
|
||||
idx, bond_feats, dist_matrix = latent_feats["idx"], latent_feats["bond_feats"], latent_feats["dist_matrix"]
|
||||
return msa, pair, state, xyz, is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
|
||||
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
|
||||
|
||||
def forward(self, latent_feats):
|
||||
B, N, L = latent_feats["msa"].shape[:3]
|
||||
@@ -110,7 +116,7 @@ class RecurrentLocalRefinement_w_Adaptor(nn.Module):
|
||||
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
|
||||
latent_feats["atom_frames"], latent_feats["chirals"]
|
||||
idx, bond_feats, dist_matrix = latent_feats["idx"], latent_feats["bond_feats"], latent_feats["dist_matrix"]
|
||||
return msa, pair, state, xyz, is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
|
||||
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
|
||||
|
||||
def forward(self, latent_feats):
|
||||
B, N, L = latent_feats["msa"].shape[:3]
|
||||
@@ -199,5 +205,6 @@ class LegacyRefiner(nn.Module):
|
||||
refinement_factory ={
|
||||
"local": RecurrentLocalRefinement,
|
||||
"local_adaptor": RecurrentLocalRefinement_w_Adaptor,
|
||||
"legacy": LegacyRefiner
|
||||
"legacy": LegacyRefiner,
|
||||
"generative": GenerativeRefinement
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.AF3_blocks import AF3_block,AF3_full_block
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
||||
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
||||
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
||||
@@ -287,11 +288,13 @@ class RF2_untied(RF2_block):
|
||||
|
||||
|
||||
block_factory = {
|
||||
"RF2aa": partial(RF2_block, is_full=False),
|
||||
"RF2aa_full": partial(RF2_block, is_full=True),
|
||||
"untied_p2p": partial(RF2_untied, is_full=False),
|
||||
"untied_p2p_full": partial(RF2_untied, is_full=True),
|
||||
"AF3": AF3_block,
|
||||
"AF3_full": AF3_full_block,
|
||||
"RF2aa": partial(RF2_block, is_full=False),
|
||||
"RF2aa_full": partial(RF2_block, is_full=True),
|
||||
"untied_p2p": partial(RF2_untied, is_full=False),
|
||||
"untied_p2p_full": partial(RF2_untied, is_full=True),
|
||||
"RF2_withgradients": partial(RF2_withgradients, is_full=False),
|
||||
"RF2_withgradients_full": partial(RF2_withgradients, is_full=True),
|
||||
"RF2_withgradients_R": partial(RF2_block, is_full=False, backprop_through_xyz=True),
|
||||
"RF2_withgradients_R": partial(RF2_block, is_full=False, backprop_through_xyz=True),
|
||||
}
|
||||
|
||||
61
rf2aa/tests/test_AF3_blocks.py
Normal file
61
rf2aa/tests/test_AF3_blocks.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
import torch
|
||||
from rf2aa.model.AF3_blocks import MsaSubsampleEmbedder
|
||||
|
||||
|
||||
def test_msa_module():
|
||||
pass
|
||||
|
||||
def test_msa_subsampler():
|
||||
B, N, L = 1, 100, 20
|
||||
params = {
|
||||
"num_sequences": 256,
|
||||
"msa_dim": 20,
|
||||
"msa_channels": 64,
|
||||
"S_dim": 32
|
||||
}
|
||||
msa_SI = torch.rand(B, N, L, 20)
|
||||
S_inputs = torch.rand(B, L, 32)
|
||||
subsampler = MsaSubsampleEmbedder(params)
|
||||
msa_SI = subsampler(msa_SI, S_inputs)
|
||||
assert msa_SI.shape == (B, N, L, 64)
|
||||
B, N, L = 1, 500, 20
|
||||
params = {
|
||||
"num_sequences": 256,
|
||||
"msa_dim": 20,
|
||||
"msa_channels": 64,
|
||||
"S_dim": 32
|
||||
}
|
||||
msa_SI = torch.rand(B, N, L, 20)
|
||||
S_inputs = torch.rand(B, L, 32)
|
||||
subsampler = MsaSubsampleEmbedder(params)
|
||||
msa_SI = subsampler(msa_SI, S_inputs)
|
||||
assert msa_SI.shape == (B, 256, L, 64)
|
||||
|
||||
def test_msa_pair_weighted_average():
|
||||
pass
|
||||
|
||||
def test_msa_weighting_einsum():
|
||||
B, I, S, H, c = 1, 5, 10, 8, 4
|
||||
|
||||
gate_SIH = torch.randn(B, S, I, H, c)
|
||||
w_IIH = torch.randn(B, I, I, H)
|
||||
v_SIH = torch.randn(B, S, I, H, c)
|
||||
|
||||
# Initialize the result tensor
|
||||
C = torch.zeros((B, S, I, H, c))
|
||||
|
||||
# Perform the einsum contraction in smaller steps
|
||||
#for idx_b in range(B):
|
||||
#for idx_s in range(S):
|
||||
#for idx_i in range(I):
|
||||
#for idx_h in range(H):
|
||||
#for idx_c in range(c):
|
||||
#C[idx_b, idx_s, idx_i, idx_h, idx_c] = torch.sum(
|
||||
#v_SIH[idx_b, idx_s, :, idx_h, idx_c] * w_IIH[idx_b, :, idx_i, idx_h]
|
||||
#)
|
||||
unaggregated_weights = torch.einsum("bsihc, bijh -> bsijhc", v_SIH, w_IIH)
|
||||
|
||||
weights = torch.einsum("bsihc, biih -> bsihc", v_SIH, w_IIH)
|
||||
o_SIH = gate_SIH * weights
|
||||
|
||||
@@ -21,7 +21,7 @@ from functools import partial
|
||||
# goal is to test all the configs on a broad set of datasets
|
||||
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
test_conditions, test_ids = setup_benchmark_array(["pdb196"], ["rf2aa","rf2_deep_layerdropout"])
|
||||
test_conditions, test_ids = setup_benchmark_array(["pdb196"], ["rf2aa","rf2_deep_layerdropout","af3"])
|
||||
|
||||
def setup_test(example, trainer):
|
||||
model = trainer.model
|
||||
@@ -76,7 +76,7 @@ def test_benchmark_fw_bw(benchmark, example, trainer):
|
||||
|
||||
def run():
|
||||
output_i = recycle_step_packed(trainer.model, network_input, 1, trainer.config.training_params.use_amp, nograds=False, force_device=gpu)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(dataloader_inputs, device=gpu)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, _, _ = get_loss_calc_items(dataloader_inputs, device=gpu)
|
||||
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
trainer,
|
||||
|
||||
@@ -24,7 +24,7 @@ sm_compl_asmb_item = {'CHAINID': '4i7z_D', 'DEPOSITION': '2012-12-01', 'RESOLUTI
|
||||
benchmark196_item = {'Unnamed: 0': 489, 'CHAINID': '4y7y_BA', 'DEPOSITION': '2015-02-16', 'RESOLUTION': 2.4, 'HASH': '105747', 'CLUSTER': 10756, 'SEQUENCE': 'TSIMAVTFKDGVILGADSRTTTGAYIANRVTDKLTRVHDKIWCCRSGSAADTQAIADIVQYHLELYTSQYGTPSTETAASVFKELCYENKDNLTAGIIVAGYDDKNKGEVYTIPLGGSVHKLPYAIAGSGSTFIYGYCDKNFRENMSKEETVDFIKHSLSQAIKWDGSSGGVIRMVVLTAAGVERLIFYPDEYEQL', 'LEN_EXIST': 196, 'TAXID': ''}
|
||||
|
||||
# configurations to test
|
||||
configs = ["rf2aa", "rf2_deep_layerdropout", "legacy_train", "rf2aa_legacy_refinement", "rf_with_gradients", "untied_p2p"]
|
||||
configs = ["rf2aa", "rf2_deep_layerdropout", "legacy_train", "rf2aa_legacy_refinement", "rf_with_gradients", "untied_p2p", "af3"]
|
||||
datasets = {
|
||||
"pdb" : (loader_pdb, pdb_item, {"homo": {"CHAIN_A": pd.Series(dtype=np.float32)}}),
|
||||
"compl" : (loader_complex, compl_item, {}),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
import hydra
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import pytest
|
||||
@@ -16,10 +17,10 @@ from rf2aa.data.compose_dataset import compose_single_item_dataset
|
||||
# goal is to test all the configs on a broad set of datasets
|
||||
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
test_conditions, test_ids = setup_benchmark_array(["pdb256"], ["rf_with_gradients"], device=gpu)
|
||||
#test_conditions, test_ids = setup_benchmark_array(["pdb256"], ["rf_with_gradients"], device=gpu)
|
||||
|
||||
def setup_test(example, trainer):
|
||||
model = trainer.model
|
||||
#model = trainer.model
|
||||
config = trainer.config.chem_params
|
||||
|
||||
# initialize chemical database
|
||||
@@ -28,14 +29,16 @@ def setup_test(example, trainer):
|
||||
|
||||
# to GPU
|
||||
trainer.move_constants_to_device(gpu)
|
||||
model = model.to(gpu)
|
||||
|
||||
trainer.construct_model(device=gpu)
|
||||
#model = model.to(gpu)
|
||||
|
||||
dataset_name = example[0]
|
||||
item, loader_params, _, loader, loader_kwargs = example[1:]
|
||||
#HACK: reduce crop size
|
||||
|
||||
loader_params["CROP"] = 200
|
||||
loader_params["MAXCYCLE"] = 10
|
||||
loader_params["CROP"] = 100
|
||||
loader_params["MAXCYCLE"] = 4
|
||||
# read from disk, move to device
|
||||
dataloader = compose_single_item_dataset( None, item, loader_params, loader, loader_kwargs)
|
||||
return dataloader
|
||||
@@ -49,7 +52,7 @@ def test_minimize_example(example, trainer):
|
||||
|
||||
dataloader = setup_test(example, trainer)
|
||||
trainer.train_loader = dataloader
|
||||
trainer.model = DDP(trainer.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
#trainer.model = DDP(trainer.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
|
||||
trainer.construct_optimizer()
|
||||
trainer.construct_scheduler()
|
||||
@@ -64,9 +67,18 @@ def test_minimize_example(example, trainer):
|
||||
for file in glob("models/*"):
|
||||
os.remove(file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
example, trainer = test_conditions[0]
|
||||
@hydra.main(config_path="../config/train", config_name="base")
|
||||
def main(config):
|
||||
from rf2aa.trainer_new import trainer_factory
|
||||
trainer = trainer_factory[config.experiment.trainer](config)
|
||||
example = data["pdb"]
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
|
||||
os.environ['MASTER_PORT'] = '%d'%trainer.config.ddp_params.port
|
||||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
||||
test_minimize_example(example, trainer)
|
||||
test_minimize_example(example, trainer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
from rf2aa.tests.test_conditions import setup_data
|
||||
data = setup_data()
|
||||
main()
|
||||
@@ -10,19 +10,20 @@ import time
|
||||
import omegaconf
|
||||
from contextlib import nullcontext
|
||||
import datetime
|
||||
|
||||
from datetime import timedelta
|
||||
import certifi
|
||||
import warnings
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
|
||||
from rf2aa.flow_matching.interpolant import Interpolant
|
||||
from rf2aa.flow_matching.sampler import Sampler
|
||||
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
|
||||
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads
|
||||
from rf2aa.training.EMA import EMA, count_parameters
|
||||
from rf2aa.loss.loss import translation_vector_field
|
||||
from rf2aa.loss.loss_factory import get_loss_and_misc
|
||||
from rf2aa.training.optimizer import add_weight_decay
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_sampling, run_model_forward
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward
|
||||
from rf2aa.model.network import RosettaFold
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
|
||||
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
|
||||
@@ -93,8 +94,24 @@ class Trainer:
|
||||
|
||||
def load_model(self):
|
||||
torch.cuda.empty_cache()
|
||||
self.model.module.model.load_state_dict(self.checkpoint["final_state_dict"], strict=True)
|
||||
self.model.module.shadow.load_state_dict(self.checkpoint["model_state_dict"], strict=False)
|
||||
|
||||
new_model_state = {}
|
||||
new_shadow_state = {}
|
||||
state_dict = self.model.module.model.state_dict()
|
||||
for param in state_dict:
|
||||
if param not in self.checkpoint['model_state_dict']:
|
||||
print ('missing',param)
|
||||
elif (self.checkpoint['model_state_dict'][param].shape == state_dict[param].shape):
|
||||
new_model_state[param] = self.checkpoint['final_state_dict'][param]
|
||||
new_shadow_state[param] = self.checkpoint['model_state_dict'][param]
|
||||
else:
|
||||
print (
|
||||
'wrong size',param,
|
||||
self.checkpoint['model_state_dict'][param].shape,
|
||||
state_dict[param].shape )
|
||||
|
||||
self.model.module.model.load_state_dict(new_model_state, strict=False)
|
||||
self.model.module.shadow.load_state_dict(new_shadow_state, strict=False)
|
||||
print("Checkpoint loaded into model")
|
||||
|
||||
def load_optimizer(self):
|
||||
@@ -179,7 +196,7 @@ class Trainer:
|
||||
|
||||
def init_process_group(self, rank, world_size):
|
||||
gpu = rank % torch.cuda.device_count()
|
||||
dist.init_process_group(backend=self.config.training_params.ddp_backend, world_size=world_size, rank=rank)
|
||||
dist.init_process_group(backend=self.config.training_params.ddp_backend, timeout=timedelta(seconds=1800), world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device("cuda:%d"%gpu)
|
||||
return gpu
|
||||
|
||||
@@ -226,7 +243,6 @@ class Trainer:
|
||||
self.move_constants_to_device(gpu)
|
||||
|
||||
self.construct_model(device=gpu)
|
||||
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
if rank == 0:
|
||||
print(f"Loading model with {count_parameters(self.model)} parameters")
|
||||
|
||||
@@ -249,6 +265,11 @@ class Trainer:
|
||||
self.config.experiment.n_epoch,
|
||||
self.config.dataset_params.n_train,
|
||||
world_size)
|
||||
|
||||
#for d, valid_sampler in valid_samplers.items():
|
||||
# valid_sampler.set_epoch(start_epoch-1)
|
||||
#self.valid_epoch(start_epoch-1, rank, world_size)
|
||||
|
||||
for epoch in range(start_epoch,self.config.experiment.n_epoch):
|
||||
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
|
||||
self.train_epoch(epoch, rank, world_size)
|
||||
@@ -279,8 +300,27 @@ class Trainer:
|
||||
|
||||
# aggregate loss and update parameters
|
||||
loss = loss / self.config.ddp_params.accum
|
||||
|
||||
if (torch.any(torch.isnan(loss))):
|
||||
print ('NAN in loss',inputs[-1])
|
||||
print ('NAN in loss',loss_dict)
|
||||
exit(1)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
hasnans=False
|
||||
for n,p in self.model.named_parameters():
|
||||
if (p.grad is not None and torch.any(torch.isnan(p.grad))):
|
||||
hasnans = True
|
||||
if hasnans:
|
||||
print ('NAN in grad')
|
||||
for n,p in self.model.named_parameters():
|
||||
if (p.grad is not None):
|
||||
print (n, torch.max( torch.abs(p.flatten()) ), torch.max( torch.abs(p.grad.flatten()) ))
|
||||
exit(1)
|
||||
|
||||
train_time = time.time() - start_time
|
||||
|
||||
if train_idx%self.config.ddp_params.accum == 0:
|
||||
self.update_parameters()
|
||||
|
||||
@@ -345,6 +385,7 @@ class Trainer:
|
||||
def update_parameters(self):
|
||||
""" scale, clip gradients and update parameters """
|
||||
# gradient clipping
|
||||
#debug_grads(self.model)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_params.grad_clip)
|
||||
self.scaler.step(self.optimizer)
|
||||
@@ -390,6 +431,8 @@ class LegacyTrainer(Trainer):
|
||||
).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
|
||||
|
||||
|
||||
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
@@ -401,7 +444,7 @@ class LegacyTrainer(Trainer):
|
||||
= prepare_input(inputs, self.xyz_converter, gpu)
|
||||
|
||||
output_i = recycle_step_legacy(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=nograds)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, _, _ = get_loss_calc_items(inputs, device=gpu)
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
|
||||
@@ -410,7 +453,8 @@ class LegacyTrainer(Trainer):
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, None, None,
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
if return_outputs:
|
||||
@@ -429,6 +473,7 @@ class ComposedTrainer(Trainer):
|
||||
self.model = RosettaFold(self.config).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
|
||||
|
||||
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
@@ -452,7 +497,8 @@ class ComposedTrainer(Trainer):
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, None, None,
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
|
||||
@@ -468,7 +514,8 @@ class FlowMatchingTrainer(Trainer):
|
||||
self.model = RosettaFold(self.config).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
self.sampler = Sampler(self.model,
|
||||
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
|
||||
self.sampler = AllAtomSampler(self.model,
|
||||
self.config.interpolant.sampling.num_timesteps,
|
||||
self.config.interpolant.min_t,
|
||||
self.interpolant,
|
||||
@@ -482,11 +529,47 @@ class FlowMatchingTrainer(Trainer):
|
||||
|
||||
def train_step(self, inputs, n_cycle, no_grads=False, return_outputs=False):
|
||||
gpu = self.model.device
|
||||
#try:
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input_fm(inputs, self.interpolant, self.xyz_converter, gpu)
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label, \
|
||||
r3_t, trans_1, mask_allatom \
|
||||
= prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, gpu)
|
||||
|
||||
output_i = recycle_step_packed(self.model, network_input, 1, self.config.training_params.use_amp, nograds=no_grads)
|
||||
output_i = recycle_step_gen(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=no_grads)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask = get_loss_calc_items(inputs, device=gpu)
|
||||
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
|
||||
#loss = (pred_allatom - true_crds).mean()
|
||||
#loss_dict = {"loss": loss.mean()}
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
|
||||
msa = msa.to(gpu)
|
||||
mask_msa = mask_msa.to(gpu)
|
||||
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
|
||||
if return_outputs:
|
||||
return loss, loss_dict, output_i
|
||||
else:
|
||||
return loss, loss_dict
|
||||
|
||||
def valid_step(self, inputs, n_cycle, no_grads=True, return_outputs=False):
|
||||
gpu = self.model.device
|
||||
#try:
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label, \
|
||||
r3_t, trans_1, mask_allatom \
|
||||
= prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, gpu)
|
||||
|
||||
#output_i = recycle_step_gen(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=no_grads)
|
||||
with torch.no_grad():
|
||||
output_i = self.sampler.sample(inputs, n_cycle=n_cycle, use_amp=self.config.training_params.use_amp)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask = get_loss_calc_items(inputs, device=gpu)
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
@@ -497,37 +580,14 @@ class FlowMatchingTrainer(Trainer):
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
|
||||
unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
|
||||
if return_outputs:
|
||||
return loss, loss_dict, output_i
|
||||
else:
|
||||
return loss, loss_dict
|
||||
|
||||
def valid_step(self, inputs, n_cycle, no_grads=True, return_outputs=False):
|
||||
gpu = self.interpolant._device
|
||||
self.sampler.model = self.model
|
||||
task, item, network_input, true_crds, \
|
||||
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
|
||||
= prepare_input_fm(inputs, self.interpolant, self.xyz_converter, gpu)
|
||||
|
||||
output_i = self.sampler.sample(inputs)
|
||||
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask \
|
||||
= get_loss_calc_items(inputs, device=gpu)
|
||||
|
||||
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
|
||||
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
|
||||
msa = msa.to(gpu)
|
||||
mask_msa = mask_msa.to(gpu)
|
||||
|
||||
loss, loss_dict = get_loss_and_misc(
|
||||
self, # avoid reloading constants to device
|
||||
output_i, true_crds, atom_mask, same_chain,
|
||||
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
|
||||
self.config.loss_param
|
||||
)
|
||||
#fd last layer l0 are unused in grads
|
||||
#fd to do: fix this in refinement module
|
||||
loss += 0.0*output_i[-1].sum()
|
||||
|
||||
if return_outputs:
|
||||
return loss, loss_dict, output_i
|
||||
@@ -543,7 +603,8 @@ class FlowMatchingTrainer(Trainer):
|
||||
for valid_idx, inputs in enumerate(valid_loader):
|
||||
n_cycle = self.config.loader_params.maxcycle
|
||||
|
||||
loss, loss_dict = self.valid_step(inputs, n_cycle)
|
||||
loss, loss_dict = self.valid_step(inputs, n_cycle)
|
||||
#print (loss_dict)
|
||||
|
||||
if valid_loss_dict is None:
|
||||
valid_loss_dict = torch.zeros_like(torch.stack(list(loss_dict.values())))
|
||||
|
||||
@@ -18,7 +18,6 @@ def recycle_step_legacy(ddp_model, input, n_cycle, use_amp, nograds=False, force
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
#stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
if i_cycle < n_cycle -1 or nograds is True:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
@@ -41,7 +40,6 @@ def recycle_step_packed(ddp_model, input, n_cycle, use_amp, nograds=False, force
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
#stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
if i_cycle < n_cycle -1 or nograds is True:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
@@ -54,6 +52,30 @@ def recycle_step_packed(ddp_model, input, n_cycle, use_amp, nograds=False, force
|
||||
|
||||
return output_i
|
||||
|
||||
def recycle_step_gen(ddp_model, input, n_cycle, use_amp, nograds=False, force_device=None):
|
||||
""" exactly same logic as legacy recycling, except inputs and outputs are dictionaries"""
|
||||
if force_device is not None:
|
||||
gpu = force_device
|
||||
else:
|
||||
gpu = ddp_model.device
|
||||
|
||||
xyz_prev, alpha_prev, mask_recycle = \
|
||||
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
if i_cycle < n_cycle -1 or nograds is True:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
return_raw = (i_cycle < n_cycle -1)
|
||||
use_checkpoint = not nograds and (i_cycle == n_cycle -1)
|
||||
|
||||
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||
rf_outputs, rf_latents = ddp_model(input_i, use_checkpoint=use_checkpoint, use_amp=use_amp, skip_refinement=return_raw)
|
||||
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
|
||||
return output_i
|
||||
|
||||
|
||||
def run_model_forward(model, network_input, use_checkpoint=False, device="cpu"):
|
||||
""" run model forward pass, no recycling, no ddp (for tests) """
|
||||
gpu = device
|
||||
@@ -90,16 +112,38 @@ def unpack_outputs(rf_outputs, rf_latents, return_raw):
|
||||
msa, pair, state = rf_latents["msa"], rf_latents["pair"], rf_latents["state"]
|
||||
|
||||
if return_raw:
|
||||
xyz_prev = rf_outputs["xyzs"][-1][None]
|
||||
alpha_prev = rf_outputs["alphas"][-1]
|
||||
xyz_prev, alpha_prev = None, None
|
||||
if "xyzs" in rf_outputs:
|
||||
xyz_prev = rf_outputs["xyzs"][-1][None]
|
||||
if "alphas" in rf_outputs:
|
||||
alpha_prev = rf_outputs["alphas"][-1]
|
||||
return msa[:, 0], pair, xyz_prev, alpha_prev, None # mask_recycle is always None
|
||||
|
||||
else:
|
||||
c6d_logits, mlm_logits, pae_logits, plddt_logits = rf_outputs["c6d"], rf_outputs["mlm"], \
|
||||
rf_outputs["pae"], rf_outputs["plddt"]
|
||||
c6d_logits, mlm_logits, pae_logits, plddt_logits = None,None,None,None
|
||||
if "c6d" in rf_outputs:
|
||||
c6d_logits = rf_outputs["c6d"]
|
||||
if "mlm" in rf_outputs:
|
||||
mlm_logits = rf_outputs["mlm"]
|
||||
if "pae" in rf_outputs:
|
||||
pae_logits = rf_outputs["pae"]
|
||||
if "plddt" in rf_outputs:
|
||||
plddt_logits = rf_outputs["plddt"]
|
||||
pde_logits = None
|
||||
p_bind = None
|
||||
xyz, alphas = rf_outputs["xyzs"], rf_outputs["alphas"]
|
||||
|
||||
if "xyzs" in rf_outputs:
|
||||
xyz = rf_outputs["xyzs"]
|
||||
else:
|
||||
xyz = rf_outputs["xyz"][..., :3, :][None]
|
||||
if "state" in rf_outputs:
|
||||
state = rf_outputs["state"]
|
||||
B = 1
|
||||
L = xyz.shape[2]
|
||||
if "alphas" in rf_outputs:
|
||||
alphas = rf_outputs["alphas"]
|
||||
else:
|
||||
alphas = torch.zeros((1, B, L, ChemData().NTOTALDOFS, 2), device=xyz.device)
|
||||
if "xyz_intermediate" in rf_latents:
|
||||
intermediate_xyzs = torch.stack(rf_latents["xyz_intermediate"], dim=0)
|
||||
xyz = torch.cat((intermediate_xyzs, xyz), dim=0)
|
||||
@@ -107,8 +151,8 @@ def unpack_outputs(rf_outputs, rf_latents, return_raw):
|
||||
if "alpha_intermediate" in rf_latents:
|
||||
alpha_intermediate = torch.stack(rf_latents["alpha_intermediate"], dim=0)
|
||||
alphas = torch.cat((alpha_intermediate, alphas), dim=0)
|
||||
|
||||
xyz_allatom = None
|
||||
# HACK: breaking change for all atom flow matching
|
||||
xyz_allatom = rf_outputs["xyz"]
|
||||
return (c6d_logits, mlm_logits, pae_logits, pde_logits, p_bind,
|
||||
xyz, alphas, xyz_allatom, plddt_logits, msa[:, 0], pair, state)
|
||||
|
||||
|
||||
@@ -1469,7 +1469,8 @@ def get_residue_contacts(xyz, idx, seq_dist_greater_than=10, n_contacts=5):
|
||||
kk = k
|
||||
for j in nodes:
|
||||
ds[j,kk] = ds[kk,j] = 999.
|
||||
nodes.append(kk)
|
||||
if kk>=0:
|
||||
nodes.append(kk)
|
||||
nodes,_ = torch.sort(torch.stack(nodes))
|
||||
|
||||
return idx[nodes]
|
||||
|
||||
@@ -268,7 +268,7 @@ def make_full_graph(xyz, pair, idx):
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
|
||||
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
|
||||
return G, pair[b,i,j][...,None]
|
||||
|
||||
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-4):
|
||||
@@ -323,48 +323,8 @@ def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True,
|
||||
|
||||
return G, pair[b,i,j][...,None]
|
||||
|
||||
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
|
||||
B,L,A = xyz.shape[:3]
|
||||
device = xyz.device
|
||||
|
||||
D = torch.norm(
|
||||
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
|
||||
)
|
||||
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
|
||||
D[~mask2d] = 9999.
|
||||
D[D==0] = 9999.
|
||||
|
||||
# select top K neighbors for each atom
|
||||
# keep indices as batch/res/atm indices
|
||||
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
|
||||
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
|
||||
bi,ri,ai = mask.nonzero(as_tuple=True)
|
||||
bi = bi[:,None].repeat(1,top_k).reshape(-1)
|
||||
ri = ri[:,None].repeat(1,top_k).reshape(-1)
|
||||
ai = ai[:,None].repeat(1,top_k).reshape(-1)
|
||||
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
|
||||
|
||||
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
|
||||
edge = torch.full(ri.shape, maxbonds, device=device)
|
||||
resmask = ri==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
|
||||
resmask = ri+1==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
|
||||
resmask = ri-1==rj
|
||||
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
|
||||
edge = edge.clamp(0,maxbonds-1)
|
||||
edge = F.one_hot(edge)[...,None]
|
||||
|
||||
natm = torch.sum(mask)
|
||||
index = torch.zeros_like(mask, dtype=torch.long, device=device)
|
||||
index[mask] = torch.arange(natm, device=device)
|
||||
src=index[bi,ri,ai]
|
||||
tgt=index[bi,rj,aj]
|
||||
|
||||
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
|
||||
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
|
||||
|
||||
return G, edge
|
||||
|
||||
|
||||
# rotate about the x axis
|
||||
|
||||
Reference in New Issue
Block a user