Add af3-style 2track net

Add config file

Error in config file

Update config

Fix unit tests.  Update config

Turn of bias in attention norms

Add some temp fixes

Change amp to use bfloat16

More model stability changes

add allatom data transform

added in allatom noising and denoising

allatom flow matching trains now

add loss params in generative refinement yaml

Some fixes and reorg for diffusive training

Remove debug output; update config

Silly bugfix

Validation bugfix.  Adjust diffusive weights.

One more bugfix

some progress towards validation in allatom fm

Add recycling.  Add stability changes from trunk.

Fix rare bug in get_residue_contacts

more work towards sampler:

runs validation, but ooms

updates for diffusion infernece in validation

Sampler respects use_amp flag

Revert "Sampler respects use_amp flag"

This reverts commit 722c32dda150bc3f12167839bd4cbae2fe8026c5.

changes for training

Move hardcoded options to options system.

Add proper conversion of atom/bond level features

changes to allow gradients to flow through diffusion module

added mlm back into config

fix d_t bug

Several updates to allatom gen refinement: 1) DNA and ligand bond distances are correctly computed; 2) atom graph includes nearby-bonded neighbors even when distance is large; 3) small change to default parameters.

Update/bugfix BiasedSequenceAttention in refinement

Reenable checkpointing

Small bugfix.

Set bonded atoms to D=-1 so they are always preferred over D=0 atoms

Updates to AF2-like training

Change af3-like architecture config

Bugfixes in validation.  Small config updates for stable training

dt buig

added timstep embedding

Some very minor stability fixes

Msa module

Re commit accidentally reverted changes.  Track t in loss reporting.

Revert some of the scaling I was doing.

add alignment to vector field matching

Several bugfixes, code simplifications
This commit is contained in:
fdimaio
2024-05-03 18:34:49 -07:00
committed by Rohith Krishna
parent 22dd371a4e
commit d33c097ff5
37 changed files with 3044 additions and 594 deletions

View File

@@ -222,6 +222,7 @@ class ConvSE3(nn.Module):
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False,
sum_over_edge: bool = True,
low_memory: bool = False
):
"""
@@ -242,6 +243,7 @@ class ConvSE3(nn.Module):
self.self_interaction = self_interaction
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
self.sum_over_edge = sum_over_edge
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
# channels_in: account for the concatenation of edge features

View File

@@ -50,7 +50,10 @@ class NormSE3(nn.Module):
└──> feature_phase ────────────────────────────┘
"""
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
#NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
#fd w/o disconnected gradients this value still causes exploding grads
NORM_CLAMP = 2 ** -16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__()

View File

@@ -381,6 +381,54 @@ class ChemicalData:
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,"N" ,"C" ,"N" ,"N" ,"C" ,"C" ,"C" ,"O" ,"N" ,"C" ,"N" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None),#G
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,"N" ,"C" ,"O" ,"N" ,"C" ,"O" ,"C" ,"C" ,None,None,None,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None,None),#U
("O" ,"C" ,"C" ,"O" ,"P" ,"O" ,"O" ,"C" ,"C" ,"C" ,"O" ,"O" ,None,None,None,None,None,None,None,None,None,None,None,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,"H" ,None,None,None,None,None),#RX
("N", "C","C" ,"O", "C", "C", "N" ,"C" ,"C" ,"N" ,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),# HIS-D NOT CORRECT!!!!!!!!!!
(None,"Al",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Al
(None,"As",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# As
(None,"Au",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Au
(None,"B",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# B
(None,"Be",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Be
(None,"Br",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Br
(None,"C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# C
(None,"Ca",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ca
(None,"Cl",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cl
(None,"Co",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Co
(None,"Cr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cr
(None,"Cu",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cu
(None,"F",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# F
(None,"Fe",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Fe
(None,"Hg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Hg
(None,"I",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# I
(None,"Ir",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ir
(None,"K",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# K
(None,"Li",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Li
(None,"Mg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mg
(None,"Mn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mn
(None,"Mo",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mo
(None,"N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# N
(None,"Ni",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ni
(None,"O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# O
(None,"Os",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Os
(None,"P",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# P
(None,"Pb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pb
(None,"Pd",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pd
(None,"Pr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pr
(None,"Pt",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pt
(None,"Re",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Re
(None,"Rh",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Rh
(None,"Ru",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ru
(None,"S",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# S
(None,"Sb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sb
(None,"Se",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Se
(None,"Si",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Si
(None,"Sn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sn
(None,"Tb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Tb
(None,"Te",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Te
(None,"U",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# U
(None,"W",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# W
(None,"V",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# V
(None,"Y",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Y
(None,"Zn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Zn
(None,"ATM",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# ATM ]
]
# frames for generic FAPE
@@ -1171,6 +1219,9 @@ class ChemicalData:
],
]
# atom ids forming polymer connections
self.protein_connect = (0,2) # N, C
self.na_connect = (4,10) # P, O3'
else:
# USE PHOSPHATE FRAME
self.aa2long=[
@@ -1363,6 +1414,54 @@ class ChemicalData:
("O","P","O","O","C","C","O","C","O","C","C","O","N","C","N","N","C","C","C","O","N","C","N","H","H","H","H","H","H","H","H","H","H","H",None,None),#G
("O","P","O","O","C","C","O","C","O","C","C","O","N","C","O","N","C","O","C","C",None,None,None,"H","H","H","H","H","H","H","H","H","H",None,None,None),#U
("O","P","O","O","C","C","O","C","O","C","C","O",None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H",None,None,None,None,None),#RX
("N", "C","C" ,"O", "C", "C", "N" ,"C" ,"C" ,"N" ,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),# HIS-D NOT CORRECT!!!!!!!!!!
(None,"Al",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Al
(None,"As",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# As
(None,"Au",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Au
(None,"B",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# B
(None,"Be",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Be
(None,"Br",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Br
(None,"C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# C
(None,"Ca",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ca
(None,"Cl",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cl
(None,"Co",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Co
(None,"Cr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cr
(None,"Cu",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Cu
(None,"F",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# F
(None,"Fe",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Fe
(None,"Hg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Hg
(None,"I",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# I
(None,"Ir",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ir
(None,"K",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# K
(None,"Li",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Li
(None,"Mg",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mg
(None,"Mn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mn
(None,"Mo",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Mo
(None,"N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# N
(None,"Ni",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ni
(None,"O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# O
(None,"Os",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Os
(None,"P",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# P
(None,"Pb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pb
(None,"Pd",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pd
(None,"Pr",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pr
(None,"Pt",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Pt
(None,"Re",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Re
(None,"Rh",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Rh
(None,"Ru",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Ru
(None,"S",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# S
(None,"Sb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sb
(None,"Se",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Se
(None,"Si",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Si
(None,"Sn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Sn
(None,"Tb",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Tb
(None,"Te",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Te
(None,"U",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# U
(None,"W",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# W
(None,"V",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# V
(None,"Y",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Y
(None,"Zn",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# Zn
(None,"ATM",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None),# ATM ]
]
# frames for generic FAPE
@@ -2151,6 +2250,12 @@ class ChemicalData:
],
]
# atom ids forming polymer connections
self.protein_connect = (0,2) # N, C
self.na_connect = (1,8) # P, O3'
# general case
self.aabonds=[
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # ala
@@ -2444,6 +2549,8 @@ class ChemicalData:
for j,a in enumerate(i_l):
if (a is None):
self.long2alt[i,j] = j
elif ("H" in a):
self.long2alt[i,j] = i_l.index(a)
else:
self.long2alt[i,j] = i_lalt.index(a)
self.allatom_mask[i,j] = True
@@ -2453,14 +2560,18 @@ class ChemicalData:
self.allatom_mask[self.NNAPROTAAS:,1] = True
# bond graph traversal
self.num_bonds = torch.zeros((self.NAATOKENS,self.NTOTAL,self.NTOTAL), dtype=torch.long)
self.MAX_BOND_DIST = 9 # largest bond separation we consider
self.num_bonds = torch.full((self.NAATOKENS,self.NTOTAL,self.NTOTAL), self.MAX_BOND_DIST, dtype=torch.long)
self.num_bonds[:,torch.arange(self.NTOTAL),torch.arange(self.NTOTAL)] = 0 # atom self-interaction
# compute for all protein & na using csgraph.shortest_path
for i in range(self.NNAPROTAAS):
num_bonds_i = np.zeros((self.NTOTAL,self.NTOTAL))
for (bnamei,bnamej) in self.aabonds[i]:
bi,bj = self.aa2long[i].index(bnamei),self.aa2long[i].index(bnamej)
num_bonds_i[bi,bj] = 1
num_bonds_i = scipy.sparse.csgraph.shortest_path (num_bonds_i,directed=False)
num_bonds_i[num_bonds_i>=4] = 4
num_bonds_i[num_bonds_i>=self.MAX_BOND_DIST] = self.MAX_BOND_DIST
self.num_bonds[i,...] = torch.tensor(num_bonds_i)
@@ -2472,6 +2583,21 @@ class ChemicalData:
self.idx2aatype.append(y)
self.aatype2idx = {x:i for i,x in enumerate(self.idx2aatype)}
# element indices
self.idx2elt = []
for x in self.aa2elt:
for y in x:
if y and y not in self.idx2elt:
self.idx2elt.append(y)
self.elt2idx = {x:i for i,x in enumerate(self.idx2elt)}
self.aa2eltidx = torch.zeros((self.NAATOKENS,self.NTOTAL), dtype=torch.long)
for i in range(self.NAATOKENS):
for j in range(self.NTOTAL):
if self.aa2elt[i][j] is not None:
self.aa2eltidx[i,j] = self.elt2idx[ self.aa2elt[i][j] ]
self.NELTTYPES = len(self.elt2idx)
# LJ/LK scoring parameters
self.atom_type_index = torch.zeros((self.NAATOKENS,self.NTOTAL), dtype=torch.long)

222
rf2aa/config/train/af3.yaml Normal file
View File

@@ -0,0 +1,222 @@
defaults:
- base
experiment:
name: rf2aa-af3_v4
trainer: "flow_matching"
training_params:
resume_from_checkpoint_path: /home/dimaio/RF2-allatom-af3/rf2aa/models/rf2aa-af3_v4_last.pt
reset_optimizer_params: True
EMA: 0.99
weight_decay: 0.01
learning_rate: .001
learning_rate_schedule:
num_warmup_steps: 0
num_steps_decay: 5000
decay_rate: 0.95
grad_clip: 0.2
use_amp: False
ddp_backend: nccl
loader_params:
maxseq: 1024
maxtoken: 1024
maxlat: 1
ddp_params:
accum: 1
batch_size: 1
port: 12465
loss_param:
w_dist: 1.0
w_str: 0.0
w_inter_fape: 0.0
w_lig_fape: 0.0
w_lddt: 0.0
w_aa: 0.5
w_bond: 0.0
w_bind: 0.0
binder_loss_label_smoothing: 0.0
w_clash: 0.0
w_atom_bond: 0.0
w_skip_bond: 0.0
w_rigid: 0.0
w_hb: 0.0
w_pae: 0.00
w_pde: 0.00
lj_lin: 0.75
w_trans: 1.0
t_normalize_clip: 0.925
trans_scale: 1.0
interpolant:
min_t: 1e-2
separate_t: False
provide_kappa: False
hierarchical_t: False
codesign_separate_t: False
codesign_forward_fold_prop: 0.0
codesign_inverse_fold_prop: 0.0
twisting:
use: False
rots:
corrupt: True
train_schedule: linear
sample_schedule: linear
exp_rate: 10
trans:
corrupt: True
batch_ot: True
train_schedule: linear
sample_schedule: linear
sample_temp: 1.0
vpsde_bmin: 0.1
vpsde_bmax: 20.0
potential: null
potential_t_scaling: False
rog:
weight: 10.0
cutoff: 5.0
aatypes:
corrupt: False
schedule: linear
schedule_exp_rate: 10
temp: 1.0
noise: 0.0
do_purity: False
train_extra_mask: 0.0
interpolant_type: masking
num_tokens: 80
sampling:
num_timesteps: 10
do_sde: False
self_condition: False
dataset_params:
n_train: 25600
validate_every_n_epochs: 1
fraction_pdb: 0.12
fraction_fb: 0.36
fraction_compl: 0.055
fraction_neg_compl: 0
fraction_na_compl: 0.055
fraction_neg_na_compl: 0
fraction_distil_tf: 0.055
fraction_tf: 0
fraction_neg_tf: 0
fraction_rna: 0.02
fraction_dna: 0.005
fraction_sm_compl: 0.11
fraction_metal_compl: 0.03
fraction_sm_compl_multi: 0.025
fraction_sm_compl_covale: 0.025
fraction_sm: 0.0
fraction_atomize_pdb: 0.0425
fraction_atomize_complex: 0.0425
fraction_sm_compl_asmb: 0.055
n_valid_pdb: 256
n_valid_homo: 0
n_valid_dslf: 0
n_valid_compl: 256
n_valid_neg_compl: 0
n_valid_na_compl: 256
n_valid_neg_na_compl: 0
n_valid_distil_tf: 0
n_valid_tf: 0
n_valid_neg_tf: 0
n_valid_rna: 256
n_valid_dna: 256
n_valid_sm_compl: 256
n_valid_metal_compl: 0
n_valid_sm_compl_multi: 0
n_valid_sm_compl_covale: 0
n_valid_sm_compl_strict: 256
n_valid_sm: 0
n_valid_atomize_pdb: 0
n_valid_atomize_complex: 0
n_valid_sm_compl_asmb: 0
p_homo_cut: 0
p_short_crop: 0
p_dslf_crop: 0
dslf_fb_upsample: 1
model:
global_params:
d_msa: 384
d_msa_full: 64
d_pair: 128
d_state: 128
embedding:
rf2aa:
params:
p_drop: 0.15
d_templ: 64
n_head_templ: 4
d_hidden_templ: 64
templ_p_drop: 0.25
symmetrize_repeats: False
repeat_length: null
symmsub_k: null
sym_method: null
main_block: null
copy_main_block_template: False
additional_dt1d: 0
recycling_type: "msa_pair_only"
use_same_chain: False
blocks:
AF3_full:
num_blocks: 4
params:
p_drop_row: 0.25
p_drop_pair: 0.25
msa_transition_drop: 0.0
outer_product_channels: 16
p_drop_outer_product: 0.0
n_pair_head: 6
n_pair_channels: 32
n_msa_head: 8
n_msa_channels: 8
norm_msa_row: False
AF3:
num_blocks: 48
params:
p_drop_row: 0.25
p_drop_pair: 0.25
msa_transition_drop: 0.0
outer_product_channels: 16
p_drop_outer_product: 0.0
n_pair_head: 6
n_pair_channels: 32
n_msa_head: 8
n_msa_channels: 32
norm_msa_row: False
refinement:
generative:
params:
num_attention_layers: 24
num_channels: 32
num_degrees: 2
num_layers: 3
n_heads: 4
div: 4
l0_in_features: 128
l0_out_features: 128
num_edge_features: 32
n_channels: 32
msa_transition_drop: 0.0
compute_gradients: True
auxiliary_predictors:
c6d:
n_feat: 128
input_feature: "pair"
mlm:
n_feat: 384
input_feature: "msa"

View File

@@ -7,7 +7,7 @@ experiment:
model:
embedding: null
blocks: null
refinment: null
refinement: null
auxiliary_predictors: {}
legacy_model: null
dataset_params:
@@ -130,6 +130,7 @@ loss_param:
w_hb: 0.0
w_pae: 0.05
w_pde: 0.05
w_trans: 0.0
lj_lin: 0.75
log_params:

View File

@@ -20,7 +20,7 @@ interpolant:
rots:
corrupt: True
train_schedule: linear
sample_schedule: exp
sample_schedule: linear
exp_rate: 10
trans:
@@ -52,5 +52,3 @@ interpolant:
num_timesteps: 100
do_sde: False
self_condition: False
training_params:
resume_from_checkpoint_path: /net/tukwila/rohith/assorted_weights/models/rf2aa-baseline-noinplace_134.pt

View File

@@ -0,0 +1,148 @@
defaults:
- rf2aa
experiment:
name: rfaa-flow-matching-allatom
trainer: "flow_matching"
interpolant:
min_t: 1e-2
separate_t: False
provide_kappa: False
hierarchical_t: False
codesign_separate_t: False
codesign_forward_fold_prop: 0.0
codesign_inverse_fold_prop: 0.0
twisting:
use: False
rots:
corrupt: True
train_schedule: linear
sample_schedule: linear
exp_rate: 10
trans:
corrupt: True
batch_ot: True
train_schedule: linear
sample_schedule: linear
sample_temp: 1.0
vpsde_bmin: 0.1
vpsde_bmax: 20.0
potential: null
potential_t_scaling: False
rog:
weight: 10.0
cutoff: 5.0
aatypes:
corrupt: False
schedule: linear
schedule_exp_rate: 10
temp: 1.0
noise: 0.0
do_purity: False
train_extra_mask: 0.0
interpolant_type: masking
num_tokens: 80
sampling:
num_timesteps: 20
do_sde: False
self_condition: False
model:
blocks:
RF2aa_full:
num_blocks: 1
params:
d_rbf: 64
p_drop_row: 0.25
p_drop_pair: 0.25
p_drop_layer: 0.0
msa_transition_drop: 0.0
outer_product_channels: 16
p_drop_outer_product: 0.0
n_pair_head: 6
n_pair_channels: 32
n_msa_head: 8
n_msa_channels: 8
structure_bias_gate_channels: 16
structure_bias_channels: 64
n_se3_layers: 1
n_se3_channels: 32
n_se3_degrees: 2
n_se3_head: 4
n_div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 6
l1_out_features: 2
n_se3_edge_features: 64
sc_pred_d_hidden: 128
sc_pred_p_drop: 0.0
residual_state: True
RF2aa:
num_blocks: 1
refinement:
generative:
params:
num_atom_encode_layers: 1
num_attention_layers: 1
num_channels: 32
num_degrees: 2
n_heads: 4
div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 1
num_edge_features: 4
n_channels: 32
msa_transition_drop: 0.0
dataset_params:
n_train: 24000
validate_every_n_epochs: 1
fraction_pdb: 0.12
fraction_fb: 0.36
fraction_compl: 0.055
fraction_neg_compl: 0
fraction_na_compl: 0.055
fraction_neg_na_compl: 0
fraction_distil_tf: 0.055
fraction_tf: 0
fraction_neg_tf: 0
fraction_rna: 0.02
fraction_dna: 0.005
fraction_sm_compl: 0.11
fraction_metal_compl: 0.03
fraction_sm_compl_multi: 0.025
fraction_sm_compl_covale: 0.025
fraction_sm: 0.0
fraction_atomize_pdb: 0.0425
fraction_atomize_complex: 0.0425
fraction_sm_compl_asmb: 0.055
loss_param:
w_trans: 10.0
t_normalize_clip: 0.9
trans_scale: 0.1
w_dist: 1.0
w_str: 0.0
w_inter_fape: 0.0
w_lig_fape: 0.0
w_lddt: 0.0
w_aa: 3.0
w_bond: 0.0
w_bind: 0.0
binder_loss_label_smoothing: 0.0
w_clash: 0.0
w_atom_bond: 0.0
w_skip_bond: 0.0
w_rigid: 0.0
w_hb: 0.0
w_pae: 0.0
w_pde: 0.0
lj_lin: 0.75

View File

@@ -30,6 +30,7 @@ model:
n_se3_edge_features: 32
residual_state: True
norm_msa_row: True
bias_in_attn_norm: True
use_flash_attn: True
RF2aa:
@@ -41,6 +42,7 @@ model:
n_se3_edge_features: 32
residual_state: True
norm_msa_row: True
bias_in_attn_norm: True
use_flash_attn: True
refinement:

View File

@@ -299,7 +299,7 @@ def MSAFeaturize(
ins,
params,
p_mask=0.15,
eps=1e-6,
eps=1e-4,
nmer=1,
L_s=[],
term_info=None,

View File

@@ -16,7 +16,6 @@ def prepare_input(inputs, xyz_converter, gpu):
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
# transfer inputs to device
B, _, N, L = msa.shape
@@ -141,6 +140,7 @@ def get_loss_calc_items(inputs,device="cpu"):
return seq.to(device), same_chain.to(device), idx_pdb.to(device), bond_feats.to(device), \
dist_matrix.to(device), atom_frames.to(device), true_crds.to(device), mask_crds.to(device)
# prepate inputs for flow matching
def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
@@ -158,7 +158,7 @@ def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
msa = msa.to(device, non_blocking=True)
atom_frames = atom_frames.to(device, non_blocking=True)
interpolant.set_device(device)
#interpolant.set_device(device)
batch = data_transforms.convert_dataloader_inputs_to_rigids(inputs, interpolant._device)
noisy_batch = interpolant.corrupt_batch(batch)
rotmats, trans = noisy_batch["rotmats_t"], noisy_batch["trans_t"]
@@ -213,6 +213,62 @@ def prepare_input_fm(inputs, interpolant, xyz_converter, device="cpu"):
symmRs = None
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label
def prepare_input_fm_allatom(inputs, interpolant, xyz_converter, device="cpu"):
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input(inputs, xyz_converter, device)
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(device, non_blocking=True) # (B, L)
true_crds = true_crds.to(device, non_blocking=True) # (B, L, 27, 3)
mask_crds = mask_crds.to(device, non_blocking=True) # (B, L, 27)
same_chain = same_chain.to(device, non_blocking=True)
t1d = t1d.to(device, non_blocking=True)
seq = seq.to(device, non_blocking=True)
msa = msa.to(device, non_blocking=True)
atom_frames = atom_frames.to(device, non_blocking=True)
interpolant.set_device(device)
allatom_mask = ChemData().allatom_mask.to(device, non_blocking=True)
# remove symmetry dimension
if len(true_crds.shape) == 5:
true_crds = true_crds[:, 0:1]
mask_crds = mask_crds[:, 0:1]
else:
true_crds = true_crds[None]
mask_crds = mask_crds[None]
# want to unroll the coordinate tensors to get the full coordinates in (atoms, 3)
is_real_atom = allatom_mask[msa[:,0,0]][None].bool()
atom_coords = true_crds[is_real_atom]
atom_mask = mask_crds[is_real_atom]
t = interpolant.sample_t(B)[:, None]
#t = torch.tensor([[0.95]], device=t.device)
trans_1 = center_allatom_chain(atom_coords, atom_mask)
#fd what to do with "real" atoms that do not have a native position?
#fd -> lets try the origin
trans_1 = trans_1.nan_to_num()
diffuse_mask = torch.ones_like(atom_mask).long()
trans_t = interpolant._corrupt_trans(trans_1[None], t, atom_mask[None], diffuse_mask[None])
xyz_noised_t = torch.zeros((1, B, L, ChemData().NTOTAL, 3), device=device)
xyz_noised_t[is_real_atom] = trans_t
# add trans_t (and t) to network input
network_input["trans_t"] = xyz_noised_t
network_input["t"] = t
return task, item, network_input, true_crds, mask_crds, msa, mask_msa, \
unclamp, negative, symmRs, Lasu, ch_label, t, trans_1, atom_mask
def construct_template_feats(xyz_t, mask_t, t1d, seq, atom_frames, xyz_converter, use_atom_frames=False):
B, T, Lasu, _, _ = xyz_t.shape
@@ -243,3 +299,16 @@ def construct_template_feats(xyz_t, mask_t, t1d, seq, atom_frames, xyz_converter
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, Lasu*Osub, 3*ChemData().NTOTALDOFS)
alpha_prev = torch.zeros((B,Lasu*Osub,ChemData().NTOTALDOFS,2))
return t2d, mask_t_2d, alpha_t, alpha_prev
def center_allatom_chain(atom_coords, atom_mask):
"""
center allatom coordinates at origin
"""
assert len(atom_coords.shape) == 2
assert len(atom_mask.shape) == 1
assert atom_coords.shape[-1] == 3
atom_coords_allatom = atom_coords[atom_mask]
atom_coords_allatom = atom_coords_allatom - atom_coords_allatom.mean(dim=0)
atom_coords[atom_mask] = atom_coords_allatom
return atom_coords

View File

@@ -28,3 +28,7 @@ def debug_grads(model):
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.norm().item()}")
def debug_nan_params(model):
for name, param in model.named_parameters():
print(f"{name}: {torch.sum(param.isnan())}")

View File

@@ -191,7 +191,7 @@ def calculate_neighbor_angles(R_ac, R_ab):
# sin(alpha) = |u x v| / (|u|*|v|)
y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
# avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
y = torch.max(y, torch.tensor(1e-9))
y = torch.max(y, torch.tensor(1e-4))
angle = torch.atan2(y, x)
return angle

View File

@@ -8,14 +8,9 @@ import pickle
import os
import torch
from typing import List, Dict, Any
#import tree
from rf2aa.flow_matching import rigid_utils as ru
from torch_scatter import scatter_add, scatter
#from Bio.PDB.Chain import Chain
#from Bio import PDB
#from cogen.data import protein, residue_constants, parsers
from torch_geometric.utils import scatter
from glob import glob
#from pytorch_lightning.utilities import rank_zero_only
Rigid = ru.Rigid
#Protein = protein.Protein
@@ -140,19 +135,19 @@ def adjust_oxygen_pos(
# Calpha to carbonyl both in the current frame.
calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / (
torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7
torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-4
)
# For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0.
# The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change.
# For masked positions, they are all 0 and so we add 1e-4 to avoid division by 0.
# The positions are in Angstroms and so are on the order ~1 so 1e-4 is an insignificant change.
# Nitrogen of the next frame to carbonyl of the current frame.
nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / (
torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7
torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-4
)
carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl # (N-1, 3)
carbonyl_to_oxygen = carbonyl_to_oxygen / (
torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7
torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-4
)
atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23
@@ -161,17 +156,17 @@ def adjust_oxygen_pos(
# Calpha to carbonyl both in the current frame. (N, 3)
calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / (
torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-4
)
# Calpha to nitrogen both in the current frame. (N, 3)
calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / (
torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-4
)
carbonyl_to_oxygen_term: torch.Tensor = (
calpha_to_carbonyl_term + calpha_to_nitrogen_term
) # (N, 3)
carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7
torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-4
)
# Create a mask that is 1 when the next residue is not available either
@@ -236,7 +231,7 @@ def chain_str_to_int(chain_str: str):
#chain_feats['bb_mask'] = chain_feats['atom_mask'][:, CA_IDX]
#bb_pos = chain_feats['atom_positions'][:, CA_IDX]
#if center:
#bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
#bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-4)
#centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
#scaled_pos = centered_pos / scale_factor
#else:
@@ -341,8 +336,8 @@ def align_structures(
reference_positions = center_zero(reference_positions, batch_indices)
# Compute covariance matrix for optimal rotation (Q.T @ P) -> [B x 3 x 3].
cov = scatter_add(
batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0
cov = scatter(
batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0, reduce="sum"
)
# Perform singular value decomposition. (all [B x 3 x 3])

View File

@@ -105,14 +105,28 @@ class Interpolant:
aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
batch_nm_0, batch_nm_1, mask=batch_mask
)
# fd shortcut with batch 1
if num_batch == 1:
return aligned_nm_0
aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)
# Compute cost matrix of aligned noise to ground truth
batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
#fd ensure masked positions are zeroed out
cost_matrix = torch.sum(
torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
torch.linalg.norm(
batch_mask[...,None] * (
aligned_nm_0.nan_to_num() -
aligned_nm_1.nan_to_num()
), dim=-1
), dim=-1
) / torch.sum(batch_mask, dim=-1)
#fd gpu->cpu slowdown
noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]

View File

@@ -564,7 +564,7 @@ class Rotation:
else:
raise ValueError("Both rotations are None")
def get_rotvec(self, eps=1e-6) -> torch.Tensor:
def get_rotvec(self, eps=1e-4) -> torch.Tensor:
"""
Return the underlying axis-angle rotation vector.
@@ -1265,7 +1265,7 @@ class Rigid:
p_neg_x_axis: torch.Tensor,
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8
eps: float = 1e-4
):
"""
Implements algorithm 21. Constructs transformations from sets of 3

View File

@@ -4,8 +4,10 @@ from typing import Any, Dict, Tuple
from rf2aa.flow_matching.interpolant import _centered_gaussian, _uniform_so3
import rf2aa.flow_matching.data_utils as du
from rf2aa.flow_matching import data_transforms
from rf2aa.training.recycling import recycle_step_packed
from rf2aa.data.dataloader_adaptor import prepare_input_fm, construct_template_feats
from rf2aa.training.recycling import recycle_step_packed, recycle_step_gen
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.dataloader_adaptor import prepare_input_fm, construct_template_feats, prepare_input_fm_allatom
from rf2aa.training.recycling import unpack_outputs
from rf2aa.util import rigid_from_3_points, writepdb_file
@@ -19,7 +21,7 @@ class Sampler:
self.xyz_converter = xyz_converter
self.is_training = is_training
def sample(self, inputs: Tuple[str, Any]) -> Dict[str, Any]:
def sample(self, inputs: Tuple[str, Any], use_amp=False) -> Dict[str, Any]:
# first receive inputs from dataloader
# convert them into features
network_input = self._get_network_input(inputs)
@@ -37,7 +39,7 @@ class Sampler:
updated_features = self._construct_xt_features(rotmats_t_1, trans_t_1, network_input)
network_input.update(updated_features)
# run model
output_i = recycle_step_packed(self.model, network_input,1, use_amp=True, nograds=True)
output_i = recycle_step_packed(self.model, network_input,1, use_amp=use_amp, nograds=True)
xyz = output_i[5][-1]
N, Ca, C = xyz[...,0, :], xyz[...,1, :], xyz[...,2, :]
px1s.append(xyz)
@@ -48,13 +50,13 @@ class Sampler:
rotmats_t_1, trans_t_1, d_t, t_1)
# set prev_rots, prev_trans to curr
rotmats_t_1, trans_t_1 = rotmats_t_2, trans_t_2
t_1 = t_2
#
# only integrated to min_t, still need to take final step
updated_features = self._construct_xt_features(rotmats_t_1, trans_t_1, network_input)
network_input.update(updated_features)
# run model
output_i = recycle_step_packed(self.model, network_input,1, use_amp=True, nograds=True)
output_i = recycle_step_packed(self.model, network_input,1, use_amp=use_amp, nograds=True)
# return the updated positions
mask = torch.ones(xyz.shape[:-1], device=xyz.device).bool()
mask = F.pad(mask, (0,33))
@@ -113,4 +115,88 @@ class Sampler:
for i, xyz in enumerate(xyz_list):
writepdb_file(f, xyz, seq, modelnum=i, atom_mask=mask)
class AllAtomSampler(Sampler):
""" sampler for model which predicts all atom positions, not frames/torsions """
def __init__(self, model, num_timesteps, min_t, interpolant, xyz_converter, is_training) -> None:
super().__init__(model, num_timesteps, min_t, interpolant, xyz_converter, is_training)
self.allatom_mask = ChemData().allatom_mask.to(self.device)
def sample(self, inputs: Tuple[str, Any], n_cycle=1, use_amp=False) -> Dict[str, Any]:
# first receive inputs from dataloader
# convert them into features
network_input = self._get_network_input(inputs)
ts = torch.linspace(self.min_t, 1.0, self.num_timesteps)
# create prior
seq_unmasked = network_input["seq_unmasked"].to(self.device)
trans_t_1 = self._setup_prior(seq_unmasked)
network_input["trans_t"] = trans_t_1[None]
# run first model fwd pass to get evoformer features
output_i = recycle_step_gen(self.model, network_input, n_cycle, use_amp=use_amp, nograds=True)
latent_feats = {
"msa": output_i[-3],
"pair": output_i[-2],
"state": output_i[-1],
"seq_unmasked": seq_unmasked,
"dist_matrix": network_input["dist_matrix"].to(self.device),
"idx": network_input["idx"].to(self.device),
"trans_t": trans_t_1[None],
"t": network_input["t"]
}
output_i_trunk = output_i
# run the model refinement for n_steps
output_i, px1, xts = self.run_refiner(latent_feats, ts)
# HACK: get features from evoformer, this needs to become a dictionary to allow for assignment
output_i = list(output_i)
for i in range(len(output_i)):
if output_i[i] is None:
output_i[i] = output_i_trunk[i]
return tuple(output_i)
def _setup_prior(self, seq_unmasked):
B, L = seq_unmasked.shape
xyz = torch.zeros(B, L, ChemData().NTOTAL, 3, device=self.device)
is_real_atom = self.allatom_mask[seq_unmasked]
num_atoms = is_real_atom.sum()
xyz[is_real_atom] = _centered_gaussian(B, num_atoms, self.device) * du.NM_TO_ANG_SCALE
return xyz
def _get_network_input(self, inputs):
out = prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, device=self.device)
network_input = out[2]
return network_input
def _take_step(self, pred_trans_1, trans_t_1, d_t, t_1, seq_unmasked):
is_real_atom = self.allatom_mask[seq_unmasked]
trans_t_2_rolled = pred_trans_1.clone()
pred_trans_1 = pred_trans_1[is_real_atom]
trans_t_1 = trans_t_1[is_real_atom]
trans_t_2 = self.interpolant._trans_euler_step(
d_t, t_1, pred_trans_1, trans_t_1)
trans_t_2_rolled[is_real_atom] = trans_t_2
return trans_t_2_rolled
def run_refiner(self, latent_feats, ts):
px1s = []
xts = []
t_1 = ts[0]
trans_t_1 = latent_feats["trans_t"][0]
for t_2 in ts[1:]:
d_t = t_2-t_1
# collect features for each step
pred_trans_1 = self._run_diffusion_step(latent_feats)
xts.append(trans_t_1)
trans_t_2 = self._take_step(pred_trans_1, trans_t_1, d_t, t_1, seq_unmasked=latent_feats["seq_unmasked"])
trans_t_1 = trans_t_2
latent_feats["trans_t"] = trans_t_1[None]
t_1 = t_2
outputs = {}
outputs["xyz"] = pred_trans_1
outputs["state"] = latent_feats["state"]
output_i = unpack_outputs(outputs, latent_feats, return_raw=False)
return output_i, px1s, xts
def _run_diffusion_step(self, latent_feats):
outputs = self.model.module.model.refinement(latent_feats)
pred_trans_1 = outputs["xyz"]
return pred_trans_1

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
def scale_rotmat(
rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-7
rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-4
) -> torch.Tensor:
"""
Scale rotation matrix. This is done by converting it to vector representation,
@@ -83,7 +83,7 @@ def skew_matrix_exponential_map_axis_angle(
def skew_matrix_exponential_map(
angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-7
angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-4
) -> torch.Tensor:
"""
Compute the matrix exponential of a rotation vector in skew matrix representation. Maps the
@@ -137,7 +137,7 @@ def skew_matrix_exponential_map(
return exp_skew
def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-7) -> torch.Tensor:
def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-4) -> torch.Tensor:
"""
Convert rotation vectors to rotation matrix representation. The length of the rotation vector
is the angle of rotation, the unit vector the rotation axis.
@@ -232,7 +232,7 @@ def rotmat_to_rotvec(rotation_matrices: torch.Tensor) -> torch.Tensor:
skew_outer = skew_outer + (torch.relu(skew_outer) - skew_outer) * id3
# Get basic rotation vector as sqrt of diagonal (is unit vector).
vector_pi = torch.sqrt(torch.diagonal(torch.clamp(skew_outer, min=1e-8), dim1=-2, dim2=-1))
vector_pi = torch.sqrt(torch.diagonal(torch.clamp(skew_outer, min=1e-4), dim1=-2, dim2=-1))
# Compute the signs of vector elements (up to a global phase).
# Fist select indices for outer product slices with the largest norm.
@@ -326,7 +326,7 @@ def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor:
def _rotquat_to_axis_angle(
rotation_quaternions: torch.Tensor, tol: float = 1e-7
rotation_quaternions: torch.Tensor, tol: float = 1e-4
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Auxiliary routine for computing rotation angle and rotation axis from unit quaternions. To avoid
@@ -334,7 +334,7 @@ def _rotquat_to_axis_angle(
Args:
rotation_quaternions (torch.Tensor): Rotation quaternions in [r, i, j, k] format.
tol (float, optional): Threshold for small rotations. Defaults to 1e-7.
tol (float, optional): Threshold for small rotations. Defaults to 1e-4.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rotation angles and axes.
@@ -385,7 +385,7 @@ def rotquat_to_rotmat(rotation_quaternions: torch.Tensor) -> torch.Tensor:
def apply_rotvec_to_rotmat(
rotation_matrices: torch.Tensor,
rotation_vectors: torch.Tensor,
tol: float = 1e-7,
tol: float = 1e-4,
) -> torch.Tensor:
"""
Update a rotation encoded in a rotation matrix with a rotation vector.
@@ -602,7 +602,7 @@ class BaseSampleSO3(nn.Module):
num_omega: int,
sigma_grid: torch.Tensor,
omega_exponent: int = 3,
tol: float = 1e-7,
tol: float = 1e-4,
interpolate: bool = True,
cache_dir: Optional[str] = None,
overwrite_cache: bool = False,
@@ -625,7 +625,7 @@ class BaseSampleSO3(nn.Module):
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
its power with the provided number. Defaults to 3.
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
cache_dir: Path to an optional cache directory. If set to None, lookup tables are
@@ -907,7 +907,7 @@ class SampleIGSO3(BaseSampleSO3):
num_omega: int,
sigma_grid: torch.Tensor,
omega_exponent: int = 3,
tol: float = 1e-7,
tol: float = 1e-4,
interpolate: bool = True,
l_max: int = 1000,
cache_dir: Optional[str] = None,
@@ -931,7 +931,7 @@ class SampleIGSO3(BaseSampleSO3):
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
its power with the provided number. Defaults to 3.
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
l_max (int, optional): Maximum number of terms used in the series expansion.
@@ -1022,7 +1022,7 @@ class SampleUSO3(BaseSampleSO3):
num_omega: int,
sigma_grid: torch.Tensor,
omega_exponent: int = 3,
tol: float = 1e-7,
tol: float = 1e-4,
interpolate: bool = True,
cache_dir: Optional[str] = None,
overwrite_cache: bool = False,
@@ -1044,7 +1044,7 @@ class SampleUSO3(BaseSampleSO3):
sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
its power with the provided number. Defaults to 3.
tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
tol (float, optional): Small value for numerical stability. Defaults to 1e-4.
interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
cache_dir: Path to an optional cache directory. If set to None, lookup tables are
@@ -1122,7 +1122,7 @@ class ScoreSO3(nn.Module):
sigma_grid: torch.Tensor,
omega_exponent: int = 3,
l_max: int = 1000,
tol: float = 1e-7,
tol: float = 1e-4,
cache_dir: Optional[str] = None,
overwrite_cache: bool = False,
) -> None:
@@ -1359,7 +1359,7 @@ def uniform_so3_density(omega: torch.Tensor) -> torch.Tensor:
def igso3_expansion(
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
) -> torch.Tensor:
"""
Compute the IGSO(3) angle probability distribution function for pairs of angles and std dev
@@ -1418,7 +1418,7 @@ def igso3_expansion(
def digso3_expansion(
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
) -> torch.Tensor:
"""
Compute the derivative of the IGSO(3) angle probability distribution function with respect to
@@ -1477,7 +1477,7 @@ def digso3_expansion(
def dlog_igso3_expansion(
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-4
) -> torch.Tensor:
"""
Compute the derivative of the logarithm of the IGSO(3) angle distribution function for pairs of
@@ -1509,7 +1509,7 @@ def generate_lookup_table(
omega_grid: torch.Tensor,
sigma_grid: torch.Tensor,
l_max: int = 1000,
tol: float = 1e-7,
tol: float = 1e-4,
):
"""
Auxiliary function for generating a lookup table from IGSO(3) expansions and their derivatives.
@@ -1550,7 +1550,7 @@ def generate_igso3_lookup_table(
omega_grid: torch.Tensor,
sigma_grid: torch.Tensor,
l_max: int = 1000,
tol: float = 1e-7,
tol: float = 1e-4,
) -> torch.Tensor:
"""
Generate a lookup table for the IGSO(3) probability distribution function of angles.
@@ -1579,7 +1579,7 @@ def generate_dlog_igso3_lookup_table(
omega_grid: torch.Tensor,
sigma_grid: torch.Tensor,
l_max: int = 1000,
tol: float = 1e-7,
tol: float = 1e-4,
) -> torch.Tensor:
"""
Generate a lookup table for the derivative of the logarithm of the angular IGSO(3) probability
@@ -1650,7 +1650,7 @@ def Exp(A): return exp(hat(A))
# Angle of rotation SO(3) to R^+, this is the norm in our chosen orthonormal basis
def Omega(R, eps=1e-6):
def Omega(R, eps=1e-4):
# multiplying by (1-epsilon) prevents instability of arccos when provided with -1 or 1 as input.
R_ = R.to(torch.float64)
assert not torch.any(torch.abs(R) > 1.1)

View File

@@ -17,6 +17,8 @@ from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.kinematics import get_dih, get_ang
from rf2aa.scoring import HbHybType
from rf2aa.flow_matching import data_utils as du
from typing import List, Dict, Optional
import logging
logger = logging.getLogger(__name__)
@@ -1655,28 +1657,70 @@ def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain,
lddt = torch.cat(lddt_s, dim=1) # (N, L, Natm)
# per-res
final_lddt_by_res = torch.clamp(
(lddt[-1]*atm_mask[0]).sum(-1)
/ (atm_mask.sum(-1) + eps), min=0.0, max=1.0)
# calculate lddt prediction loss
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
pred_lddt, true_lddt_label[-1])
res_mask = atm_mask.any(dim=-1)
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
# method 1: average per-residue
#lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-4) # L
#lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-4)
# method 2: average per-atom
# per-struct
atm_mask = atm_mask * (pair_mask_accum != 0)
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
# calculate lddt prediction loss
if pred_lddt is not None:
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
pred_lddt, true_lddt_label[-1])
res_mask = atm_mask.any(dim=-1)
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
else:
lddt_loss = None # no pred lddt provided
return lddt_loss, lddt
def rms_aln_tgt(predin, true, mask):
def centroid(X):
return X.mean(dim=-2, keepdim=True)
pred = predin[mask]
true = true[mask]
pred = pred - centroid(pred)
cT = centroid(true)
true = true - cT
C = torch.einsum('ji,jk->ik', pred, true)
V, S, W = torch.svd(C)
d = torch.ones([3,3], device=pred.device)
d[:,-1] = torch.sign(torch.det(V)*torch.det(W)).unsqueeze(0)
U = torch.matmul(d*V, W.permute(1,0)) # (3, 3)
rpred = torch.matmul(predin, U) + cT
return rpred
def translation_vector_field(pred_trans_1, noaln_gt_trans_1, mask, r3_t, params):
gt_trans_1 = rms_aln_tgt(noaln_gt_trans_1, pred_trans_1.detach(), mask)
return translation_vector_field_noaln(pred_trans_1, gt_trans_1, mask, r3_t, params)
def translation_vector_field_noaln(pred_trans_1, gt_trans_1, mask, r3_t, params):
t_normalize_clip = params.t_normalize_clip
trans_scale = params.trans_scale # global scale
t_dep_scale = 1 - torch.min( # t-dependant scale
r3_t[..., None], torch.tensor(t_normalize_clip)
) # (B, 1, 1)
trans_error = trans_scale / t_dep_scale * (gt_trans_1 - pred_trans_1)
loss_denom = 3 * torch.sum(mask)
trans_loss = torch.sum(
trans_error*trans_error*mask[...,None], dim=(-1,-2)
) / loss_denom
return trans_loss

View File

@@ -6,7 +6,8 @@ from rf2aa.kinematics import xyz_to_c6d, c6d_to_bins
from rf2aa.loss.loss import resolve_equiv_natives, resolve_equiv_natives_asmb, \
resolve_symmetry_predictions, resolve_symmetry, mask_unresolved_frames, \
compute_general_FAPE, torsionAngleLoss, calc_lddt, calc_allatom_lddt_loss, \
calc_crd_rmsd, calc_BB_bond_geom, calc_lj, calc_atom_bond_loss
calc_crd_rmsd, calc_BB_bond_geom, calc_lj, calc_atom_bond_loss,translation_vector_field
from rf2aa.util import is_atom, is_protein, Ls_from_same_chain_2d, get_prot_sm_mask, \
xyz_to_frame_xyz, get_frames, writepdb
from rf2aa.chemical import ChemicalData as ChemData
@@ -17,20 +18,16 @@ cce_loss = nn.CrossEntropyLoss(reduction='none')
def get_loss_and_misc(
trainer,
output_i, true_crds, atom_mask, same_chain,
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
seq, msa, mask_msa, idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
unclamp, negative, task, item, symmRs, Lasu, ch_label,
loss_param
):
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
if pred_allatom is None:
_, pred_allatom = trainer.xyz_converter.compute_all_atom(msa[0][0][None],pred_crds[-1], alphas[-1])
#pred_crds = pred_crds[:, None]
#alphas = alphas[:, None]
if (symmRs is not None):
###
# resolve symmetry
###
if symmRs is not None:
true_crds = true_crds[:,0]
atom_mask = atom_mask[:,0]
mapT2P = resolve_symmetry_predictions(pred_crds, true_crds, atom_mask, Lasu) # (Nlayer, Ltrue)
@@ -88,230 +85,233 @@ def get_loss_and_misc(
c6d = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
# contact accuray not as useful to track anymore
#prob = self.active_fn(logit_s[0]) # distogram
#acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
loss, loss_dict = calc_loss(
trainer, logit_s, c6d,
trainer, loss_param, logit_s, c6d,
logit_aa_s, msa, mask_msa, logit_pae, logit_pde, p_bind,
pred_crds, alphas, pred_allatom, true_crds,
pred_crds, alphas, pred_allatom, trans_1, r3_t, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, bond_feats, dist_matrix,
atom_frames=atom_frames,unclamp=unclamp, negative=negative,
item=item, task=task, **loss_param
item=item, task=task
)
return loss, loss_dict
def calc_loss(trainer, logit_s, label_s,
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
pred, pred_tors, pred_allatom, true,
mask_crds, mask_BB, mask_2d, same_chain,
pred_lddt, idx, bond_feats, dist_matrix, atom_frames=None, unclamp=False,
negative=False, interface=False,
w_dist=1.0, w_aa=1.0, w_str=1.0, w_inter_fape=0.0, w_lig_fape=1.0, w_lddt=1.0,
w_bond=1.0, w_clash=0.0, w_atom_bond=0.0, w_skip_bond=0.0, w_rigid=0.0, w_hb=0.0, w_bind=0.0,
w_pae=0.0, w_pde=0.0, lj_lin=0.85, eps=1e-4, binder_loss_label_smoothing = 0.0, item=None, task=None, out_dir='./'
):
gpu = pred.device
def calc_loss(
trainer, loss_param, logit_s, label_s,
logit_aa_s, label_aa_s, mask_aa_s, logit_pae, logit_pde, p_bind,
pred, pred_tors, pred_allatom, trans_1, r3_t, true,
mask_crds, mask_BB, mask_2d, same_chain,
pred_lddt, idx, bond_feats, dist_matrix,
atom_frames=None, unclamp=False, negative=False, interface=False, item=None, task=None,
eps=1e-4, binder_loss_label_smoothing = 0.0, out_dir='./'
):
# fd: force these to be specified in config
w_dist, w_aa, w_str = loss_param["w_dist"], loss_param["w_aa"], loss_param["w_str"]
w_trans, w_inter_fape, w_lig_fape = loss_param["w_trans"], loss_param["w_inter_fape"], loss_param["w_lig_fape"]
w_lddt, w_bond, w_clash = loss_param["w_lddt"], loss_param["w_bond"], loss_param["w_clash"]
w_atom_bond, w_skip_bond, w_rigid = loss_param["w_atom_bond"], loss_param["w_skip_bond"], loss_param["w_rigid"]
w_hb, w_bind, w_pae = loss_param["w_hb"], loss_param["w_bind"], loss_param["w_pae"]
w_pde, lj_lin, w_pae = loss_param["w_pde"], loss_param["lj_lin"], loss_param["w_pae"]
# track losses for printing to local log and uploading to WandB
loss_dict = OrderedDict()
# track losses for printing to local log and uploading to WandB
loss_dict = OrderedDict()
B, L, natoms = true.shape[:3]
seq = label_aa_s[:,0].clone()
B, L, natoms = true.shape[:3]
seq = label_aa_s[:,0].clone()
assert (B==1) # fd - code assumes a batch size of 1
assert (B==1) # fd - code assumes a batch size of 1
tot_loss = 0.0
# set up frames
frames, frame_mask = get_frames(
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
tot_loss = 0.0
# set up frames
frames, frame_mask = get_frames(
pred_allatom[-1,None,...], mask_crds, seq, trainer.fi_dev, atom_frames)
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
frames_BB = frames.clone()
frames_BB[..., 1:, :, :] = 0
frame_mask_BB = frame_mask.clone()
frame_mask_BB[...,1:] =False
# update frames and frames_mask to only include BB frames (have to update both for compatibility with compute_general_FAPE)
frames_BB = frames.clone()
frames_BB[..., 1:, :, :] = 0
frame_mask_BB = frame_mask.clone()
frame_mask_BB[...,1:] =False
# c6d loss
for i in range(4):
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
mask_2d_ = mask_2d
else:
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
bb_frame_good = bb_frame_good[...,0] # (1,L)
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
mask_2d_ = mask_2d & loss_mask_2d
# c6d loss
for i in range(4):
loss = cce_loss(logit_s[i], label_s[...,i]) # (B, L, L)
if i==0: # apply distogram loss to all residue pairs with valid BB atoms
mask_2d_ = mask_2d
else:
# apply anglegram loss only when both residues have valid BB frames (i.e. not metal ions, and not examples with unresolved atoms in frames)
_, bb_frame_good = mask_unresolved_frames(frames_BB, frame_mask_BB, mask_crds) # (1, L, nframes)
bb_frame_good = bb_frame_good[...,0] # (1,L)
loss_mask_2d = bb_frame_good & bb_frame_good[...,None]
mask_2d_ = mask_2d & loss_mask_2d
if negative.item():
# Don't compute inter-chain distogram losses
# for negative examples.
mask_2d_ = mask_2d_ * same_chain
if negative.item():
# Don't compute inter-chain distogram losses
# for negative examples.
mask_2d_ = mask_2d_ * same_chain
#fd upcast loss to float to avoid overflow
loss = (mask_2d_*loss.float()).sum() / (mask_2d_.sum() + eps)
tot_loss += w_dist*loss
loss_dict[f'c6d_{i}'] = loss.detach()
#fd upcast loss to float to avoid overflow
loss = (mask_2d_*loss.float()).sum() / (mask_2d_.sum() + eps)
tot_loss += w_dist*loss
loss_dict[f'c6d_{i}'] = loss.detach()
# masked token prediction loss
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
loss = loss * mask_aa_s.reshape(B, -1)
loss = loss.float().sum() / (mask_aa_s.sum() + 1e-4)
tot_loss += w_aa*loss
loss_dict['aa_cce'] = loss.detach()
# masked token prediction loss
loss = cce_loss(logit_aa_s, label_aa_s.reshape(B, -1))
loss = loss * mask_aa_s.reshape(B, -1)
loss = loss.float().sum() / (mask_aa_s.sum() + 1e-4)
tot_loss += w_aa*loss
loss_dict['aa_cce'] = loss.detach()
# col 4: binder loss
# only apply binding loss to complexes
# note that this will apply loss to positive sets w/o a corresponding negative set
# (e.g., homomers). Maybe want to change this?
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
if (torch.sum(same_chain==0) > 0):
bce = torch.nn.BCELoss()
target = torch.tensor(
[abs(float(not negative) - binder_loss_label_smoothing)],
device=p_bind.device
)
loss = bce(p_bind,target)
else:
# avoid unused parameter error
loss = 0.0 * p_bind.sum()
tot_loss += w_bind * loss
loss_dict['binder_bce_loss'] = loss.detach()
### GENERAL LAYERS
# Structural loss (layer-wise backbone FAPE)
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
dclamp_prot = 10
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
# whereas other losses also maks unresolved ligand atoms (mask_BB)
# frames with unresolved ligand atoms are masked in compute_general_FAPE
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
# create 2d masks for intrachain and interchain fape calculations
nframes = frame_mask.shape[-1]
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, ChemData().NTOTAL)
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
different_chain = ~same_chain.bool()
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
pae_loss = 0.0 * logit_pae.sum()
pde_loss = 0.0 * logit_pde.sum()
elif negative: # inter-chain fapes should be ignored for negative cases
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
dclamp=dclamp,
logit_pae=logit_pae,
logit_pde=logit_pde,
# col 4: binder loss
# only apply binding loss to complexes
# note that this will apply loss to positive sets w/o a corresponding negative set
# (e.g., homomers). Maybe want to change this?
if "binder" in trainer.config.model.auxiliary_predictors or trainer.config.experiment.trainer =="legacy":
if (torch.sum(same_chain==0) > 0):
bce = torch.nn.BCELoss()
target = torch.tensor(
[abs(float(not negative) - binder_loss_label_smoothing)],
device=p_bind.device
)
loss = bce(p_bind,target)
else:
# avoid unused parameter error
loss = 0.0 * p_bind.sum()
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
tot_loss += w_bind * loss
loss_dict['binder_bce_loss'] = loss.detach()
# change clamp for intra protein to 10, leave rest at 30
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
if not unclamp:
is_prot = is_protein(seq) # (1,L)
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
# zero out rows and columns with small molecules
same_chain_clamp_mask[:, ~is_prot[0]] = 0
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
dclamp_2d *= ~same_chain_clamp_mask.bool()
dclamp_2d += same_chain_clamp_mask*dclamp_prot
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
dclamp=None,
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
logit_pae=logit_pae,
logit_pde=logit_pde,
)
### GENERAL LAYERS
# Structural loss (layer-wise backbone FAPE)
dclamp = 300.0 if unclamp else 30.0 # protein & NA FAPE distance cutoffs
dclamp_sm, Z_sm = 4, 4 # sm mol FAPE distance cutoffs
dclamp_prot = 10
# residue mask for FAPE calculation only masks unresolved protein backbone atoms
# whereas other losses also maks unresolved ligand atoms (mask_BB)
# frames with unresolved ligand atoms are masked in compute_general_FAPE
res_mask = ~((mask_crds[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(seq)))
# free up big intermediate data tensors
del dclamp_2d
if not unclamp:
del same_chain_clamp_mask
# create 2d masks for intrachain and interchain fape calculations
nframes = frame_mask.shape[-1]
frame_atom_mask_2d_allatom = torch.einsum('bfn,bra->bfnra', frame_mask_BB, mask_crds).bool() # B, L, nframes, L, natoms
frame_atom_mask_2d = frame_atom_mask_2d_allatom[:, :, :, :, :3]
frame_atom_mask_2d_intra_allatom = frame_atom_mask_2d_allatom * same_chain[:, :,None, :, None].bool().expand(-1,-1,nframes,-1, ChemData().NTOTAL)
frame_atom_mask_2d_intra = frame_atom_mask_2d_intra_allatom[:, :, :, :, :3]
different_chain = ~same_chain.bool()
frame_atom_mask_2d_inter = frame_atom_mask_2d*different_chain[:, :,None, :, None].expand(-1,-1,nframes,-1, 3)
num_layers = pred.shape[0]
gamma = 1.0 # equal weighting of fape across all layers
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
w_bb_fape = torch.flip(w_bb_fape, (0,))
w_bb_fape = w_bb_fape / w_bb_fape.sum()
bb_l_fape = (w_bb_fape*tot_str).sum()
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
tot_str = 0.0 * pred.sum(axis=(1,2,3,4))
pae_loss = 0.0 * logit_pae.sum()
pde_loss = 0.0 * logit_pde.sum()
elif negative: # inter-chain fapes should be ignored for negative cases
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra[:, res_mask[0]][:, :, :, res_mask[0]],
dclamp=dclamp,
logit_pae=logit_pae,
logit_pde=logit_pde,
)
tot_loss += 0.5*w_str*bb_l_fape
for i in range(len(tot_str)):
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
loss_dict['bb_fape_full'] = bb_l_fape.detach()
else:
tot_loss += w_pae*pae_loss + w_pde*pde_loss
loss_dict['pae_loss'] = pae_loss.detach()
loss_dict['pde_loss'] = pde_loss.detach()
if logit_pae is not None:
logit_pae = logit_pae[:,:,res_mask[0]][:,:,:,res_mask[0]]
if logit_pde is not None:
logit_pde = logit_pde[:,:,res_mask[0]][:,:,:,res_mask[0]]
## small-molecule ligands
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
# change clamp for intra protein to 10, leave rest at 30
dclamp_2d = torch.full_like(frame_atom_mask_2d_allatom, dclamp, dtype=torch.float32)
if not unclamp:
is_prot = is_protein(seq) # (1,L)
same_chain_clamp_mask = same_chain[:, :, None, :, None].bool().repeat(1,1,nframes,1, natoms)
# zero out rows and columns with small molecules
same_chain_clamp_mask[:, ~is_prot[0]] = 0
same_chain_clamp_mask[:,:, :, ~is_prot[0]] = 0
dclamp_2d *= ~same_chain_clamp_mask.bool()
dclamp_2d += same_chain_clamp_mask*dclamp_prot
## AllAtom loss
# get ground-truth torsion angles
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
true, seq, mask_in=mask_crds)
tors_mask *= mask_BB[...,None]
tot_str, pae_loss, pde_loss = compute_general_FAPE(
pred[:,res_mask,:,:3],
true[:,res_mask[0],:3],
mask_crds[:,res_mask[0],:3],
frames_BB[:,res_mask[0]],
frame_mask_BB[:,res_mask[0]],
dclamp=None,
dclamp_2d=dclamp_2d[:, res_mask[0]][:, :, :, res_mask[0],:3],
logit_pae=logit_pae,
logit_pde=logit_pde,
)
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true)
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
predTs = pred[-1,...]
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
# free up big intermediate data tensors
del dclamp_2d
if not unclamp:
del same_chain_clamp_mask
# - resolve symmetry
xs_mask = trainer.aamask[seq] # (B, L, 27)
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
xs_mask *= mask_crds # mask missing atoms & residues as well
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
num_layers = pred.shape[0]
gamma = 1.0 # equal weighting of fape across all layers
w_bb_fape = torch.pow(torch.full((num_layers,), gamma, device=pred.device), torch.arange(num_layers, device=pred.device))
w_bb_fape = torch.flip(w_bb_fape, (0,))
w_bb_fape = w_bb_fape / w_bb_fape.sum()
bb_l_fape = (w_bb_fape*tot_str).sum()
# torsion angle loss
l_tors = torsionAngleLoss(
pred_tors,
true_tors,
true_tors_alt,
tors_mask,
tors_planar,
eps = 1e-4)
tot_loss += w_str*l_tors
loss_dict['torsion'] = l_tors.detach()
tot_loss += 0.5*w_str*bb_l_fape
for i in range(len(tot_str)):
loss_dict[f'bb_fape_layer{i}'] = tot_str[i].detach()
loss_dict['bb_fape_full'] = bb_l_fape.detach()
### FINETUNING LAYERS
# lddts (CA)
tot_loss += w_pae*pae_loss + w_pde*pde_loss
loss_dict['pae_loss'] = pae_loss.detach()
loss_dict['pde_loss'] = pde_loss.detach()
## small-molecule ligands
sm_res_mask = is_atom(label_aa_s[0,0])*res_mask[0] # (L,)
## AllAtom loss
# get ground-truth torsion angles
true_tors, true_tors_alt, tors_mask, tors_planar = trainer.xyz_converter.get_torsions(
true, seq, mask_in=mask_crds)
tors_mask *= mask_BB[...,None]
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true)
true_alt.scatter_(2, trainer.l2a[seq,:,None].repeat(1,1,1,3), true)
natRs_all, _n0 = trainer.xyz_converter.compute_all_atom(seq, true[...,:3,:], true_tors)
natRs_all_alt, _n1 = trainer.xyz_converter.compute_all_atom(seq, true_alt[...,:3,:], true_tors_alt)
predTs = pred[-1,...]
predRs_all, pred_all = trainer.xyz_converter.compute_all_atom(seq, predTs, pred_tors[-1])
# - resolve symmetry
xs_mask = trainer.aamask[seq] # (B, L, 27)
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
xs_mask *= mask_crds # mask missing atoms & residues as well
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
# torsion angle loss
l_tors = torsionAngleLoss(
pred_tors,
true_tors,
true_tors_alt,
tors_mask,
tors_planar,
eps = 1e-4)
tot_loss += w_str*l_tors
loss_dict['torsion'] = l_tors.detach()
### FINETUNING LAYERS
# lddts (CA)
if pred_lddt is not None:
ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface)
loss_dict['ca_lddt'] = ca_lddt[-1].detach()
@@ -323,174 +323,162 @@ def calc_loss(trainer, logit_s, label_s,
loss_dict['lddt_loss'] = lddt_loss.detach()
loss_dict['allatom_lddt'] = allatom_lddt[0].detach()
# FAPE losses
# allatom fape and torsion angle loss
# frames, frame_mask = get_frames(
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
l_fape = torch.zeros((pred.shape[0])).to(gpu)
# FAPE losses
# allatom fape and torsion angle loss
# frames, frame_mask = get_frames(
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
if task[0] in ['tf','neg_tf'] or res_mask.sum() == 0:
l_fape = torch.zeros((pred.shape[0])).to(gpu)
elif negative.item(): # inter-chain fapes should be ignored for negative cases
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
)
else:
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]]
)
tot_loss += w_str*l_fape[0]
loss_dict['allatom_fape'] = l_fape[0].detach()
# rmsd loss (for logging only)
if torch.any(mask_BB[0]):
rmsd = calc_crd_rmsd(
pred_allatom[:,mask_BB[0],:,:3],
nat_symm[None,mask_BB[0],:,:3],
xs_mask[:,mask_BB[0]]
)
loss_dict["rmsd"] = rmsd[0].detach()
else:
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
# create protein and not protein masks; not protein could include nucleic acids
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
not_prot_mask_BB = ~prot_mask_BB.bool()
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
xs_mask_prot[:,~prot_mask_BB] = False
xs_mask_lig[:,~not_prot_mask_BB] = False
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_prot = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
)
else:
rmsd_prot_prot = torch.tensor([0], device=pred.device)
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_lig_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
)
else:
rmsd_lig_lig = torch.tensor([0], device=pred.device)
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
alignment_radius=10.0
)
# fd rms of target ligand only
#fd get target ligand mask
#fd this is more difficult than expected with only the data we have
#fd a) target ligand is 1st one
#fd b) examples are all protein followed by ligand
sm_mask = not_prot_mask_BB
Ls_prot = Ls_from_same_chain_2d(same_chain[:,~sm_mask][:,:,~sm_mask])
Ls_sm = Ls_from_same_chain_2d(same_chain[:,sm_mask][:,:,sm_mask])
xs_mask_tgt = xs_mask.clone()
xs_mask_tgt[:,:sum(Ls_prot)] = False
xs_mask_tgt[:,(sum(Ls_prot)+Ls_sm[0]):]= False
rmsd_prot_tgt = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_tgt[:,mask_BB[0]],
alignment_radius=10.0
)
else:
rmsd_prot_lig = torch.tensor([0], device=pred.device)
rmsd_prot_tgt = torch.tensor([0], device=pred.device)
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
loss_dict["rmsd_prot_tgt"]= rmsd_prot_tgt[0].detach()
# cart bonded (bond geometry)
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss
loss_dict['bond_geom'] = bond_loss.detach()
# clash [use all atoms not just those in native]
clash_loss = calc_lj(
seq[0], pred_allatom,
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, trainer.lj_correction_parameters, trainer.num_bonds,
lj_lin=lj_lin
elif negative.item(): # inter-chain fapes should be ignored for negative cases
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]],
frame_atom_mask_2d=frame_atom_mask_2d_intra_allatom[:, res_mask[0]][:, :, :, res_mask[0]]
)
if w_clash > 0.0:
tot_loss += w_clash*clash_loss.mean()
loss_dict['clash_loss'] = clash_loss[0].detach()
if torch.any(mask_BB[0]):
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
pred=pred_allatom[:,mask_BB[0]],
true=nat_symm[None,mask_BB[0]],
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
seq=seq[:,mask_BB[0]]
)
else:
atom_bond_loss = torch.tensor(0, device=gpu)
skip_bond_loss = torch.tensor(0, device=gpu)
rigid_loss = torch.tensor(0, device=gpu)
if w_atom_bond >= 0.0:
tot_loss += w_atom_bond*atom_bond_loss
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
else:
l_fape, _, _ = compute_general_FAPE(
pred_allatom[:,res_mask[0],:,:3],
nat_symm[None,res_mask[0],:,:3],
xs_mask[:,res_mask[0]],
frames[:,res_mask[0]],
frame_mask[:,res_mask[0]]
)
if w_skip_bond >= 0.0:
tot_loss += w_skip_bond*skip_bond_loss
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
tot_loss += w_str*l_fape[0]
loss_dict['allatom_fape'] = l_fape[0].detach()
if w_rigid >= 0.0:
tot_loss += w_rigid*rigid_loss
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
chain_prot = same_chain.clone()
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
# translation loss (for generative refinement only)
if trans_1 is not None:
allatom_mask = ChemData().allatom_mask.to(seq.device, non_blocking=True)
is_real_atom = allatom_mask[seq].bool()
pred_trans_1 = pred_allatom[is_real_atom]
mask = mask_crds[is_real_atom]
trans = translation_vector_field (pred_trans_1, trans_1, mask, r3_t, loss_param)
tot_loss += w_trans*trans[0]
loss_dict["trans_loss"] = trans[0].detach()
loss_dict["t"] = r3_t[0,0].detach()
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, negative=True, N_stripe=10)
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
# rmsd loss (for logging only)
if torch.any(mask_BB[0]):
rmsd = calc_crd_rmsd(
pred_allatom[:,mask_BB[0],:,:3],
nat_symm[None,mask_BB[0],:,:3],
xs_mask[:,mask_BB[0]]
)
loss_dict["rmsd"] = rmsd[0].detach()
else:
loss_dict["rmsd"] = torch.tensor(0, device=gpu)
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
chain_lig = same_chain.clone()
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
# create protein and not protein masks; not protein could include nucleic acids
prot_mask_BB = is_protein(label_aa_s[0,0]) #*mask_BB[0] # (L,)
not_prot_mask_BB = ~prot_mask_BB.bool()
xs_mask_prot, xs_mask_lig = xs_mask.clone(), xs_mask.clone()
xs_mask_prot[:,~prot_mask_BB] = False
xs_mask_lig[:,~not_prot_mask_BB] = False
if torch.any(prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_prot = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_prot[:,mask_BB[0]]
)
else:
rmsd_prot_prot = torch.tensor([0], device=pred.device)
if torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_lig_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_lig[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]]
)
else:
rmsd_lig_lig = torch.tensor([0], device=pred.device)
if torch.any(prot_mask_BB) and torch.any(not_prot_mask_BB) and torch.any(mask_BB[0]):
rmsd_prot_lig = calc_crd_rmsd(
pred=pred_allatom[:,mask_BB[0],:,:3], true=nat_symm[None,mask_BB[0],:,:3],
atom_mask=xs_mask_prot[:,mask_BB[0]], rmsd_mask=xs_mask_lig[:,mask_BB[0]],
alignment_radius=10.0
)
else:
rmsd_prot_lig = torch.tensor([0], device=pred.device)
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
chain_prot_lig_inter += protein_mask_2d
chain_prot_lig_inter += not_protein_mask_2d
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
chain_prot_lig_inter, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
loss_dict['total_loss'] = tot_loss.detach()
loss_dict["rmsd_prot_prot"]= rmsd_prot_prot[0].detach()
loss_dict["rmsd_lig_lig"]= rmsd_lig_lig[0].detach()
loss_dict["rmsd_prot_lig"]= rmsd_prot_lig[0].detach()
return tot_loss, loss_dict
# cart bonded (bond geometry)
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss
loss_dict['bond_geom'] = bond_loss.detach()
# clash [use all atoms not just those in native]
clash_loss = calc_lj(
seq[0], pred_allatom,
trainer.aamask, bond_feats, dist_matrix, trainer.ljlk_parameters, trainer.lj_correction_parameters, trainer.num_bonds,
lj_lin=lj_lin
)
if w_clash > 0.0:
tot_loss += w_clash*clash_loss.mean()
loss_dict['clash_loss'] = clash_loss[0].detach()
if torch.any(mask_BB[0]):
atom_bond_loss, skip_bond_loss, rigid_loss = calc_atom_bond_loss(
pred=pred_allatom[:,mask_BB[0]],
true=nat_symm[None,mask_BB[0]],
bond_feats=bond_feats[:,mask_BB[0]][:,:,mask_BB[0]],
seq=seq[:,mask_BB[0]]
)
else:
atom_bond_loss = torch.tensor(0, device=gpu)
skip_bond_loss = torch.tensor(0, device=gpu)
rigid_loss = torch.tensor(0, device=gpu)
if w_atom_bond >= 0.0:
tot_loss += w_atom_bond*atom_bond_loss
loss_dict['atom_bond_loss'] = ( atom_bond_loss.detach() )
if w_skip_bond >= 0.0:
tot_loss += w_skip_bond*skip_bond_loss
loss_dict['skip_bond_loss'] = ( skip_bond_loss.detach() )
if w_rigid >= 0.0:
tot_loss += w_rigid*rigid_loss
loss_dict['rigid_loss'] = ( rigid_loss.detach() )
chain_prot = same_chain.clone()
protein_mask_2d = torch.einsum('l,r-> lr', prot_mask_BB, prot_mask_BB)
_, allatom_lddt_prot_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, negative=True, N_stripe=10)
loss_dict['allatom_lddt_prot_intra'] = allatom_lddt_prot_intra[0].detach()
_, allatom_lddt_prot_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, protein_mask_2d[None],
chain_prot, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_inter'] = allatom_lddt_prot_inter[0].detach()
chain_lig = same_chain.clone()
not_protein_mask_2d = torch.einsum('l,r-> lr', not_prot_mask_BB, not_prot_mask_BB)
_, allatom_lddt_lig_intra = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, negative=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_intra'] = allatom_lddt_lig_intra[0].detach()
_, allatom_lddt_lig_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, not_protein_mask_2d[None],
chain_lig, interface=True, bin_scaling=0.5, N_stripe=10)
loss_dict['allatom_lddt_lig_inter'] = allatom_lddt_lig_inter[0].detach()
chain_prot_lig_inter = torch.zeros_like(same_chain, dtype=bool)
chain_prot_lig_inter += protein_mask_2d
chain_prot_lig_inter += not_protein_mask_2d
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d,
chain_prot_lig_inter, interface=True, N_stripe=10)
loss_dict['allatom_lddt_prot_lig_inter'] = allatom_lddt_inter[0].detach()
loss_dict['total_loss'] = tot_loss.detach()
return tot_loss, loss_dict
### this file will contain specific calls to the loss function

452
rf2aa/model/AF3_blocks.py Normal file
View File

@@ -0,0 +1,452 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
import torch.nn.functional as F
from functools import partial
import numpy as np
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util_module import Dropout, init_lecun_normal
from opt_einsum import contract as einsum
# MSA transformer
class AF3_full_block(nn.Module):
"""
AF3_full_block:
- MSA/Pair updates as in AF2
- MSA then Pair
"""
def __init__(self, global_config, block_params, **kwargs):
super(AF3_full_block, self).__init__()
d_msa, d_pair = (
global_config.d_msa_full,
global_config.d_pair
)
# to do: optionally disable norm bias
self.norm_pair_bias = nn.LayerNorm(d_pair, bias=False)
self.drop_row = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_row)
self.drop_col = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_pair)
# to do: optionally disable norm bias
self.msa_row_attn = MSARowAttentionWithBias(
d_msa=d_msa,
d_pair=d_pair,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels,
nseq_normalization=block_params.norm_msa_row,
bias=False
)
self.msa_col_attn = MSAColGlobalAttention(
d_msa=d_msa,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels,
bias=False
)
self.msa_transition = FeedForwardLayer(d_msa, 4, p_drop=block_params.msa_transition_drop)
# Pair update parameters
self.outer_product = OuterProductMean(d_msa, d_pair, d_hidden=block_params.outer_product_channels, \
p_drop=block_params.p_drop_outer_product)
# to do: optionally disable norm bias
self.tri_mul_outgoing = TriangleMultiplication(
d_pair, d_hidden=block_params.n_pair_channels, outgoing=True, bias=False)
self.tri_mul_incoming = TriangleMultiplication(
d_pair, d_hidden=block_params.n_pair_channels, outgoing=False, bias=False)
self.tri_attn_start = TriangleAttention(
d_pair, d_hidden=block_params.n_pair_channels, start_node=True)
self.tri_attn_end = TriangleAttention(
d_pair, d_hidden=block_params.n_pair_channels, start_node=False)
self.pair_transition = FeedForwardLayer(d_pair, 2) # HACK: hardcoded value for transition
def _unpack_inputs(self, latent_feats):
pair = latent_feats["pair"]
msa = latent_feats["msa_full"]
return msa, pair
def _pack_outputs(self, msa, pair, latent_feats):
latent_feats["msa_full"] = msa
latent_feats["pair"] = pair
return latent_feats
def _1d_update(self, msa, pair):
pair = self.norm_pair_bias(pair)
msa = msa + self.drop_row(self.msa_row_attn(msa, pair))
msa = msa + self.msa_col_attn(msa)
msa = msa + self.msa_transition(msa)
return msa
def _2d_update(self, msa, pair):
msa_bias = self.outer_product(msa)
pair = pair + msa_bias
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
pair = pair + self.drop_row(self.tri_attn_start(pair))
pair = pair + self.drop_row(self.tri_attn_end(pair))
pair = pair + self.pair_transition(pair)
return pair
def forward(self, latent_feats, use_checkpoint, use_amp):
msa, pair = self._unpack_inputs(latent_feats)
if use_checkpoint:
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
msa = checkpoint.checkpoint(self._1d_update, msa, pair, use_reentrant=True)
pair = checkpoint.checkpoint(self._2d_update, msa, pair, use_reentrant=True)
else:
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
msa = self._1d_update(msa, pair)
pair = self._2d_update(msa, pair)
latent_feats = self._pack_outputs(msa, pair, latent_feats)
return latent_feats
# Pair transformer
class AF3_block(nn.Module):
"""
Attempt to replicate AF3 architecture from scratch.
"""
def __init__(self, global_config, block_params, **kwargs):
super(AF3_block, self).__init__()
d_singleseq, d_pair = (
global_config.d_msa,
global_config.d_pair
)
# to do: optionally disable norm bias
self.norm_pair_bias = nn.LayerNorm(d_pair, bias=False)
self.drop_row = Dropout(broadcast_dim=2, p_drop=block_params.p_drop_row)
self.drop_col = Dropout(broadcast_dim=1, p_drop=block_params.p_drop_pair)
# single sequence attn
self.msa_row_attn = MSARowAttentionWithBias(
d_msa=d_singleseq,
d_pair=d_pair,
n_head=block_params.n_msa_head,
d_hidden=block_params.n_msa_channels,
nseq_normalization=block_params.norm_msa_row,
bias=False
)
self.msa_transition = FeedForwardLayer(d_singleseq, 4, p_drop=block_params.msa_transition_drop)
# to do: optionally disable norm bias
self.tri_mul_outgoing = TriangleMultiplication(
d_pair, d_hidden=block_params.n_pair_channels, outgoing=True, bias=False)
self.tri_mul_incoming = TriangleMultiplication(
d_pair, d_hidden=block_params.n_pair_channels, outgoing=False, bias=False)
self.tri_attn_start = TriangleAttention(
d_pair, d_hidden=block_params.n_pair_channels, start_node=True)
self.tri_attn_end = TriangleAttention(
d_pair, d_hidden=block_params.n_pair_channels, start_node=False)
self.pair_transition = FeedForwardLayer(d_pair, 4) # HACK: hardcoded value for transition
def _unpack_inputs(self, latent_feats):
pair = latent_feats["pair"]
singleseq = latent_feats["msa"]
return singleseq, pair
def _pack_outputs(self, singleseq, pair, latent_feats):
latent_feats["msa"] = singleseq
latent_feats["pair"] = pair
return latent_feats
def _1d_update(self, msa, pair):
pair = self.norm_pair_bias(pair)
msa = msa + self.drop_row(self.msa_row_attn(msa, pair)) # pair biased attn
msa = msa + self.msa_transition(msa)
return msa
def _2d_update(self, msa, pair):
pair = pair + self.drop_row(self.tri_mul_outgoing(pair))
pair = pair + self.drop_row(self.tri_mul_incoming(pair))
pair = pair + self.drop_row(self.tri_attn_start(pair))
pair = pair + self.drop_col(self.tri_attn_end(pair))
pair = pair + self.pair_transition(pair)
return pair
def forward(self, latent_feats, use_checkpoint, use_amp):
singleseq, pair = self._unpack_inputs(latent_feats)
drop_layer = 0
if use_checkpoint:
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
# 2d then 1d update
pair = checkpoint.checkpoint(self._2d_update, singleseq, pair, use_reentrant=True)
singleseq = checkpoint.checkpoint(self._1d_update, singleseq, pair, use_reentrant=True)
else:
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
# 2d then 1d update
pair = self._2d_update(singleseq, pair)
singleseq = self._1d_update(singleseq, pair)
latent_feats = self._pack_outputs(singleseq, pair, latent_feats)
return latent_feats
class MsaModule(nn.Module):
def __init__(self,
n_blocks,
subsampled_embedding,
outer_product,
msa_pair_weighted_averaging,
msa_transition,
triangle_multiplication_incoming,
triangle_multiplication_outgoing,
triangle_attention_starting,
triangle_attention_ending,
pair_transition,
):
super(MsaModule, self).__init__()
self.n_blocks = n_blocks
self.msa_subsampler = MsaSubsampleEmbedder(subsampled_embedding)
self.outer_product = OuterProductMean(outer_product)
self.msa_pair_weighted_averaging = MsaPairWeightedAverage(msa_pair_weighted_averaging)
self.msa_transition = FeedForwardLayer(msa_transition)
#TODO: check if row and col dropout are right
self.drop_row = Dropout(broadcast_dim=1, p_drop=0.25)
self.drop_col = Dropout(broadcast_dim=2, p_drop=0.25)
self.tri_mult_outgoing = TriangleMultiplication(triangle_multiplication_outgoing)
self.tri_mult_incoming = TriangleMultiplication(triangle_multiplication_incoming)
self.tri_attn_start = TriangleAttention(triangle_attention_starting)
self.tri_attn_end = TriangleAttention(triangle_attention_ending)
self.pair_transition = FeedForwardLayer(pair_transition)
def forward(self,
f_dict,
pair_II,
S_inputs
):
msa_SI = f_dict["msa_SI"]
msa_SI = self.msa_subsampler(msa_SI, S_inputs)
for i in range(self.n_blocks):
pair_II = pair_II + self.outer_product(msa_SI)
msa_SI = msa_SI + self.drop_row(self.msa_pair_weighted_averaging(msa_SI, pair_II))
msa_SI = msa_SI + self.msa_transition(msa_SI)
pair_II = pair_II + self.drop_row(self.tri_mult_outgoing(pair_II))
pair_II = pair_II + self.drop_row(self.tri_mult_incoming(pair_II))
pair_II = pair_II + self.drop_row(self.tri_attn_start(pair_II))
pair_II = pair_II + self.drop_row(self.tri_attn_end(pair_II))
pair_II = pair_II + self.pair_transition(pair_II)
return pair_II
class MsaSubsampleEmbedder(nn.Module):
def __init__(self, params):
super(MsaSubsampleEmbedder, self).__init__()
self.num_sequences = params["num_sequences"]
self.emb_msa = nn.Linear(params["msa_dim"], params["msa_channels"], bias=False)
self.emb_S_inputs = nn.Linear(params["S_dim"], params["msa_channels"], bias=False)
def forward(self,
msa_SI,
S_inputs # (B, L, S_dim)
):
B, S, I = msa_SI.shape[:3]
num_samples = torch.min(torch.tensor([self.num_sequences, S]))
weights = torch.ones(num_samples.item(), device=msa_SI.device)
samples = torch.multinomial(weights, num_samples, replacement=False)
msa_SI = msa_SI[:, samples]
msa_SI = self.emb_msa(msa_SI)
msa_SI = msa_SI + self.emb_S_inputs(S_inputs)
return msa_SI
class MsaPairWeightedAverage(nn.Module):
""" implements Algorithm 10 from AF3 paper"""
def __init__(self, params):
super(MsaPairWeightedAverage, self).__init__()
self.weighted_average_channels = params["weighted_average_channels"]
self.n_heads = params["n_heads"]
self.msa_channels = params["msa_channels"]
self.pair_channels = params["pair_channels"]
self.norm_msa = nn.LayerNorm(self.msa_channels)
self.to_v = nn.Linear(self.msa_channels, self.n_heads*self.weighted_average_channels, bias=False)
self.norm_pair = nn.LayerNorm(self.pair_channels)
self.to_bias = nn.Linear(self.msa_channels, self.n_heads, bias=False)
self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
self.to_out = nn.Linear(self.weighted_average_channels, self.msa_channels, bias=False)
def forward(self,
msa_SI,
pair_II
):
B, S, I = msa_SI.shape[:3]
msa_SI = self.norm_msa(msa_SI)
v_SIH = self.to_v(msa_SI).reshape(B, S, I, self.n_heads, self.d_head)
bias_IIH = self.to_bias(self.norm_pair(pair_II))
gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
w_IIH = F.softmax(bias_IIH, dim=-2)
weights = torch.einsum( "bijh,bsjhc->bsihc", w_IIH, v_SIH)
o_SIH = gate_SIH * weights
msa_update_SI = self.to_out(o_SIH.reshape(B, S, I, -1))
return msa_update_SI
class BiasedSequenceAttention(nn.Module):
def __init__(self, global_params, block_params):
super(BiasedSequenceAttention, self).__init__()
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
#
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
self.scaling = 1/np.sqrt(block_params.n_channels)
self.h = block_params.n_heads
self.dim = block_params.n_channels
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
def forward(self, state, pair): # TODO: make this as tied-attention
B, L = state.shape[:2]
#
state = self.norm_state(state)
pair = self.norm_pair(pair)
query = self.to_q(state).reshape(B, L, self.h, self.dim)
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
value = self.to_v(state).reshape(B, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(state))
attn = einsum('bqhd,bkhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
out = gate * out
out = self.to_out(out)
out = state + self.transition(out)
return out
class TemplateEmbedding(nn.Module):
def __init__(self, params):
super(TemplateEmbedding, self).__init__()
self.template_channels = params["template_channels"]
self.emb_pair = nn.Linear(params["pair_dim"], params["template_channels"], bias=False)
self.norm_pair_before_pairformer = nn.LayerNorm(params["pair_dim"])
self.norm_after_pairformer = nn.LayerNorm(params["template_channels"])
# HACK: need the actual pairformer block
self.pairformer = AF3_block(params["pair_dim"], params["template_channels"], params["pairformer_channels"], params["n_pairformer_layers"])
# NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
self.agg_emb = nn.Linear(params["template_channels"], params["pair_dim"], bias=False)
def forward(self,
f_dict,
pair_II,
):
B, I = pair_II.shape[:2]
template_frame_mask = f_dict["template_frame_mask"][None, :] * f_dict["template_frame_mask"][:, None]
template_pseudo_beta_mask = f_dict["template_pseudo_beta_mask"][None, :] * f_dict["template_pseudo_beta_mask"][:, None]
template_feats = torch.cat([f_dict["template_distogram"], template_frame_mask, f_dict["template_unit_vector"], template_pseudo_beta_mask])
template_feats = template_feats * (f_dict["asym_id"][None, :] == f_dict["asym_id"][:, None])
T = template_feats.shape[1]
u_II = torch.zeros(B, I, I, self.template_channels, device=pair_II.device)
for i in range(T):
v_II = self.emb_pair(self.norm_pair_before_pairformer(pair_II)) + template_feats[:, i]
v_II = self.pairformer(v_II)
u_II = u_II + self.norm_after_pairformer(v_II)
u_II = u_II / T
return self.agg_emb(F.relu(u_II))
class BiasedSequenceAttention(nn.Module):
def __init__(self, global_params, block_params):
super(BiasedSequenceAttention, self).__init__()
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
#
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
self.scaling = 1/np.sqrt(block_params.n_channels)
self.h = block_params.n_heads
self.dim = block_params.n_channels
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
def forward(self, state, pair): # TODO: make this as tied-attention
B, L = state.shape[:2]
#
state = self.norm_state(state)
pair = self.norm_pair(pair)
query = self.to_q(state).reshape(B, L, self.h, self.dim)
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
value = self.to_v(state).reshape(B, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(state))
attn = einsum('bqhd,bkhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
out = gate * out
out = self.to_out(out)
out = state + self.transition(out)
return out

View File

@@ -0,0 +1,779 @@
import torch
import torch.nn as nn
from torch.nn.functional import one_hot, sigmoid
import torch.utils.checkpoint as checkpoint
from functools import partial
import numpy as np
from torch import relu
from icecream import ic
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util_module import Dropout
#from rf2aa.alignment import weighted_rigid_align
'''
Glossary:
I: # tokens (coarse representation)
L: # atoms (fine representation)
M: # msa
T: # templates
'''
linearNoBias = partial(torch.nn.Linear, bias=False)
class AtomAttentionEncoder(nn.Module):
def __init__(self, c_atom, c_atompair, c_token, atom_1d_features, atom_transformer):
super().__init__()
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_s = c_token
self.atom_1d_features = atom_1d_features
self.pair_mlp = nn.Sequential(
nn.ReLU(),
linearNoBias(self.c_atompair, c_atompair),
nn.ReLU(),
linearNoBias(self.c_atompair, c_atompair),
nn.ReLU(),
linearNoBias(self.c_atompair, c_atompair),
)
self.atom_transformer = AtomTransformer(c_atom=c_atom, c_atompair=c_atompair, **atom_transformer)
def forward(
self,
f, # Dict (Input feature dictionary)
Rl, # [B, L, 3]
Si_trunk, # [B, I, C_S_trunk]
Zij, # [B, I, I, C_Z]
tok_idx, # [L] maps l --> i
):
B, I, _ = Si_trunk.shape
# Create the atom single conditioning: Embed per-atom meta data
Cl = self.linear_no_bias_1(torch.concat(f[feature_name] for feature_name in self.atom_1d_features))
# Embed offsets between atom reference positions
Dlm = f['ref_pos'].unsqueeze(-1) - f['ref_pos'].unsqueeze(-2)
Vlm = f['ref_space_uid'].unsqueeze(-1) == f['ref_space_uid'].unsqueeze(-2)
Plm = self.linear_1(Dlm) * Vlm
# Embed pairwise inverse squared distances, and the valid mask
Plm += self.linear_2(1/(1+torch.linalg.norm(Dlm, dim=-1))) * Vlm
# Initialise the atom single representation as the single conditioning.
Ql = Cl
# If provided, add trunk embeddings and noisy positions.
## Broadcast the single and pair embedding from the trunk.
Sl_trunk = Si_trunk[:, self.tok_idx]
Cl += self.linear_3(self.layer_norm_1(Sl_trunk))
## Add the noisy positions.
Ql += self.linear_4(Rl)
# Add the combined single conditioning to the pair representation.
Plm += self.linear_5(relu(Cl)).unsqueeze(-1) + self.linear_6(relu(Cl)).unsqueeze(-2)
# Run a small MLP on the pair activations
Plm += self.pair_mlp(Plm)
# Cross attention transformer.
Ql = self.atom_transformer(Ql, Cl, Plm)
# Aggregate per-atom representation to per-token representation
Ai = torch.zeros((B, I, self.c_token)).reduce(
1,
tok_idx,
relu(self.linear_6(Ql)),
'mean',
include_self=False)
return Ai, Ql, Cl, Plm
class AtomAttentionDecoder(nn.Module):
def __init__(self, c_token, c_atom, c_atompair, atom_transformer):
super().__init__()
self.atom_transformer = AtomTransformer(c_atom=c_atom, c_atompair=c_atompair, **atom_transformer)
self.linear_1 = linearNoBias(c_token, c_atom)
self.linear_2 = linearNoBias(c_atom, 3)
def forward(
self,
Ai, # [L, C_token]
Ql_skip, # [L, C_atom]
Cl_skip, # [L, C_atom]
Plm_skip, # [L, L, C_atompair]
tok_idx, # [L] maps l --> i
):
# Broadcast per-token activiations to per-atom activations and add the skip connection
Ql = self.linear_1(Ai[tok_idx]) + Ql_skip
# Cross attention transformer.
Ql = self.atom_transformer(Ql, Cl_skip, Plm_skip)
# Map to positions update
Rl_update = self.linear_2(self.layer_norm_2(Ql))
return Rl_update
class AtomTransformer(nn.Module):
def __init__(
self,
c_atom,
c_atompair,
diffusion_transformer,
n_queries,
n_keys,
l_max,
):
super().__init__()
self.l_max = l_max
subset_centers = torch.arange(0, n_queries, l_max) + (n_queries-1 + n_queries / 2)
l = torch.arange(l_max).unsqueeze(-1).unsqueeze(-1) # [l_max, 1, 1]
m = torch.arange(l_max).unsqueeze(0).unsqueeze(-1) # [1, l_max, 1]
c = subset_centers.unsqueeze(0).unsqueeze(0) # [1, 1, S]
Beta_lms_binary = (torch.abs(l - m) < n_queries / 2) * (torch.abs(m - c) < n_keys / 2)
ic(
Beta_lms_binary.dtype,
)
Beta_lm_binary = Beta_lms_binary.prod(dim=-1, dtype=bool)
ic(
Beta_lm_binary.dtype,
)
self.Beta_lm = torch.where(Beta_lm_binary, 0, -10e10)
self.diffusion_transformer = DiffusionTransformer(c_token=c_atom, c_tokenpair=c_atompair, **diffusion_transformer)
def forward(
self,
Ql, # [B, L, C_atom]
Cl, # [B, L, C_atom]
Plm, # [B, L, L, C_atompair]
):
B, L, _ = Ql.shape
assert L < self.l_max
Beta_lm = self.Beta_lm[:L, :L]
return self.diffusion_transformer(Ql, Cl, Plm, Beta_lm)
class DiffusionTransformer(nn.Module):
def __init__(self, c_token, c_tokenpair, n_block, diffusion_transformer_block):
super().__init__()
self.blocks = torch.nn.Sequential(*[
DiffusionTransformerBlock(c_token=c_token, c_tokenpair=c_tokenpair, **diffusion_transformer_block)
for _ in range(n_block)
])
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token]
Zij, # [B, I, I, C_tokenpair]
Beta_ij, # [I, I]
):
return self.blocks(Ai, Si, Zij, Beta_ij)
class DiffusionTransformerBlock(nn.Module):
def __init__(self, c_token, c_tokenpair, n_head):
super().__init__()
self.attention_pair_bias = AttentionPairBias(c_a=c_token, c_pair=c_tokenpair, n_head=n_head)
self.conditioned_transition_block = ConditionedTransitionBlock(c_token=c_token)
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token]
Zij, # [B, I, I, C_tokenpair]
Beta_ij, # [I, I]
):
Bi = self.attention_pair_bias(Ai, Si, Zij, Beta_ij)
Ai = Bi + self.conditioned_transition_block(Ai, Si)
return Ai, Si, Zij, Beta_ij
# class MultiHeadLinear(nn.Linear):
# def __init__(self, in_features, out_features, h, *args, **kwargs):
# self.h = h
# self.out_features = out_features
# super().__init__(in_features, out_features, *args, **kwargs)
# def forward(self, x):
# return sel
class MultiDimLinear(nn.Linear):
def __init__(self, in_features, out_shape, **kwargs):
self.out_shape = out_shape
out_features = np.prod(out_shape)
super().__init__(in_features, out_features, **kwargs)
def forward(self, x):
out = super().forward(x)
return out.reshape(x.shape[:-1] + self.out_shape)
class LinearBiasInit(nn.Linear):
def __init__(self, *args, biasinit, **kwargs):
assert biasinit == -2. # Sanity check
self.biasinit = biasinit
super().__init__(*args, **kwargs)
def reset_parameters(self) -> None:
super().reset_parameters()
self.bias.data.fill_(self.biasinit)
class AttentionPairBias(nn.Module):
def __init__(self, c_a, c_pair, n_head):
super().__init__()
c = c_a // n_head
ic(c, n_head)
self.to_q = MultiDimLinear(c, (n_head, c))
self.to_k = MultiDimLinear(c, (n_head, c), bias=False)
self.to_v = MultiDimLinear(c, (n_head, c), bias=False)
self.to_b = linearNoBias(c_pair, n_head)
self.to_g = nn.Sequential(
MultiDimLinear(c_a, (n_head, c), bias=False),
nn.Sigmoid(),
)
self.to_a = linearNoBias(c_a, c_a)
self.linear_output_project = nn.Sequential(
LinearBiasInit(c_a, c_a, biasinit=-2.),
nn.Sigmoid(),
)
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token] | None
Zij, # [B, I, I, C_tokepair]
Beta_ij, # [I, I]
):
# Input projections
if Si is not None:
Ai = self.ada_ln_1(Ai, Si)
else:
Ai = self.ln_1(Ai, Si)
Qih = self.to_q(Ai)
Kih = self.to_k(Ai)
Vih = self.to_v(Ai)
Bijh = self.to_b(Zij) + Beta_ij
Gih = self.to_g(Ai)
# Attention
Aijh = torch.softmax(torch.pow(self.c, -1/2) * torch.einsum("bihd,bjhd->bijh", Qih, Kih) + Bijh, dim=-2) # softmax over j
## Gih: [B, I, H, C]
## Aijh: [B, I, I, H]
## ViH: [B, I, H, C]
head_i = torch.einsum("bijh,bjhc->bihc", Aijh, Vih)
head_i = Gih * head_i # [B, I, H, C]
Ai = torch.concat(head_i, dim=-2) # [B, I, Ca]
Ai = self.to_a(Ai)
# Output projection (from adaLN-Zero)
if Si is not None:
Ai = self.linear_output_project(Si) * Ai
return Ai
# SwiGLU transition block with adaptive layernorm
class ConditionedTransitionBlock(nn.Module):
def __init__(self, c_token, n=2):
super().__init__()
self.ada_ln = AdaLN(c_token=c_token)
self.linear_1 = linearNoBias(c_token, c_token*n)
self.linear_2 = linearNoBias(c_token, c_token*n)
self.linear_output_project = nn.Sequential(
LinearBiasInit(c_token, c_token, biasinit=-2.),
nn.Sigmoid(),
)
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token]
):
Ai = self.ada_ln(Ai, Si)
Bi = torch.silu(self.linear_1(Ai)) * self.linear_2(Ai)
# Output projection (from adaLN-Zero)
return self.linear_output_project(Si) * self.linear_3(Bi)
class AdaLN(nn.Module):
def __init__(self, c_token, n=2):
super().__init__()
self.ln = nn.LayerNorm(normalized_shape=(c_token,), elementwise_affine=False)
self.ln_learnable_gain = nn.LayerNorm(normalized_shape=(c_token,), bias=False)
self.linear_1 = nn.Linear(c_token, c_token)
self.linear_2 = nn.Linear(c_token, c_token)
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token]
):
Ai = self.ln(Ai)
Si = self.ln_learnable_gain(Si)
return torch.sigmoid(self.linear_1(Si)) * Ai + self.linear_2(Si)
class DiffusionModule(nn.Module):
def __init__(self, sigma_data, c_atom, c_atompair, c_token, c_s, c_z, diffusion_conditioning, atom_attention_encoder, diffusion_transformer, atom_attention_decoder):
super().__init__()
self.sigma_data = sigma_data
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.diffusion_conditioning = DiffusionConditioning(sigma_data=sigma_data, c_s=c_s, c_z=c_z, **diffusion_conditioning)
self.atom_attention_encoder = AtomAttentionEncoder(c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, **atom_attention_encoder)
self.diffusion_transformer = DiffusionTransformer(c_token=c_token, c_tokenpair=c_atompair, **diffusion_transformer)
self.layer_norm_1 = nn.LayerNorm(c_token)
self.atom_attention_decoder = AtomAttentionDecoder(c_token=c_token, c_atom=c_atom, c_atompair=c_atompair, **atom_attention_decoder)
def forward(self,
X_noisy_L, # [B, L, 3]
t, # [B] (0 is ground truth)
f, # Dict (Input feature dictionary)
S_input_I, # [B, I, C_S_input]
S_trunk_I, # [B, I, C_S_trunk]
Z_trunk_II, # [B, I, I, C_Z]
):
# Conditioning
S_I, Z_II = self.diffusion_conditioning(t, f, S_input_I, S_trunk_I, Z_trunk_II)
# Scale positions to dimensionless vectors with approximately unit variance
R_noisy_L = X_noisy_L / torch.sqrt(t^2 + self.sigma_data)
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
A_I, Q_skip_L, C_skip_L, P_skip_LL = self.atom_attention_encoder(f, R_noisy_L, S_trunk_I, Z_II)
# Full self-attention on token level
A_I += self.linear_no_bias(self.layer_norm(S_I))
A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=0)
A_I = self.layer_norm_1(A_I)
# Broadcast token activations to atoms and run Sequence-local Atom Attention
R_update_L = self.atom_attention_decoder(A_I, Q_skip_L, C_skip_L, P_skip_LL)
# Rescale updates to positions and combine with input positions
X_out_L = self.sigma_data**2 / (self.sigma_data**2 + t**2) * X_noisy_L + self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5 * R_update_L
return X_out_L
class DiffusionConditioning(nn.Module):
def __init__(self, sigma_data, c_z, c_s, c_t_embed):
super().__init__()
self.sigma_data = sigma_data
self.to_zii = nn.Sequential(
nn.LayerNorm(c_z),
linearNoBias(c_z, c_z)
)
self.transition_1 = nn.ModuleList([
Transition(c=c_s, n=2),
Transition(c=c_s, n=2),
])
self.to_si = nn.Sequential(
nn.LayerNorm(c_s),
linearNoBias(c_s, c_s)
)
c_t_embed = 256
self.fourier_embedding = FourierEmbedding(c_t_embed)
self.process_n = nn.Sequential(
nn.LayerNorm(c_t_embed),
linearNoBias(c_t_embed, c_s)
)
self.transition_2 = nn.ModuleList([
Transition(c=c_s, n=2),
Transition(c=c_s, n=2),
])
def forward(self,
t,
f,
S_inputs_I,
S_trunk_I,
Z_trunk_II):
# Pair conditioning
Z_II = torch.concat([Z_trunk_II, self.relative_position_encoding(f)])
Z_II = self.to_zii(Z_II)
for b in range(2):
Z_II += self.transition_1[b](Z_II)
# Single conditioning
S_I = torch.concat([S_trunk_I, S_inputs_I])
S_I = self.to_si(S_I)
N_T = self.fourier_embedding(1/4 * torch.log(t/self.sigma_data))
S_I += self.process_n(N_T)
for b in range(2):
S_I += self.transition_2[b](S_I)
return S_I, Z_II
class Transition(nn.Module):
def __init__(self, n, c):
super().__init__()
self.n = n
self.layer_norm_1 = nn.LayerNorm(c)
self.linear_1 = linearNoBias(c, n*c)
self.linear_2 = linearNoBias(c, n*c)
self.linear_3 = linearNoBias(n*c, c)
def forward(self,
X,
):
X = self.layer_norm_1(X)
A = self.linear_1(X)
B = self.linear_2(X)
X = self.linear_3(torch.silu(A) * B)
return X
pi = torch.acos(torch.zeros(1)).item() * 2
class FourierEmbedding(nn.Module):
def __init__(self, c):
super().__init__()
self.c = c
self.register_buffer('w', torch.zeros((c), dtype=torch.float32))
self.register_buffer('b', torch.zeros((c), dtype=torch.float32))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.w)
nn.init.normal_(self.b)
def forward(self,
t,
):
return torch.cos(2 * pi * (t*self.w + self.b))
# from dataclasses import dataclass
# @dataclass
# class RecyclingInput:
class Model(nn.Module):
def __init__(self,
c_s,
c_z,
c_atom,
c_atompair,
feature_initializer,
recycler,
diffusion_module,
**kwargs
):
super().__init__()
self.feature_initializer = FeatureInitializer(c_s=c_s, c_z=c_z, c_atom=c_atom, c_atompair=c_atompair, **feature_initializer)
self.recycler = Recycler(c_s=c_s, c_z=c_z, **recycler)
self.diffusion_module = DiffusionModule(c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z, **diffusion_module)
def forward(self,
f,
X_noisy_L,
t,
n_cycle,
):
super().__init__()
S_input_I, S_init_I, Z_init_II = self.feature_initializer(f)
S_I = torch.zeros_like(S_init_I)
Z_II = torch.zeros_like(Z_init_II)
for _ in range(n_cycle):
S_I, Z_II = self.recycler(f, S_input_I, S_init_I, Z_init_II, S_I, Z_II)
X_pred = self.diffusion_module(
X_noisy_L,
t,
f,
S_input_I,
S_I,
Z_II,
)
return X_pred
def pre_recycle(self,
f,
X_noisy_L,
t):
S_input_I, S_init_I, Z_init_II = self.feature_initializer(f)
S_I = torch.zeros_like(S_init_I)
Z_II = torch.zeros_like(Z_init_II)
return S_input_I, S_init_I, Z_init_II, S_I, Z_II, f, X_noisy_L, t
def recycle(self,
S_input_I,
S_init_I,
Z_init_II,
S_I,
Z_II,
f,
X_noisy_L,
t,
):
S_I, Z_II = self.recycler(
S_input_I,
S_init_I,
Z_init_II,
S_I,
Z_II
)
return S_input_I, S_init_I, Z_init_II, S_I, Z_II, f, X_noisy_L, t
def post_recycle(self,
S_input_I,
S_init_I,
Z_init_II,
S_I,
Z_II,
f,
X_noisy_L,
t,
):
X_pred = self.diffusion_module(
X_noisy_L,
t,
f,
S_input_I,
S_I,
Z_II,
)
return X_pred
class Recycler(nn.Module):
def __init__(self,
c_s,
c_z,
template_embedder,
msa_module,
n_pairformer_blocks,
pairformer_block,
):
super().__init__()
self.process_zh = nn.Sequential(
nn.LayerNorm(c_z),
linearNoBias(c_z, c_z),
)
self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
self.msa_module = MSAModule(**msa_module)
self.process_sh = nn.Sequential(
nn.LayerNorm(c_s),
linearNoBias(c_s, c_s),
)
self.pairformer_stack = nn.Sequential(*[
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) for _ in range(n_pairformer_blocks)
])
def forward(self,
f,
S_inputs_I,
S_init_I,
Z_init_II,
Sh_I,
Zh_II,
):
Z_II = Z_init_II + self.process_zh(Zh_II)
Z_II += self.template_embedder(f, Z_II)
Z_II += self.msa_module(f['msa'], Z_II, S_inputs_I)
S_I = S_init_I + self.process_sh(Sh_I)
S_I, Z_II = self.pairformer_stack(S_I, Z_II)
return S_I, Z_II
class PairformerBlock(nn.Module):
"""
Attempt to replicate AF3 architecture from scratch.
"""
def __init__(self,
c_s,
c_z,
p_drop,
c,
attention_pair_bias,
n_transition=4,
):
super().__init__()
self.drop_row = Dropout(broadcast_dim=2, p_drop=p_drop)
self.drop_col = Dropout(broadcast_dim=1, p_drop=p_drop)
self.tri_mul_outgoing = TriangleMultiplication(
c_z, d_hidden=c, outgoing=True, bias=False)
self.tri_mul_incoming = TriangleMultiplication(
c_z, d_hidden=c, outgoing=False, bias=False)
self.tri_attn_start = TriangleAttention(
c_z, d_hidden=c, start_node=True)
self.tri_attn_end = TriangleAttention(
c_z, d_hidden=c, start_node=False)
self.z_transition = Transition(c_z, n_transition)
self.s_transition = Transition(c_s, n_transition)
self.attention_pair_bias = AttentionPairBias(c_a=c_s, c_pair=c_z, **attention_pair_bias)
def forward(self,
S_I,
Z_II):
Z_II += self.drop_row(self.tri_mul_outgoing(Z_II))
Z_II += self.drop_row(self.tri_mul_incoming(Z_II))
Z_II += self.drop_row(self.tri_attn_start(Z_II))
Z_II += self.drop_col(self.tri_attn_end(Z_II))
Z_II += self.z_transition(Z_II)
S_I += self.attention_pair_bias(S_I, None, Z_II, Beta_II=0)
S_I += self.s_transition(S_I)
return S_I, Z_II
class FeatureInitializer(nn.Module):
def __init__(self,
c_s,
c_z,
c_atom,
c_atompair,
input_feature_embedder,
relative_position_encoding):
super().__init__()
self.input_feature_embedder = InputFeatureEmbedder(c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, **input_feature_embedder)
self.to_s_init = linearNoBias(c_s, c_s)
self.to_z_init_i = linearNoBias(c_s, c_z)
self.to_z_init_j = linearNoBias(c_s, c_z)
self.relative_position_encoding = RelativePositionEncoding(c_z=c_z, **relative_position_encoding)
self.process_token_bonds = linearNoBias(1, c_z)
def forward(self,
f,
):
S_inputs_I = self.input_feature_embedder(f)
S_init_I = self.to_s_init(S_init_I)
Z_init_II = self.to_z_init_i(S_inputs_I).unsqueeze(-3) + self.to_z_init_j(S_inputs_I).unsqueeze(-2)
Z_init_II += self.relative_position_encoding(f)
Z_init_II += self.process_token_bonds(f['token_bonds'])
return S_inputs_I, S_init_I, Z_init_II
class InputFeatureEmbedder(nn.Module):
def __init__(self,
features,
c_atom,
c_atompair,
c_s,
atom_attention_encoder):
super().__init__()
self.atom_attention_encoder = AtomAttentionEncoder(c_atom=c_atom, c_atompair=c_atompair, c_token=c_s, **atom_attention_encoder)
self.features = features
def forward(self,
f,
A_I,
):
S_I, _, _, _ = self.atom_attention_encoder(A_I)
S_I = torch.concat([A_I] + [f[feature] for feature in self.features])
return S_I
class RelativePositionEncoding(nn.Module):
def __init__(self,
r_max,
s_max,
c_z):
super().__init__()
self.r_max = r_max
self.s_max = s_max
self.c_z = c_z
self.linear = linearNoBias(2*(2*self.r_max+2) + (2*self.s_max+2) + 1, c_z)
def forward(self,
f):
b_samechain_II = f['asym_id'].unsqueeze(-1) == f['asym_id'].unsqueeze(-2)
b_sameresidue_II = f['residue_index'].unsqueeze(-1) == f['residue_index'].unsqueeze(-2)
b_same_entity_II = f['entity_id'].unsqueeze(-1) == f['entity_id'].unsqueeze(-2)
d_residue_II = torch.where(
b_samechain_II,
torch.clip(f['residue_index'].unsqueeze(-2) - f['residue_index'].unsqueeze(-1) + self.r_max, 0, 2*self.r_max),
2 * self.r_max + 1
)
A_relpos_II = one_hot(d_residue_II, 2*self.r_max+2)
d_token_II = torch.where(
b_samechain_II * b_sameresidue_II,
torch.clip(f['token_index'].unsqueeze(-2) - f['token_index'].unsqueeze(-1) + self.r_max, 0, 2*self.r_max),
2 * self.r_max + 1
)
A_reltoken_II = one_hot(d_token_II, 2*self.r_max+2)
d_chain_II = torch.where(
b_samechain_II,
torch.clip(f['sym_id'].unsqueeze(-2) - f['sym_id'].unsqueeze(-1) + self.s_max, 0, 2*self.s_max),
2 * self.s_max + 1
)
A_relchain_II = one_hot(d_chain_II, 2*self.s_max+2)
return self.linear(torch.cat([A_relpos_II, A_reltoken_II, b_same_entity_II.unsqueeze(-1), A_relchain_II], dim=3))
# Mock for testing.
class MSAModule(nn.Module):
def __init__(self, n_block, c_m):
super().__init__()
def forward(self,
f,
Z_II,
S_inputs_I,
):
return Z_II
# Mock for testing.
class TemplateEmbedder(nn.Module):
def __init__(self,
n_block,
c_z,
c):
super().__init__()
self.c =c
self.linear = linearNoBias(c_z, c)
def forward(self,
f,
Z_II,
):
return self.linear(Z_II)
class Loss:
def __init__(self,
sigma_data,
):
self.sigma_data = sigma_data
def __call__(self,
f,
X_L, # [B, L, 3]
X_gt_L, # [B, L, 3]
t, # [B]
):
w_L = 1 + (
f['is_dna']*self.alpha_is_dna +
f['is_rna'] * self.alpha_is_rna +
f['is_ligand'] * self.alpha_is_ligand
)
# Align ground truth onto predictions.
X_gt_aligned_L = weighted_rigid_align(X_gt_L, X_L, w_L)
l_mse = 1/3 * w_L * torch.mean(torch.linalg.norm(X_L, X_gt_aligned_L, dim=-1), dim=-1) # [B]
l_diffusion = (t**2 + self.sigma_data**2) / (t + self.sigma_data)**2 * l_mse
# TODO: implement auxiliary losses
l_total = l_diffusion.sum()
return l_total, {
'diffusion_loss': l_diffusion
}

View File

@@ -41,7 +41,7 @@ class RF2_embedding(nn.Module):
## Update inputs with outputs from previous forward pass
self.recycle = recycling_factory[block_params.recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
self.recycling_type = block_params.recycling_type
assert self.recycling_type == "msa_pair", "no backward compatibility to recycling state"
assert self.recycling_type != "all", "no backward compatibility to recycling state"
def _unpack_inputs(self, rf_inputs):
msa_latent, msa_full, seq, idx, bond_feats, dist_matrix = \
@@ -90,9 +90,7 @@ class RF2_embedding(nn.Module):
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
# No support for recycling state
#state = state + state_recycle # if state is not recycled these will be zeros
# add template embedding
pair, state = self._add_templ_features(rf_inputs, pair, state)
return {
"msa": msa_latent,

View File

@@ -0,0 +1,367 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors, SE3TransformerWrapper
from rf2aa.model.AF3_structure import FourierEmbedding
from rf2aa.model.Track_module import Str2Str
from rf2aa.util_module import rbf, make_topk_graph, make_full_graph, init_lecun_normal
from rf2aa.util import is_atom,is_nucleic,is_protein
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.loss.loss import calc_chiral_grads
from rf2aa.model.layers.Attention_module import FeedForwardLayer
from rf2aa.flow_matching import data_utils as du
import dgl
def get_bondgraph(bonds, num_bonds, dist_matrix, idx, is_prot, is_na, is_atom):
''' Combine different sources of info to get bond-distance graph for all i,j pairs
'''
L = bonds.shape[1]
# intra-prot/intra-na bonds
bonds[:,torch.arange(L),:,torch.arange(L),:] = num_bonds.transpose(0,1)
# ligand bonds
ia = is_atom.nonzero()[:,0]
bonds[:,ia[:,None],1,ia[None,:],1] = dist_matrix[ia[:,None],ia[None,:]].to(dtype=bonds.dtype)
# we need to handle covalent bonds between residues
# to reduce computational load only consider +/- 1 residue
# prot-prot
ii,jj = ChemData().protein_connect
resmask = (is_prot[0,:-1] * is_prot[0,1:] * ((idx[1:]-idx[:-1])==1)).nonzero()[...,0]
bonds[:,resmask+1,:,resmask ,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
bonds[:,resmask ,:,resmask+1,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
# na-na
ii,jj = ChemData().na_connect
resmask = (is_na[0,:-1] * is_na[0,1:] * ((idx[1:]-idx[:-1])==1)).nonzero()[...,0]
bonds[:,resmask+1,:,resmask ,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
bonds[:,resmask ,:,resmask+1,:] = (num_bonds[:,resmask+1,ii:(ii+1)] + num_bonds[:,resmask,:,jj:jj+1] + 1).transpose(0,1)
return bonds
#
def make_atom_graph(
xyz, mask, is_prot, is_na, is_atom, idx, num_bonds, dist_matrix, top_k=24, max_nbonds_encode=8, max_nbonds_connect=3
):
'''
Build an atom level graph from a mixed residue/ligand pose
Parameters of interest:
max_nbonds_encode - edge features encode # bonds, max this number
max_nbonds_connect - force connections between atoms this # bonds or fewer
Ensure top_k is large enough for max_nbonds_connect:
with max_nbonds_connect=3, ~15 atoms are brought in by bonds alone
with max_nbonds_connect=2, ~11 atoms are brought in by bonds alone
with max_nbonds_connect=1, ~4 atoms are brought in by bonds alone
'''
B,L,A = xyz.shape[:3]
device = xyz.device
D = torch.norm(
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
)
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
bonds = torch.full_like(D, ChemData().MAX_BOND_DIST, dtype=num_bonds.dtype)
bonds = get_bondgraph(bonds, num_bonds, dist_matrix, idx, is_prot, is_na, is_atom)
# set D to _negative_ for close bonded
# all missing-atom pairs will have D==0 so need to prefer these
D[bonds<=max_nbonds_connect] = -1.0
D[bonds==0] = np.inf # set D large for self
D[~mask2d] = np.inf # set D large for non-atoms
# select top K neighbors for each atom
# keep indices as batch/res/atm indices
nmaxedge = torch.sum(mask)-1 # most edges = num atoms - 1
if top_k > nmaxedge:
top_k = nmaxedge
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, A, top_k)
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
bi,ri,ai = mask.nonzero(as_tuple=True)
bi = bi[:,None].repeat(1,top_k).reshape(-1)
ri = ri[:,None].repeat(1,top_k).reshape(-1)
ai = ai[:,None].repeat(1,top_k).reshape(-1)
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
# on each edge, encode:
# a) 1-hot encode the number of bonds (up to maxbonds) separating each atom
# b) 1/D
bonds = bonds[bi,ri,ai,rj,aj]
bonds[bonds >= max_nbonds_encode] = max_nbonds_encode
natm = torch.sum(mask)
index = torch.zeros_like(mask, dtype=torch.long, device=device)
index[mask] = torch.arange(natm, device=device)
src=index[bi,ri,ai]
tgt=index[bi,rj,aj]
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj])
edge = torch.cat([
F.one_hot(bonds-1),
1 / (torch.norm(G.edata['rel_pos'], dim=-1,keepdim=True)+1)
], dim=-1)
return G, edge
class BiasedSequenceAttention(nn.Module):
def __init__(self, global_params, block_params):
super(BiasedSequenceAttention, self).__init__()
self.norm_state = nn.LayerNorm(global_params.d_state, bias=False)
self.norm_pair = nn.LayerNorm(global_params.d_pair, bias=False)
#
self.to_q = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_k = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_v = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads, bias=False)
self.to_b = nn.Linear(global_params.d_pair, block_params.n_heads, bias=False)
self.to_g = nn.Linear(global_params.d_state, block_params.n_channels*block_params.n_heads)
self.to_out = nn.Linear( block_params.n_channels*block_params.n_heads, global_params.d_state, bias=False)
self.scaling = 1/np.sqrt(block_params.n_channels)
self.h = block_params.n_heads
self.dim = block_params.n_channels
self.transition = FeedForwardLayer(global_params.d_state, 4, p_drop=block_params.msa_transition_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
#nn.init.zeros_(self.to_out.weight)
def forward(self, state, pair): # TODO: make this as tied-attention
B, L = state.shape[:2]
#
state = self.norm_state(state)
pair = self.norm_pair(pair)
query = self.to_q(state).reshape(B, L, self.h, self.dim)
key = self.scaling * self.to_k(state).reshape(B, L, self.h, self.dim)
value = self.to_v(state).reshape(B, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(state))
attn = einsum('bqhd,bkhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
out = gate * out
out = self.to_out(out)
out = state + self.transition(out)
return out
class GenerativeRefinement(nn.Module):
def __init__(self, global_params, block_params) -> None:
super(GenerativeRefinement, self).__init__()
self.proj_atoms = nn.Linear(ChemData().NELTTYPES, global_params.d_state)
self.proj_edge = nn.Linear(9, block_params.num_edge_features) #max_bonds_encode+1, need to refactor
self.norm_state_0 = nn.LayerNorm(global_params.d_state+global_params.d_msa)
self.proj_state_0 = nn.Linear(global_params.d_state+global_params.d_msa, global_params.d_state)
self.ff_state_0 = FeedForwardLayer(global_params.d_state, 2, zero_init=False)
self.norm_pair_0 = nn.LayerNorm(global_params.d_pair)
self.proj_pair_0 = nn.Linear(global_params.d_pair, global_params.d_pair)
self.ff_pair_0 = FeedForwardLayer(global_params.d_pair, 2, zero_init=False)
self.timestep_embedding_dim = 256
self.fourier_embedding = FourierEmbedding(self.timestep_embedding_dim)
self.norm_timstep_emb = nn.LayerNorm(self.timestep_embedding_dim)
self.emb_timestep = nn.Linear(self.timestep_embedding_dim, global_params.d_state, bias=False)
self.proj_state_1 = nn.Linear(global_params.d_state, global_params.d_state, bias=False)
self.proj_state_2 = nn.Linear(global_params.d_state, global_params.d_state, bias=False)
self.norm_state_1 = nn.LayerNorm(global_params.d_state)
self.norm_state_2 = nn.LayerNorm(global_params.d_state)
self.sigma_data = 16 # expose as parameter
self.atom_encoder = nn.ModuleList(
[
SE3TransformerWrapper(
num_layers=block_params.num_layers,
num_channels=block_params.num_channels,
num_degrees=block_params.num_degrees,
n_heads=block_params.n_heads,
div=block_params.div,
l0_in_features=block_params.l0_in_features,
l0_out_features=block_params.l0_out_features,
l1_in_features=1,
l1_out_features=1,
num_edge_features=block_params.num_edge_features,
compute_gradients=True
)
]
)
self.token_processing = nn.ModuleList(
[
BiasedSequenceAttention(global_params, block_params)
for i in range(block_params.num_attention_layers)
]
)
self.atom_decoder = nn.ModuleList(
[
SE3TransformerWrapper(
num_layers=block_params.num_layers,
num_channels=block_params.num_channels,
num_degrees=block_params.num_degrees,
n_heads=block_params.n_heads,
div=block_params.div,
l0_in_features=block_params.l0_in_features,
l0_out_features=0,
l1_in_features=1,
l1_out_features=1,
num_edge_features=block_params.num_edge_features,
compute_gradients=True
)
]
)
def _unpack_latents(self, latent_feats):
msa, pair, state = latent_feats["msa"], latent_feats["pair"], latent_feats["state"]
seq_unmasked = latent_feats["seq_unmasked"]
allatom_mask = ChemData().allatom_mask.to(state.device)
is_valid_atom = allatom_mask[seq_unmasked]
num_bonds = ChemData().num_bonds.to(state.device)
num_bonds_sequence = num_bonds[seq_unmasked]
xyz = latent_feats["trans_t"][0]
t = latent_feats["t"]
is_atomized = is_atom(seq_unmasked)
is_prot = is_protein(seq_unmasked)
is_na = is_nucleic(seq_unmasked)
dist_matrix = latent_feats["dist_matrix"][0]
idx = latent_feats["idx"][0]
return (
msa, pair, state, seq_unmasked,
is_valid_atom, num_bonds_sequence, xyz, t,
is_atomized, is_prot, is_na, dist_matrix, idx
)
def _embed_1d(self, latent_feats):
seq_unmasked = latent_feats["seq_unmasked"]
# feature set 1: element
elts = ChemData().aa2eltidx.to(seq_unmasked.device)
elts = elts[seq_unmasked]
elts = torch.nn.functional.one_hot(elts, ChemData().NELTTYPES).float()
return self.proj_atoms(elts)
def forward(self, latent_feats):
# get outputs + noised xyz
msa, pair, state, seq_unmasked, is_valid_atom, num_bonds_sequence, xyz, t, \
is_atomized, is_prot, is_na, dist_matrix, idx \
= self._unpack_latents(latent_feats)
t_hat = (1-t)*du.NM_TO_ANG_SCALE ## ?
# initial state embedding from msa (single seq) + state
msa = msa.squeeze(1)
state = torch.cat([msa,state], dim=-1)
state = self.proj_state_0(self.norm_state_0(state))
# add timestep embedding
timestep_emb_T = self.fourier_embedding(1/4 * torch.log(t_hat/self.sigma_data))
timestep_emb_T = self.emb_timestep(self.norm_timstep_emb(timestep_emb_T))
state = state + timestep_emb_T
state = self.ff_state_0(state)
# initial pair embedding
pair = self.proj_pair_0(self.norm_pair_0(pair))
pair = self.ff_pair_0(pair)
# add atom embeddings
# to do: a lot more...
atomstate = state[...,None,:] + self._embed_1d(latent_feats)
# encode atom level
atomstate, xyz = self.atom_update(
t_hat, xyz, atomstate, is_valid_atom,
is_prot, is_na, is_atomized, idx,
num_bonds_sequence, dist_matrix, encoder=True)
# to residue level
atomstate = F.relu_(self.proj_state_1(F.relu_(atomstate)))
state = (
self.proj_state_1(self.norm_state_1(state))
+ (atomstate * is_valid_atom[..., None]).sum(dim=2) / is_valid_atom.sum(dim=2)[...,None] # project down atomstate
)
# res level updates
for layer in self.token_processing:
state = layer(state, pair)
# combine atom&res level state, decode atom level
atomstate = atomstate + self.proj_state_2(self.norm_state_2(state))[...,None,:]
atomstate, xyz = self.atom_update(
t_hat, xyz, atomstate, is_valid_atom,
is_prot, is_na, is_atomized, idx,
num_bonds_sequence, dist_matrix, encoder=False)
# to residue level
state = (atomstate * is_valid_atom[..., None]).sum(dim=2) / is_valid_atom.sum(dim=2)[...,None]
return {
"state": state,
"xyz": xyz,
}
def atom_update(self, t_hat, xyz, state, is_valid_atom, is_prot, is_na, is_atomized, idx, num_bonds_sequence, dist_matrix, encoder=True):
SE3_SCALE = 10.0 #to do: make this a parameter
if encoder:
layers = self.atom_encoder
else:
layers = self.atom_decoder
for layer in layers:
G, edge = make_atom_graph(
xyz, is_valid_atom, is_prot, is_na, is_atomized, idx, num_bonds_sequence, dist_matrix,
)
node = state[is_valid_atom]
node_l1 = xyz[is_valid_atom].unsqueeze(-2) #torch.ones((node.shape[0], 3,3), device=state.device, dtype=state.dtype)
edge = self.proj_edge(edge).unsqueeze(-1)
shift = layer(G, node[..., None], node_l1, edge)
xyz[is_valid_atom] = xyz[is_valid_atom] + shift["1"].squeeze(1)/SE3_SCALE
if encoder:
state[is_valid_atom] = state[is_valid_atom] + shift["0"][...,0]/SE3_SCALE
else:
xyz = xyz+shift["0"].sum() #fd hack to avoid unused grads (even though shift['0'] is of dim 0)
return state, xyz

View File

@@ -7,23 +7,24 @@ from einops import rearrange
from rf2aa.util_module import init_lecun_normal
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, r_ff, p_drop=0.1):
def __init__(self, d_model, r_ff, p_drop=0.1, zero_init=True):
super(FeedForwardLayer, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_model*r_ff)
self.dropout = nn.Dropout(p_drop)
self.linear2 = nn.Linear(d_model*r_ff, d_model)
self.reset_parameter()
self.reset_parameter(zero_init)
def reset_parameter(self):
def reset_parameter(self,zero_init):
# initialize linear layer right before ReLu: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear1.bias)
# initialize linear layer right before residual connection: zero initialize
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
if zero_init:
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, src):
src = self.norm(src)

View File

@@ -77,6 +77,7 @@ class MSA_emb(nn.Module):
# pair embedding
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix, same_chain=same_chain,cyclize=cyclize)
# state embedding
state = self._state_emb(seq)
return msa, pair, state
@@ -522,6 +523,21 @@ class Templ_emb_NoPtwise(nn.Module):
return pair, state
class RecyclingMSAPairOnly(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=0):
super(RecyclingMSAPairOnly, self).__init__()
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_pair = nn.Linear(d_pair, d_pair, bias=False)
self.norm_msa = nn.LayerNorm(d_msa)
self.proj_msa = nn.Linear(d_msa, d_msa, bias=False)
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = msa.shape[:2]
msa = msa + self.proj_msa(self.norm_msa(msa))
pair = pair + self.proj_pair(self.norm_pair(pair))
return msa, pair, state # state is dummy
class Recycling(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
super(Recycling, self).__init__()
@@ -593,6 +609,7 @@ class RecyclingAllFeatures(nn.Module):
return msa, pair, state
recycling_factory = {
"msa_pair_only": RecyclingMSAPairOnly,
"msa_pair": Recycling,
"all": RecyclingAllFeatures
}

View File

@@ -27,7 +27,7 @@ class SE3TransformerWrapper(nn.Module):
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32,
compute_gradients=False):
compute_gradients=False, **kwargs):
super().__init__()
# Build the network
self.l1_in = l1_in_features
@@ -63,36 +63,16 @@ class SE3TransformerWrapper(nn.Module):
populate_edge="arcsin",
final_layer="lin",
use_layer_norm=True,
compute_gradients=compute_gradients
compute_gradients=compute_gradients,
)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
if self.l1_out > 0:
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
pass
#nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
#if self.l1_out > 0:
# nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
@@ -282,7 +262,6 @@ class FullyConnectedSE3(FullyConnectedSE3_noR):
def compute_structure_update(self, G, node, l1_feats, edge_feats, xyz, state, is_atom, drop_layer=False, is_motif=None):
weight = 0. if drop_layer else 1.
B, L = node.shape[:2]
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
@@ -320,6 +299,8 @@ class FullyConnectedSE3(FullyConnectedSE3_noR):
quat_update = torch.stack([qA, qB, qC, qD], dim=2)
return state, xyz, quat_update
def forward(self, msa, pair, state, xyz, is_atom, atom_frames, chirals,idx, bond_feats, dist_matrix, drop_layer=False, is_motif=None):
block_outputs = super().forward(msa, pair, state, xyz, is_atom, atom_frames, chirals,idx, bond_feats, dist_matrix, is_motif=is_motif)

View File

@@ -36,7 +36,7 @@ class RosettaFold(nn.Module):
self.simulator = nn.ModuleList(blocks)
n_refinement_blocks = len(model_params.refinement.keys())
assert n_refinement_blocks <= 1, "only can have one refinment block"
assert n_refinement_blocks <= 1, "only can have one refinement block"
self.refinement = None
if n_refinement_blocks == 1:
refinement_type = next(iter(model_params.refinement.keys()))
@@ -60,7 +60,7 @@ class RosettaFold(nn.Module):
for aux_task in model_params.auxiliary_predictors.keys()
}
def forward(self, rf_inputs, use_checkpoint, use_amp):
def forward(self, rf_inputs, use_checkpoint, use_amp, skip_refinement=False):
latent_feats = self.embedding(rf_inputs)
#load useful primitives into latent_features
latent_feats.update(
@@ -73,20 +73,24 @@ class RosettaFold(nn.Module):
"bond_feats": rf_inputs["bond_feats"],
"dist_matrix": rf_inputs["dist_matrix"],
"is_motif": rf_inputs.get("is_motif", None),
"seq_unmasked": rf_inputs["seq_unmasked"],
"trans_t": rf_inputs.get("trans_t", None),
"t": rf_inputs.get("t", None),
}
)
for block in self.simulator:
latent_feats = block(latent_feats, use_checkpoint, use_amp)
rf_outputs = {}
if self.refinement:
rf_outputs = self.refinement(latent_feats)
for aux_task, aux_predictor in self.auxiliary_predictors.items():
input_feature = self.auxiliary_predictor_input_feats[aux_task]
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
rf_outputs.update({aux_task: auxiliary_predictions})
if not skip_refinement:
if self.refinement:
rf_outputs = self.refinement(latent_feats)
for aux_task, aux_predictor in self.auxiliary_predictors.items():
input_feature = self.auxiliary_predictor_input_feats[aux_task]
auxiliary_predictions = aux_predictor(latent_feats[input_feature])
rf_outputs.update({aux_task: auxiliary_predictions})
return rf_outputs, latent_feats

View File

@@ -2,11 +2,13 @@ import torch
import torch.nn as nn
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, get_backbone_offset_vectors, get_chiral_vectors, SE3TransformerWrapper
from rf2aa.model.Track_module import Str2Str
from rf2aa.util_module import rbf, make_topk_graph, init_lecun_normal
from rf2aa.util_module import rbf, make_topk_graph, make_full_graph, init_lecun_normal
from rf2aa.util import is_atom,is_nucleic,is_protein
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.loss.loss import calc_chiral_grads
from rf2aa.model.generative_refinement import GenerativeRefinement
class LocalRefinementSE3(FullyConnectedSE3):
@@ -58,7 +60,11 @@ class LocalRefinementSE3(FullyConnectedSE3):
def construct_graph(self, xyz, edge):
L = xyz.shape[1]
idx = torch.arange(L, device=edge.device)[None]
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
if self.top_k>0:
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=self.top_k)
else:
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx) # top_k=1 -> fully connected
return G, edge_feats
class RecurrentLocalRefinement(nn.Module):
@@ -75,7 +81,7 @@ class RecurrentLocalRefinement(nn.Module):
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
latent_feats["atom_frames"], latent_feats["chirals"]
idx, bond_feats, dist_matrix = latent_feats["idx"], latent_feats["bond_feats"], latent_feats["dist_matrix"]
return msa, pair, state, xyz, is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
def forward(self, latent_feats):
B, N, L = latent_feats["msa"].shape[:3]
@@ -110,7 +116,7 @@ class RecurrentLocalRefinement_w_Adaptor(nn.Module):
latent_feats["state"], latent_feats["xyz"], latent_feats["is_atom"], \
latent_feats["atom_frames"], latent_feats["chirals"]
idx, bond_feats, dist_matrix = latent_feats["idx"], latent_feats["bond_feats"], latent_feats["dist_matrix"]
return msa, pair, state, xyz, is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
return msa, pair, state, xyz[..., :3, :], is_atom, atom_frames, chirals, idx, bond_feats, dist_matrix
def forward(self, latent_feats):
B, N, L = latent_feats["msa"].shape[:3]
@@ -199,5 +205,6 @@ class LegacyRefiner(nn.Module):
refinement_factory ={
"local": RecurrentLocalRefinement,
"local_adaptor": RecurrentLocalRefinement_w_Adaptor,
"legacy": LegacyRefiner
"legacy": LegacyRefiner,
"generative": GenerativeRefinement
}

View File

@@ -5,6 +5,7 @@ from functools import partial
import numpy as np
from rf2aa.debug import debug_nans
from rf2aa.model.AF3_blocks import AF3_block,AF3_full_block
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
@@ -287,11 +288,13 @@ class RF2_untied(RF2_block):
block_factory = {
"RF2aa": partial(RF2_block, is_full=False),
"RF2aa_full": partial(RF2_block, is_full=True),
"untied_p2p": partial(RF2_untied, is_full=False),
"untied_p2p_full": partial(RF2_untied, is_full=True),
"AF3": AF3_block,
"AF3_full": AF3_full_block,
"RF2aa": partial(RF2_block, is_full=False),
"RF2aa_full": partial(RF2_block, is_full=True),
"untied_p2p": partial(RF2_untied, is_full=False),
"untied_p2p_full": partial(RF2_untied, is_full=True),
"RF2_withgradients": partial(RF2_withgradients, is_full=False),
"RF2_withgradients_full": partial(RF2_withgradients, is_full=True),
"RF2_withgradients_R": partial(RF2_block, is_full=False, backprop_through_xyz=True),
"RF2_withgradients_R": partial(RF2_block, is_full=False, backprop_through_xyz=True),
}

View File

@@ -0,0 +1,61 @@
import pytest
import torch
from rf2aa.model.AF3_blocks import MsaSubsampleEmbedder
def test_msa_module():
pass
def test_msa_subsampler():
B, N, L = 1, 100, 20
params = {
"num_sequences": 256,
"msa_dim": 20,
"msa_channels": 64,
"S_dim": 32
}
msa_SI = torch.rand(B, N, L, 20)
S_inputs = torch.rand(B, L, 32)
subsampler = MsaSubsampleEmbedder(params)
msa_SI = subsampler(msa_SI, S_inputs)
assert msa_SI.shape == (B, N, L, 64)
B, N, L = 1, 500, 20
params = {
"num_sequences": 256,
"msa_dim": 20,
"msa_channels": 64,
"S_dim": 32
}
msa_SI = torch.rand(B, N, L, 20)
S_inputs = torch.rand(B, L, 32)
subsampler = MsaSubsampleEmbedder(params)
msa_SI = subsampler(msa_SI, S_inputs)
assert msa_SI.shape == (B, 256, L, 64)
def test_msa_pair_weighted_average():
pass
def test_msa_weighting_einsum():
B, I, S, H, c = 1, 5, 10, 8, 4
gate_SIH = torch.randn(B, S, I, H, c)
w_IIH = torch.randn(B, I, I, H)
v_SIH = torch.randn(B, S, I, H, c)
# Initialize the result tensor
C = torch.zeros((B, S, I, H, c))
# Perform the einsum contraction in smaller steps
#for idx_b in range(B):
#for idx_s in range(S):
#for idx_i in range(I):
#for idx_h in range(H):
#for idx_c in range(c):
#C[idx_b, idx_s, idx_i, idx_h, idx_c] = torch.sum(
#v_SIH[idx_b, idx_s, :, idx_h, idx_c] * w_IIH[idx_b, :, idx_i, idx_h]
#)
unaggregated_weights = torch.einsum("bsihc, bijh -> bsijhc", v_SIH, w_IIH)
weights = torch.einsum("bsihc, biih -> bsihc", v_SIH, w_IIH)
o_SIH = gate_SIH * weights

View File

@@ -21,7 +21,7 @@ from functools import partial
# goal is to test all the configs on a broad set of datasets
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
test_conditions, test_ids = setup_benchmark_array(["pdb196"], ["rf2aa","rf2_deep_layerdropout"])
test_conditions, test_ids = setup_benchmark_array(["pdb196"], ["rf2aa","rf2_deep_layerdropout","af3"])
def setup_test(example, trainer):
model = trainer.model
@@ -76,7 +76,7 @@ def test_benchmark_fw_bw(benchmark, example, trainer):
def run():
output_i = recycle_step_packed(trainer.model, network_input, 1, trainer.config.training_params.use_amp, nograds=False, force_device=gpu)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(dataloader_inputs, device=gpu)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, _, _ = get_loss_calc_items(dataloader_inputs, device=gpu)
loss, loss_dict = get_loss_and_misc(
trainer,

View File

@@ -24,7 +24,7 @@ sm_compl_asmb_item = {'CHAINID': '4i7z_D', 'DEPOSITION': '2012-12-01', 'RESOLUTI
benchmark196_item = {'Unnamed: 0': 489, 'CHAINID': '4y7y_BA', 'DEPOSITION': '2015-02-16', 'RESOLUTION': 2.4, 'HASH': '105747', 'CLUSTER': 10756, 'SEQUENCE': 'TSIMAVTFKDGVILGADSRTTTGAYIANRVTDKLTRVHDKIWCCRSGSAADTQAIADIVQYHLELYTSQYGTPSTETAASVFKELCYENKDNLTAGIIVAGYDDKNKGEVYTIPLGGSVHKLPYAIAGSGSTFIYGYCDKNFRENMSKEETVDFIKHSLSQAIKWDGSSGGVIRMVVLTAAGVERLIFYPDEYEQL', 'LEN_EXIST': 196, 'TAXID': ''}
# configurations to test
configs = ["rf2aa", "rf2_deep_layerdropout", "legacy_train", "rf2aa_legacy_refinement", "rf_with_gradients", "untied_p2p"]
configs = ["rf2aa", "rf2_deep_layerdropout", "legacy_train", "rf2aa_legacy_refinement", "rf_with_gradients", "untied_p2p", "af3"]
datasets = {
"pdb" : (loader_pdb, pdb_item, {"homo": {"CHAIN_A": pd.Series(dtype=np.float32)}}),
"compl" : (loader_complex, compl_item, {}),

View File

@@ -1,5 +1,6 @@
import os
import torch
import hydra
from torch.nn.parallel import DistributedDataParallel as DDP
import pytest
@@ -16,10 +17,10 @@ from rf2aa.data.compose_dataset import compose_single_item_dataset
# goal is to test all the configs on a broad set of datasets
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
test_conditions, test_ids = setup_benchmark_array(["pdb256"], ["rf_with_gradients"], device=gpu)
#test_conditions, test_ids = setup_benchmark_array(["pdb256"], ["rf_with_gradients"], device=gpu)
def setup_test(example, trainer):
model = trainer.model
#model = trainer.model
config = trainer.config.chem_params
# initialize chemical database
@@ -28,14 +29,16 @@ def setup_test(example, trainer):
# to GPU
trainer.move_constants_to_device(gpu)
model = model.to(gpu)
trainer.construct_model(device=gpu)
#model = model.to(gpu)
dataset_name = example[0]
item, loader_params, _, loader, loader_kwargs = example[1:]
#HACK: reduce crop size
loader_params["CROP"] = 200
loader_params["MAXCYCLE"] = 10
loader_params["CROP"] = 100
loader_params["MAXCYCLE"] = 4
# read from disk, move to device
dataloader = compose_single_item_dataset( None, item, loader_params, loader, loader_kwargs)
return dataloader
@@ -49,7 +52,7 @@ def test_minimize_example(example, trainer):
dataloader = setup_test(example, trainer)
trainer.train_loader = dataloader
trainer.model = DDP(trainer.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
#trainer.model = DDP(trainer.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
trainer.construct_optimizer()
trainer.construct_scheduler()
@@ -64,9 +67,18 @@ def test_minimize_example(example, trainer):
for file in glob("models/*"):
os.remove(file)
if __name__ == "__main__":
example, trainer = test_conditions[0]
@hydra.main(config_path="../config/train", config_name="base")
def main(config):
from rf2aa.trainer_new import trainer_factory
trainer = trainer_factory[config.experiment.trainer](config)
example = data["pdb"]
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
os.environ['MASTER_PORT'] = '%d'%trainer.config.ddp_params.port
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
test_minimize_example(example, trainer)
test_minimize_example(example, trainer)
if __name__ == "__main__":
from rf2aa.tests.test_conditions import setup_data
data = setup_data()
main()

View File

@@ -10,19 +10,20 @@ import time
import omegaconf
from contextlib import nullcontext
import datetime
from datetime import timedelta
import certifi
import warnings
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
from rf2aa.flow_matching.interpolant import Interpolant
from rf2aa.flow_matching.sampler import Sampler
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads
from rf2aa.training.EMA import EMA, count_parameters
from rf2aa.loss.loss import translation_vector_field
from rf2aa.loss.loss_factory import get_loss_and_misc
from rf2aa.training.optimizer import add_weight_decay
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_sampling, run_model_forward
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward
from rf2aa.model.network import RosettaFold
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
@@ -93,8 +94,24 @@ class Trainer:
def load_model(self):
torch.cuda.empty_cache()
self.model.module.model.load_state_dict(self.checkpoint["final_state_dict"], strict=True)
self.model.module.shadow.load_state_dict(self.checkpoint["model_state_dict"], strict=False)
new_model_state = {}
new_shadow_state = {}
state_dict = self.model.module.model.state_dict()
for param in state_dict:
if param not in self.checkpoint['model_state_dict']:
print ('missing',param)
elif (self.checkpoint['model_state_dict'][param].shape == state_dict[param].shape):
new_model_state[param] = self.checkpoint['final_state_dict'][param]
new_shadow_state[param] = self.checkpoint['model_state_dict'][param]
else:
print (
'wrong size',param,
self.checkpoint['model_state_dict'][param].shape,
state_dict[param].shape )
self.model.module.model.load_state_dict(new_model_state, strict=False)
self.model.module.shadow.load_state_dict(new_shadow_state, strict=False)
print("Checkpoint loaded into model")
def load_optimizer(self):
@@ -179,7 +196,7 @@ class Trainer:
def init_process_group(self, rank, world_size):
gpu = rank % torch.cuda.device_count()
dist.init_process_group(backend=self.config.training_params.ddp_backend, world_size=world_size, rank=rank)
dist.init_process_group(backend=self.config.training_params.ddp_backend, timeout=timedelta(seconds=1800), world_size=world_size, rank=rank)
torch.cuda.set_device("cuda:%d"%gpu)
return gpu
@@ -226,7 +243,6 @@ class Trainer:
self.move_constants_to_device(gpu)
self.construct_model(device=gpu)
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
if rank == 0:
print(f"Loading model with {count_parameters(self.model)} parameters")
@@ -249,6 +265,11 @@ class Trainer:
self.config.experiment.n_epoch,
self.config.dataset_params.n_train,
world_size)
#for d, valid_sampler in valid_samplers.items():
# valid_sampler.set_epoch(start_epoch-1)
#self.valid_epoch(start_epoch-1, rank, world_size)
for epoch in range(start_epoch,self.config.experiment.n_epoch):
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
self.train_epoch(epoch, rank, world_size)
@@ -279,8 +300,27 @@ class Trainer:
# aggregate loss and update parameters
loss = loss / self.config.ddp_params.accum
if (torch.any(torch.isnan(loss))):
print ('NAN in loss',inputs[-1])
print ('NAN in loss',loss_dict)
exit(1)
self.scaler.scale(loss).backward()
hasnans=False
for n,p in self.model.named_parameters():
if (p.grad is not None and torch.any(torch.isnan(p.grad))):
hasnans = True
if hasnans:
print ('NAN in grad')
for n,p in self.model.named_parameters():
if (p.grad is not None):
print (n, torch.max( torch.abs(p.flatten()) ), torch.max( torch.abs(p.grad.flatten()) ))
exit(1)
train_time = time.time() - start_time
if train_idx%self.config.ddp_params.accum == 0:
self.update_parameters()
@@ -345,6 +385,7 @@ class Trainer:
def update_parameters(self):
""" scale, clip gradients and update parameters """
# gradient clipping
#debug_grads(self.model)
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_params.grad_clip)
self.scaler.step(self.optimizer)
@@ -390,6 +431,8 @@ class LegacyTrainer(Trainer):
).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
""" take an input from dataloader, run the model and compute a loss """
@@ -401,7 +444,7 @@ class LegacyTrainer(Trainer):
= prepare_input(inputs, self.xyz_converter, gpu)
output_i = recycle_step_legacy(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=nograds)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames = get_loss_calc_items(inputs, device=gpu)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, _, _ = get_loss_calc_items(inputs, device=gpu)
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
@@ -410,7 +453,8 @@ class LegacyTrainer(Trainer):
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, None, None,
unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
if return_outputs:
@@ -429,6 +473,7 @@ class ComposedTrainer(Trainer):
self.model = RosettaFold(self.config).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
""" take an input from dataloader, run the model and compute a loss """
@@ -452,7 +497,8 @@ class ComposedTrainer(Trainer):
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, None, None,
unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
@@ -468,7 +514,8 @@ class FlowMatchingTrainer(Trainer):
self.model = RosettaFold(self.config).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
self.sampler = Sampler(self.model,
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
self.sampler = AllAtomSampler(self.model,
self.config.interpolant.sampling.num_timesteps,
self.config.interpolant.min_t,
self.interpolant,
@@ -482,11 +529,47 @@ class FlowMatchingTrainer(Trainer):
def train_step(self, inputs, n_cycle, no_grads=False, return_outputs=False):
gpu = self.model.device
#try:
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input_fm(inputs, self.interpolant, self.xyz_converter, gpu)
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label, \
r3_t, trans_1, mask_allatom \
= prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, gpu)
output_i = recycle_step_packed(self.model, network_input, 1, self.config.training_params.use_amp, nograds=no_grads)
output_i = recycle_step_gen(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=no_grads)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask = get_loss_calc_items(inputs, device=gpu)
logit_s, logit_aa_s, logit_pae, logit_pde, p_bind, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = output_i
#loss = (pred_allatom - true_crds).mean()
#loss_dict = {"loss": loss.mean()}
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
msa = msa.to(gpu)
mask_msa = mask_msa.to(gpu)
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
if return_outputs:
return loss, loss_dict, output_i
else:
return loss, loss_dict
def valid_step(self, inputs, n_cycle, no_grads=True, return_outputs=False):
gpu = self.model.device
#try:
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label, \
r3_t, trans_1, mask_allatom \
= prepare_input_fm_allatom(inputs, self.interpolant, self.xyz_converter, gpu)
#output_i = recycle_step_gen(self.model, network_input, n_cycle, self.config.training_params.use_amp, nograds=no_grads)
with torch.no_grad():
output_i = self.sampler.sample(inputs, n_cycle=n_cycle, use_amp=self.config.training_params.use_amp)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask = get_loss_calc_items(inputs, device=gpu)
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
@@ -497,37 +580,14 @@ class FlowMatchingTrainer(Trainer):
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, trans_1, r3_t,
unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
if return_outputs:
return loss, loss_dict, output_i
else:
return loss, loss_dict
def valid_step(self, inputs, n_cycle, no_grads=True, return_outputs=False):
gpu = self.interpolant._device
self.sampler.model = self.model
task, item, network_input, true_crds, \
atom_mask, msa, mask_msa, unclamp, negative, symmRs, Lasu, ch_label \
= prepare_input_fm(inputs, self.interpolant, self.xyz_converter, gpu)
output_i = self.sampler.sample(inputs)
seq, same_chain, idx_pdb, bond_feats, dist_matrix, atom_frames, true_crds, atom_mask \
= get_loss_calc_items(inputs, device=gpu)
#HACK: indexing into msa and mask msa recycle dimension in arguments of this function
#HACK: need to promote some inputs to gpu for loss calculation, all promotions should happen together
msa = msa.to(gpu)
mask_msa = mask_msa.to(gpu)
loss, loss_dict = get_loss_and_misc(
self, # avoid reloading constants to device
output_i, true_crds, atom_mask, same_chain,
seq, msa[:, n_cycle-1], mask_msa[:, n_cycle-1], idx_pdb, bond_feats, dist_matrix, atom_frames, unclamp, negative, task, item, symmRs, Lasu, ch_label,
self.config.loss_param
)
#fd last layer l0 are unused in grads
#fd to do: fix this in refinement module
loss += 0.0*output_i[-1].sum()
if return_outputs:
return loss, loss_dict, output_i
@@ -543,7 +603,8 @@ class FlowMatchingTrainer(Trainer):
for valid_idx, inputs in enumerate(valid_loader):
n_cycle = self.config.loader_params.maxcycle
loss, loss_dict = self.valid_step(inputs, n_cycle)
loss, loss_dict = self.valid_step(inputs, n_cycle)
#print (loss_dict)
if valid_loss_dict is None:
valid_loss_dict = torch.zeros_like(torch.stack(list(loss_dict.values())))

View File

@@ -18,7 +18,6 @@ def recycle_step_legacy(ddp_model, input, n_cycle, use_amp, nograds=False, force
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
#stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
if i_cycle < n_cycle -1 or nograds is True:
stack.enter_context(torch.no_grad())
stack.enter_context(ddp_model.no_sync())
@@ -41,7 +40,6 @@ def recycle_step_packed(ddp_model, input, n_cycle, use_amp, nograds=False, force
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
#stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
if i_cycle < n_cycle -1 or nograds is True:
stack.enter_context(torch.no_grad())
stack.enter_context(ddp_model.no_sync())
@@ -54,6 +52,30 @@ def recycle_step_packed(ddp_model, input, n_cycle, use_amp, nograds=False, force
return output_i
def recycle_step_gen(ddp_model, input, n_cycle, use_amp, nograds=False, force_device=None):
""" exactly same logic as legacy recycling, except inputs and outputs are dictionaries"""
if force_device is not None:
gpu = force_device
else:
gpu = ddp_model.device
xyz_prev, alpha_prev, mask_recycle = \
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
if i_cycle < n_cycle -1 or nograds is True:
stack.enter_context(torch.no_grad())
stack.enter_context(ddp_model.no_sync())
return_raw = (i_cycle < n_cycle -1)
use_checkpoint = not nograds and (i_cycle == n_cycle -1)
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
rf_outputs, rf_latents = ddp_model(input_i, use_checkpoint=use_checkpoint, use_amp=use_amp, skip_refinement=return_raw)
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
return output_i
def run_model_forward(model, network_input, use_checkpoint=False, device="cpu"):
""" run model forward pass, no recycling, no ddp (for tests) """
gpu = device
@@ -90,16 +112,38 @@ def unpack_outputs(rf_outputs, rf_latents, return_raw):
msa, pair, state = rf_latents["msa"], rf_latents["pair"], rf_latents["state"]
if return_raw:
xyz_prev = rf_outputs["xyzs"][-1][None]
alpha_prev = rf_outputs["alphas"][-1]
xyz_prev, alpha_prev = None, None
if "xyzs" in rf_outputs:
xyz_prev = rf_outputs["xyzs"][-1][None]
if "alphas" in rf_outputs:
alpha_prev = rf_outputs["alphas"][-1]
return msa[:, 0], pair, xyz_prev, alpha_prev, None # mask_recycle is always None
else:
c6d_logits, mlm_logits, pae_logits, plddt_logits = rf_outputs["c6d"], rf_outputs["mlm"], \
rf_outputs["pae"], rf_outputs["plddt"]
c6d_logits, mlm_logits, pae_logits, plddt_logits = None,None,None,None
if "c6d" in rf_outputs:
c6d_logits = rf_outputs["c6d"]
if "mlm" in rf_outputs:
mlm_logits = rf_outputs["mlm"]
if "pae" in rf_outputs:
pae_logits = rf_outputs["pae"]
if "plddt" in rf_outputs:
plddt_logits = rf_outputs["plddt"]
pde_logits = None
p_bind = None
xyz, alphas = rf_outputs["xyzs"], rf_outputs["alphas"]
if "xyzs" in rf_outputs:
xyz = rf_outputs["xyzs"]
else:
xyz = rf_outputs["xyz"][..., :3, :][None]
if "state" in rf_outputs:
state = rf_outputs["state"]
B = 1
L = xyz.shape[2]
if "alphas" in rf_outputs:
alphas = rf_outputs["alphas"]
else:
alphas = torch.zeros((1, B, L, ChemData().NTOTALDOFS, 2), device=xyz.device)
if "xyz_intermediate" in rf_latents:
intermediate_xyzs = torch.stack(rf_latents["xyz_intermediate"], dim=0)
xyz = torch.cat((intermediate_xyzs, xyz), dim=0)
@@ -107,8 +151,8 @@ def unpack_outputs(rf_outputs, rf_latents, return_raw):
if "alpha_intermediate" in rf_latents:
alpha_intermediate = torch.stack(rf_latents["alpha_intermediate"], dim=0)
alphas = torch.cat((alpha_intermediate, alphas), dim=0)
xyz_allatom = None
# HACK: breaking change for all atom flow matching
xyz_allatom = rf_outputs["xyz"]
return (c6d_logits, mlm_logits, pae_logits, pde_logits, p_bind,
xyz, alphas, xyz_allatom, plddt_logits, msa[:, 0], pair, state)

View File

@@ -1469,7 +1469,8 @@ def get_residue_contacts(xyz, idx, seq_dist_greater_than=10, n_contacts=5):
kk = k
for j in nodes:
ds[j,kk] = ds[kk,j] = 999.
nodes.append(kk)
if kk>=0:
nodes.append(kk)
nodes,_ = torch.sort(torch.stack(nodes))
return idx[nodes]

View File

@@ -268,7 +268,7 @@ def make_full_graph(xyz, pair, idx):
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-4):
@@ -323,48 +323,8 @@ def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True,
return G, pair[b,i,j][...,None]
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
B,L,A = xyz.shape[:3]
device = xyz.device
D = torch.norm(
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
)
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
D[~mask2d] = 9999.
D[D==0] = 9999.
# select top K neighbors for each atom
# keep indices as batch/res/atm indices
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
bi,ri,ai = mask.nonzero(as_tuple=True)
bi = bi[:,None].repeat(1,top_k).reshape(-1)
ri = ri[:,None].repeat(1,top_k).reshape(-1)
ai = ai[:,None].repeat(1,top_k).reshape(-1)
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
edge = torch.full(ri.shape, maxbonds, device=device)
resmask = ri==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
resmask = ri+1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
resmask = ri-1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
edge = edge.clamp(0,maxbonds-1)
edge = F.one_hot(edge)[...,None]
natm = torch.sum(mask)
index = torch.zeros_like(mask, dtype=torch.long, device=device)
index[mask] = torch.arange(natm, device=device)
src=index[bi,ri,ai]
tgt=index[bi,rj,aj]
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
return G, edge
# rotate about the x axis