mirror of
https://github.com/RosettaCommons/RFdiffusion.git
synced 2026-06-04 18:44:21 +08:00
Code cleaning suggestions from arogozhnikov
This commit is contained in:
committed by
Joseph Watson/Watchwell
parent
5c6f2f1b14
commit
0d629aa672
@@ -64,8 +64,6 @@ model:
|
||||
l1_in_features: 3
|
||||
l1_out_features: 2
|
||||
num_edge_features: 64
|
||||
d_time_emb: null
|
||||
d_time_emb_proj: null
|
||||
freeze_track_motif: False
|
||||
use_motif_timestep: False
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ RUN apt-get -q update \
|
||||
decorator==5.1.0 \
|
||||
hydra-core==1.3.2 \
|
||||
pyrsistent==0.19.3 \
|
||||
icecream==2.1.3 \
|
||||
/app/RFdiffusion/env/SE3Transformer \
|
||||
&& pip install --no-cache-dir /app/RFdiffusion --no-deps
|
||||
|
||||
@@ -48,4 +47,4 @@ WORKDIR /app/RFdiffusion
|
||||
|
||||
ENV DGLBACKEND="pytorch"
|
||||
|
||||
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
|
||||
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
|
||||
|
||||
1
env/SE3nv.yml
vendored
1
env/SE3nv.yml
vendored
@@ -12,7 +12,6 @@ dependencies:
|
||||
- torchvision
|
||||
- cudatoolkit=11.1
|
||||
- dgl-cuda11.1
|
||||
- icecream
|
||||
- pip
|
||||
- pip:
|
||||
- hydra-core
|
||||
|
||||
@@ -31,7 +31,7 @@ class FeedForwardLayer(nn.Module):
|
||||
|
||||
class Attention(nn.Module):
|
||||
# calculate multi-head attention
|
||||
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
|
||||
def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
|
||||
super(Attention, self).__init__()
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
@@ -34,7 +34,7 @@ class DistanceNetwork(nn.Module):
|
||||
return logits_dist, logits_omega, logits_theta, logits_phi
|
||||
|
||||
class MaskedTokenNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.1):
|
||||
def __init__(self, n_feat):
|
||||
super(MaskedTokenNetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, 21)
|
||||
|
||||
|
||||
@@ -12,82 +12,6 @@ import math
|
||||
|
||||
# Module contains classes and functions to generate initial embeddings
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
class Timestep_emb(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
T,
|
||||
use_motif_timestep=True
|
||||
):
|
||||
super(Timestep_emb, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.T = T
|
||||
|
||||
# get source for timestep embeddings at all t AND zero (for the motif)
|
||||
self.source_embeddings = get_timestep_embedding(torch.arange(self.T+1), self.input_size)
|
||||
self.source_embeddings.requires_grad = False
|
||||
|
||||
# Layers to use for projection
|
||||
self.node_embedder = nn.Sequential(
|
||||
nn.Linear(input_size, output_size, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_size, output_size, bias=True),
|
||||
nn.LayerNorm(output_size),
|
||||
)
|
||||
|
||||
|
||||
def get_init_emb(self, t, L, motif_mask):
|
||||
"""
|
||||
Calculates and stacks a timestep embedding to project
|
||||
|
||||
Parameters:
|
||||
|
||||
t (int, required): Current timestep
|
||||
|
||||
L (int, required): Length of protein
|
||||
|
||||
motif_mask (torch.tensor, required): Boolean mask where True denotes a fixed motif position
|
||||
"""
|
||||
assert t > 0, 't should be 1-indexed and cant have t=0'
|
||||
|
||||
t_emb = torch.clone(self.source_embeddings[t.squeeze()]).to(motif_mask.device)
|
||||
zero_emb = torch.clone(self.source_embeddings[0]).to(motif_mask.device)
|
||||
|
||||
# timestep embedding for all residues
|
||||
timestep_embedding = torch.stack([t_emb]*L)
|
||||
|
||||
# slice in motif zero timestep features
|
||||
timestep_embedding[motif_mask] = zero_emb
|
||||
|
||||
return timestep_embedding
|
||||
|
||||
|
||||
def forward(self, L, t, motif_mask):
|
||||
"""
|
||||
Constructs and projects a timestep embedding
|
||||
"""
|
||||
emb_in = self.get_init_emb(t,L,motif_mask)
|
||||
emb_out = self.node_embedder(emb_in)
|
||||
return emb_out
|
||||
|
||||
class PositionalEncoding2D(nn.Module):
|
||||
# Add relative positional encoding to pair features
|
||||
def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
|
||||
@@ -194,14 +118,6 @@ class Extra_emb(nn.Module):
|
||||
|
||||
# Sergey's one hot trick
|
||||
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||
"""
|
||||
#TODO delete this once verified
|
||||
if self.input_seq_onehot:
|
||||
# Sergey's one hot trick
|
||||
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||
else:
|
||||
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||
"""
|
||||
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
return self.drop(msa)
|
||||
|
||||
@@ -278,14 +194,14 @@ class Templ_emb(nn.Module):
|
||||
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
d_hidden=d_hidden, p_drop=p_drop)
|
||||
|
||||
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
|
||||
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)
|
||||
|
||||
# process torsion angles
|
||||
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
|
||||
self.proj_t1d = nn.Linear(d_templ, d_templ)
|
||||
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
# d_hidden=d_hidden, p_drop=p_drop)
|
||||
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
|
||||
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling, Timestep_emb
|
||||
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
|
||||
from rfdiffusion.Track_module import IterativeSimulator
|
||||
from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
|
||||
from opt_einsum import contract as einsum
|
||||
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
def __init__(self,
|
||||
n_extra_block,
|
||||
@@ -23,8 +22,6 @@ class RoseTTAFoldModule(nn.Module):
|
||||
p_drop,
|
||||
d_t1d,
|
||||
d_t2d,
|
||||
d_time_emb, # total dims for input timestep emb
|
||||
d_time_emb_proj, # size of projected timestep emb
|
||||
T, # total timesteps (used in timestep emb
|
||||
use_motif_timestep, # Whether to have a distinct emb for motif
|
||||
freeze_track_motif, # Whether to freeze updates to motif in track
|
||||
@@ -47,15 +44,6 @@ class RoseTTAFoldModule(nn.Module):
|
||||
n_head=n_head_templ,
|
||||
d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
|
||||
|
||||
# timestep embedder
|
||||
if d_time_emb:
|
||||
print('NOTE: Using sinusoidal timestep embeddings of dim ',d_time_emb, ' projected to dim ',d_time_emb_proj)
|
||||
assert d_t1d >= 22 + d_time_emb_proj, 'timestep projection size doesn\'t fit into RF t1d projection layers'
|
||||
self.timestep_embedder = Timestep_emb(input_size=d_time_emb,
|
||||
output_size=d_time_emb_proj,
|
||||
T=T,
|
||||
use_motif_timestep=use_motif_timestep)
|
||||
|
||||
|
||||
# Update inputs with outputs from previous round
|
||||
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
@@ -72,7 +60,7 @@ class RoseTTAFoldModule(nn.Module):
|
||||
p_drop=p_drop)
|
||||
##
|
||||
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa)
|
||||
self.lddt_pred = LDDTNetwork(d_state)
|
||||
|
||||
self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
|
||||
|
||||
@@ -1,23 +1,7 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.spatial
|
||||
|
||||
# calculate dihedral angles defined by 4 sets of points
|
||||
def get_dihedrals(a, b, c, d):
|
||||
|
||||
b0 = -1.0*(b - a)
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
|
||||
b1 /= np.linalg.norm(b1, axis=-1)[:,None]
|
||||
|
||||
v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
|
||||
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
|
||||
|
||||
x = np.sum(v*w, axis=-1)
|
||||
y = np.sum(np.cross(b1, v)*w, axis=-1)
|
||||
|
||||
return np.arctan2(y, x)
|
||||
from rfdiffusion.kinematics import get_dih
|
||||
|
||||
# calculate planar angles defined by 3 sets of points
|
||||
def get_angles(a, b, c):
|
||||
@@ -65,11 +49,10 @@ def get_coords6d(xyz, dmax):
|
||||
|
||||
# matrix of Ca-Cb-Cb-Ca dihedrals
|
||||
omega6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
|
||||
|
||||
omega6d[idx0,idx1] = get_dih(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
|
||||
# matrix of polar coord theta
|
||||
theta6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
|
||||
theta6d[idx0,idx1] = get_dih(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
|
||||
|
||||
# matrix of polar coord phi
|
||||
phi6d = np.zeros((nres, nres), dtype=np.float32)
|
||||
|
||||
@@ -1,255 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from rfdiffusion.chemical import INIT_CRDS
|
||||
from icecream import ic
|
||||
|
||||
|
||||
def th_min_angle(start, end, radians=False):
|
||||
"""
|
||||
Finds the angle you would add to <start> in order to get to <end>
|
||||
on the shortest path.
|
||||
"""
|
||||
a,b,c = (np.pi, 2*np.pi, 3*np.pi) if radians else (180, 360, 540)
|
||||
shortest_angle = ((((end - start) % b) + c) % b) - a
|
||||
return shortest_angle
|
||||
|
||||
|
||||
def th_interpolate_angles(start, end, T, n_diffuse,mindiff=None, radians=True):
|
||||
"""
|
||||
|
||||
"""
|
||||
# find the minimum angle to add to get from start to end
|
||||
angle_diffs = th_min_angle(start, end, radians=radians)
|
||||
if mindiff is not None:
|
||||
assert torch.sum(mindiff.flatten()-angle_diffs) == 0.
|
||||
if n_diffuse is None:
|
||||
# default is to diffuse for max steps
|
||||
n_diffuse = torch.full((len(angle_diffs)), T)
|
||||
|
||||
|
||||
interps = []
|
||||
for i,diff in enumerate(angle_diffs):
|
||||
N = int(n_diffuse[i])
|
||||
actual_interp = torch.linspace(start[i], start[i]+diff, N)
|
||||
whole_interp = torch.full((T,), float(start[i]+diff))
|
||||
temp=torch.clone(whole_interp)
|
||||
whole_interp[:N] = actual_interp
|
||||
|
||||
interps.append(whole_interp)
|
||||
|
||||
return torch.stack(interps, dim=0)
|
||||
|
||||
|
||||
def th_interpolate_angle_single(start, end, step, T, mindiff=None, radians=True):
|
||||
"""
|
||||
|
||||
"""
|
||||
# find the minimum angle to add to get from start to end
|
||||
angle_diffs = th_min_angle(start, end, radians=radians)
|
||||
if mindiff is not None:
|
||||
assert torch.sum(mindiff.flatten()-angle_diffs) == 0.
|
||||
|
||||
# linearly interpolate between x = [0, T-1], y = [start, start + diff]
|
||||
x_range = T-1
|
||||
interps = step / x_range * angle_diffs + start
|
||||
|
||||
return interps
|
||||
|
||||
|
||||
def get_aa_schedule(T, L, nsteps=100):
|
||||
"""
|
||||
Returns the steps t when each amino acid should be decoded,
|
||||
as well as how many steps that amino acids chi angles will be diffused
|
||||
|
||||
Parameters:
|
||||
T (int, required): Total number of steps we are decoding the sequence over
|
||||
|
||||
L (int, required): Length of protein sequence
|
||||
|
||||
nsteps (int, optional): Number of steps over the course of which to decode the amino acids
|
||||
|
||||
Returns: three items
|
||||
decode_times (list): List of times t when the positions in <decode_order> should be decoded
|
||||
|
||||
decode_order (list): List of lists, each element containing which positions are going to be decoded at
|
||||
the corresponding time in <decode_times>
|
||||
|
||||
idx2diffusion_steps (np.array): Array mapping the index of the residue to how many diffusion steps it will require
|
||||
|
||||
"""
|
||||
# nsteps can't be more than T or more than length of protein
|
||||
if (nsteps > T) or (nsteps > L):
|
||||
nsteps = min([T,L])
|
||||
|
||||
|
||||
decode_order = [[a] for a in range(L)]
|
||||
random.shuffle(decode_order)
|
||||
|
||||
while len(decode_order) > nsteps:
|
||||
# pop an element and then add those positions randomly to some other step
|
||||
tmp_seqpos = decode_order.pop()
|
||||
decode_order[random.randint(0,len(decode_order)-1)] += tmp_seqpos
|
||||
random.shuffle(decode_order)
|
||||
|
||||
decode_times = np.arange(nsteps)+1
|
||||
|
||||
# now given decode times, calculate number of diffusion steps each position gets
|
||||
aa_masks = np.full((200,L), False)
|
||||
|
||||
idx2diffusion_steps = np.full((L,),float(np.nan))
|
||||
for i,t in enumerate(decode_times):
|
||||
decode_pos = decode_order[i] # positions to be decoded at this step
|
||||
|
||||
for j,pos in enumerate(decode_pos):
|
||||
# calculate number of diffusion steps this residue gets
|
||||
idx2diffusion_steps[pos] = int(t)
|
||||
aa_masks[t,pos] = True
|
||||
|
||||
aa_masks = np.cumsum(aa_masks, axis=0)
|
||||
|
||||
return decode_times, decode_order, idx2diffusion_steps, ~(aa_masks.astype(bool))
|
||||
|
||||
####################
|
||||
### for SecStruc ###
|
||||
####################
|
||||
|
||||
def ss_to_tensor(ss_dict):
|
||||
"""
|
||||
Function to convert ss files to indexed tensors
|
||||
0 = Helix
|
||||
1 = Strand
|
||||
2 = Loop
|
||||
3 = Mask/unknown
|
||||
4 = idx for pdb
|
||||
"""
|
||||
ss_conv = {'H':0,'E':1,'L':2}
|
||||
ss_int = np.array([int(ss_conv[i]) for i in ss_dict['ss']])
|
||||
return ss_int
|
||||
|
||||
def mask_ss(ss, min_mask = 0, max_mask = 0.75):
|
||||
"""
|
||||
Function to take ss array, find the junctions, and randomly mask these until a random proportion (up to 75%) is masked
|
||||
Input: numpy array of ss (H=0,E=1,L=2,mask=3)
|
||||
output: tensor with some proportion of junctions masked
|
||||
"""
|
||||
mask_prop = random.uniform(min_mask, max_mask)
|
||||
transitions = np.where(ss[:-1] - ss[1:] != 0)[0] #gets last index of each block of ss
|
||||
counter = 0
|
||||
#TODO think about masking whole ss elements
|
||||
while len(ss[ss == 3])/len(ss) < mask_prop and counter < 100: #very hacky - do better
|
||||
try:
|
||||
width = random.randint(1,9)
|
||||
start = random.choice(transitions)
|
||||
offset = random.randint(-8,1)
|
||||
ss[start+offset:start+offset+width] = 3
|
||||
counter += 1
|
||||
except:
|
||||
counter += 1
|
||||
ss = torch.tensor(ss)
|
||||
mask = torch.where(ss == 3, True, False)
|
||||
ss = torch.nn.functional.one_hot(ss, num_classes=4)
|
||||
return ss, mask
|
||||
|
||||
def construct_block_adj_matrix( sstruct, xyz, nan_mask, cutoff=6, include_loops=False ):
|
||||
'''
|
||||
Given a sstruct specification and backbone coordinates, build a block adjacency matrix.
|
||||
|
||||
Input:
|
||||
|
||||
sstruct (torch.FloatTensor): (L) length tensor with numeric encoding of sstruct at each position
|
||||
|
||||
xyz (torch.FloatTensor): (L,3,3) tensor of Cartesian coordinates of backbone N,Ca,C atoms
|
||||
|
||||
cutoff (float): The Cb distance cutoff under which residue pairs are considered adjacent
|
||||
By eye, Nate thinks 6A is a good Cb distance cutoff
|
||||
|
||||
Output:
|
||||
|
||||
block_adj (torch.FloatTensor): (L,L) boolean matrix where adjacent secondary structure contacts are 1
|
||||
'''
|
||||
|
||||
# Remove nans at this stage, as ss doesn't consider nans
|
||||
xyz_nonan = xyz[nan_mask]
|
||||
L = xyz_nonan.shape[0]
|
||||
assert L == sstruct.shape[0]
|
||||
# three anchor atoms
|
||||
N = xyz_nonan[:,0]
|
||||
Ca = xyz_nonan[:,1]
|
||||
C = xyz_nonan[:,2]
|
||||
|
||||
# recreate Cb given N,Ca,C
|
||||
Cb = generate_Cbeta(N,Ca,C)
|
||||
|
||||
dist = get_pair_dist(Cb,Cb) # [L,L]
|
||||
dist[torch.isnan(dist)] = 999.9
|
||||
assert torch.sum(torch.isnan(dist)) == 0
|
||||
dist += 999.9*torch.eye(L,device=xyz.device)
|
||||
|
||||
# Now we have dist matrix and sstruct specification, turn this into a block adjacency matrix
|
||||
|
||||
# First: Construct a list of segments and the index at which they begin and end
|
||||
in_segment = True
|
||||
segments = []
|
||||
|
||||
begin = -1
|
||||
end = -1
|
||||
# need to expand ss out to size L
|
||||
|
||||
|
||||
for i in range(sstruct.shape[0]):
|
||||
# Starting edge case
|
||||
if i == 0:
|
||||
begin = 0
|
||||
continue
|
||||
|
||||
if not sstruct[i] == sstruct[i-1]:
|
||||
end = i
|
||||
segments.append( (sstruct[i-1], begin, end) )
|
||||
|
||||
begin = i
|
||||
|
||||
# Ending edge case: last segment is length one
|
||||
if not end == sstruct.shape[0]:
|
||||
segments.append( (sstruct[-1], begin, sstruct.shape[0]) )
|
||||
|
||||
# Second: Using segments and dgram, determine adjacent blocks
|
||||
block_adj = torch.zeros_like(dist)
|
||||
for i in range(len(segments)):
|
||||
curr_segment = segments[i]
|
||||
|
||||
if curr_segment[0] == 2 and not include_loops: continue
|
||||
|
||||
begin_i = curr_segment[1]
|
||||
end_i = curr_segment[2]
|
||||
for j in range(i+1, len(segments)):
|
||||
j_segment = segments[j]
|
||||
|
||||
if j_segment[0] == 2 and not include_loops: continue
|
||||
|
||||
begin_j = j_segment[1]
|
||||
end_j = j_segment[2]
|
||||
|
||||
if torch.any( dist[begin_i:end_i, begin_j:end_j] < cutoff ):
|
||||
# Matrix is symmetic
|
||||
block_adj[begin_i:end_i, begin_j:end_j] = torch.ones(end_i - begin_i, end_j - begin_j)
|
||||
block_adj[begin_j:end_j, begin_i:end_i] = torch.ones(end_j - begin_j, end_i - begin_i)
|
||||
|
||||
return block_adj
|
||||
|
||||
def get_pair_dist(a, b):
|
||||
"""calculate pair distances between two sets of points
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of two sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dist : pytorch tensor of shape [batch,nres,nres]
|
||||
stores paitwise distances between atoms in a and b
|
||||
"""
|
||||
|
||||
dist = torch.cdist(a, b, p=2)
|
||||
return dist
|
||||
@@ -237,8 +237,7 @@ class IGSO3:
|
||||
num_sigma=self.num_sigma,
|
||||
min_sigma=self.min_sigma,
|
||||
max_sigma=self.max_sigma,
|
||||
num_omega=self.num_omega,
|
||||
L=L,
|
||||
num_omega=self.num_omega
|
||||
)
|
||||
write_pkl(cache_fname, igso3_vals)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def igso3_score(R, t, L=L_default):
|
||||
unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None]
|
||||
return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None]
|
||||
|
||||
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma, L=L_default):
|
||||
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma):
|
||||
"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs
|
||||
and score norms and expected squared score norms.
|
||||
|
||||
|
||||
@@ -247,7 +247,6 @@ class Sampler:
|
||||
'L': L,
|
||||
'diffuser': self.diffuser,
|
||||
'potential_manager': self.potential_manager,
|
||||
'visible': visible
|
||||
})
|
||||
return iu.Denoise(**denoise_kwargs)
|
||||
|
||||
|
||||
@@ -223,8 +223,6 @@ class Denoise:
|
||||
T,
|
||||
L,
|
||||
diffuser,
|
||||
visible,
|
||||
seq_diffuser=None,
|
||||
b_0=0.001,
|
||||
b_T=0.1,
|
||||
min_b=1.0,
|
||||
@@ -256,7 +254,6 @@ class Denoise:
|
||||
self.T = T
|
||||
self.L = L
|
||||
self.diffuser = diffuser
|
||||
self.seq_diffuser = seq_diffuser
|
||||
self.b_0 = b_0
|
||||
self.b_T = b_T
|
||||
self.noise_level = noise_level
|
||||
@@ -301,8 +298,6 @@ class Denoise:
|
||||
Third, centre at origin
|
||||
"""
|
||||
|
||||
# if True:
|
||||
# return px0
|
||||
def rmsd(V, W, eps=0):
|
||||
# First sum down atoms, then sum down xyz
|
||||
N = V.shape[-2]
|
||||
@@ -358,17 +353,12 @@ class Denoise:
|
||||
px0[~atom_mask] = 0 # convert nans to 0
|
||||
px0 = px0.reshape(-1, 3) - px0_motif_mean
|
||||
px0_ = px0 @ R
|
||||
# xT_motif_out = xT_motif.reshape(-1,3)
|
||||
# xT_motif_out = (xT_motif_out @ R ) + px0_motif_mean
|
||||
# ic(xT_motif_out.shape)
|
||||
# xT_motif_out = xT_motif_out.reshape((diffusion_mask.sum(),3,3))
|
||||
|
||||
# 3 put in same global position as xT
|
||||
px0_ = px0_ + xT_motif_mean
|
||||
px0_ = px0_.reshape([L, n_atom, 3])
|
||||
px0_[~atom_mask] = float("nan")
|
||||
return torch.Tensor(px0_)
|
||||
# return torch.tensor(xT_motif_out)
|
||||
|
||||
def get_potential_gradients(self, xyz, diffusion_mask):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from rfdiffusion.chemical import INIT_CRDS
|
||||
from rfdiffusion.util import generate_Cbeta
|
||||
|
||||
PARAMS = {
|
||||
"DMIN" : 2.0,
|
||||
@@ -55,13 +56,18 @@ def get_dih(a, b, c, d):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
||||
a,b,c,d : pytorch tensors or numpy array of shape [batch,nres,3]
|
||||
store Cartesian coordinates of four sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dih : pytorch tensor of shape [batch,nres]
|
||||
dih : pytorch tensor or numpy array of shape [batch,nres]
|
||||
stores resulting dihedrals
|
||||
"""
|
||||
convert_to_torch = lambda *arrays: [torch.from_numpy(arr) for arr in arrays]
|
||||
output_np=False
|
||||
if isinstance(a, np.ndarray):
|
||||
output_np=True
|
||||
a,b,c,d = convert_to_torch(a,b,c,d)
|
||||
b0 = a - b
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
@@ -73,18 +79,10 @@ def get_dih(a, b, c, d):
|
||||
|
||||
x = torch.sum(v*w, dim=-1)
|
||||
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
|
||||
|
||||
return torch.atan2(y, x)
|
||||
|
||||
def get_Cb(xyz):
|
||||
'''recreate Cb given N,Ca,C'''
|
||||
N = xyz[...,0,:]
|
||||
Ca = xyz[...,1,:]
|
||||
C = xyz[...,2,:]
|
||||
b = Ca - N
|
||||
c = C - Ca
|
||||
a = torch.cross(b, c, dim=-1)
|
||||
return -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||
output = torch.atan2(y, x)
|
||||
if output_np:
|
||||
return output.numpy()
|
||||
return output
|
||||
|
||||
# ============================================================
|
||||
def xyz_to_c6d(xyz, params=PARAMS):
|
||||
@@ -108,7 +106,7 @@ def xyz_to_c6d(xyz, params=PARAMS):
|
||||
N = xyz[:,:,0]
|
||||
Ca = xyz[:,:,1]
|
||||
C = xyz[:,:,2]
|
||||
Cb = get_Cb(xyz)
|
||||
Cb = generate_Cbeta(N, Ca, C)
|
||||
|
||||
# 6d coordinates order: (dist,omega,theta,phi)
|
||||
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
|
||||
|
||||
@@ -113,9 +113,6 @@ class PotentialManager:
|
||||
if setting['type'] in potentials.require_binderlen:
|
||||
setting.update(binderlen_update)
|
||||
|
||||
if setting['type'] in potentials.require_hotspot_res:
|
||||
setting.update(hotspot_res_update)
|
||||
|
||||
self.potentials_to_apply = self.initialize_all_potentials(setting_list)
|
||||
self.T = diffuser_config.T
|
||||
|
||||
@@ -199,7 +196,7 @@ class PotentialManager:
|
||||
# Linear interpolation with y2: 0, y1: guide_scale, x2: 0, x1: T, x: t
|
||||
'linear' : lambda t: t/self.T * self.guide_scale,
|
||||
'quadratic' : lambda t: t**2/self.T**2 * self.guide_scale,
|
||||
'cubic' : lambda t: t**3/self.T**3
|
||||
'cubic' : lambda t: t**3/self.T**3 * self.guide_scale
|
||||
}
|
||||
|
||||
if self.guide_decay not in implemented_decay_types:
|
||||
|
||||
@@ -146,52 +146,6 @@ class binder_ncontacts(Potential):
|
||||
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
|
||||
return self.weight * binder_ncontacts.sum()
|
||||
|
||||
|
||||
class dimer_ncontacts(Potential):
|
||||
|
||||
'''
|
||||
Differentiable way to maximise number of contacts for two individual monomers in a dimer
|
||||
|
||||
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html
|
||||
|
||||
Author: PV
|
||||
'''
|
||||
|
||||
|
||||
def __init__(self, binderlen, weight=1, r_0=8, d_0=4):
|
||||
|
||||
self.binderlen = binderlen
|
||||
self.r_0 = r_0
|
||||
self.weight = weight
|
||||
self.d_0 = d_0
|
||||
|
||||
def compute(self, xyz):
|
||||
|
||||
# Only look at binder Ca residues
|
||||
Ca = xyz[:self.binderlen,1] # [Lb,3]
|
||||
#cdist needs a batch dimension - NRB
|
||||
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
|
||||
divide_by_r_0 = (dgram - self.d_0) / self.r_0
|
||||
numerator = torch.pow(divide_by_r_0,6)
|
||||
denominator = torch.pow(divide_by_r_0,12)
|
||||
binder_ncontacts = (1 - numerator) / (1 - denominator)
|
||||
#Potential is the sum of values in the tensor
|
||||
binder_ncontacts = binder_ncontacts.sum()
|
||||
|
||||
# Only look at target Ca residues
|
||||
Ca = xyz[self.binderlen:,1] # [Lb,3]
|
||||
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
|
||||
divide_by_r_0 = (dgram - self.d_0) / self.r_0
|
||||
numerator = torch.pow(divide_by_r_0,6)
|
||||
denominator = torch.pow(divide_by_r_0,12)
|
||||
target_ncontacts = (1 - numerator) / (1 - denominator)
|
||||
#Potential is the sum of values in the tensor
|
||||
target_ncontacts = target_ncontacts.sum()
|
||||
|
||||
print("DIMER NCONTACTS:", (binder_ncontacts+target_ncontacts)/2)
|
||||
#Returns average of n contacts withiin monomer 1 and monomer 2
|
||||
return self.weight * (binder_ncontacts+target_ncontacts)/2
|
||||
|
||||
class interface_ncontacts(Potential):
|
||||
|
||||
'''
|
||||
@@ -266,42 +220,6 @@ class monomer_contacts(Potential):
|
||||
return self.weight * ncontacts.sum()
|
||||
|
||||
|
||||
def make_contact_matrix(nchain, contact_string=None):
|
||||
"""
|
||||
Calculate a matrix of inter/intra chain contact indicators
|
||||
|
||||
Parameters:
|
||||
nchain (int, required): How many chains are in this design
|
||||
|
||||
contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains
|
||||
'!' denotes repulsive, '&' denotes attractive
|
||||
"""
|
||||
alphabet = [a for a in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ']
|
||||
letter2num = {a:i for i,a in enumerate(alphabet)}
|
||||
|
||||
contacts = np.zeros((nchain,nchain))
|
||||
written = np.zeros((nchain,nchain))
|
||||
|
||||
contact_list = contact_string.split(',')
|
||||
for c in contact_list:
|
||||
if not len(c) == 3:
|
||||
raise SyntaxError('Invalid contact(s) specification')
|
||||
|
||||
i,j = letter2num[c[0]],letter2num[c[2]]
|
||||
symbol = c[1]
|
||||
|
||||
# denote contacting/repulsive
|
||||
assert symbol in ['!','&']
|
||||
if symbol == '!':
|
||||
contacts[i,j] = -1
|
||||
contacts[j,i] = -1
|
||||
else:
|
||||
contacts[i,j] = 1
|
||||
contacts[j,i] = 1
|
||||
|
||||
return contacts
|
||||
|
||||
|
||||
class olig_contacts(Potential):
|
||||
"""
|
||||
Applies PV's num contacts potential within/between chains in symmetric oligomers
|
||||
@@ -343,17 +261,6 @@ class olig_contacts(Potential):
|
||||
self.nchain=shape[0]
|
||||
|
||||
|
||||
# self._compute_chain_indices()
|
||||
|
||||
# def _compute_chain_indices(self):
|
||||
# # make list of shape [i,N] for indices of each chain in total length
|
||||
# indices = []
|
||||
# start = 0
|
||||
# for l in self.chain_lengths:
|
||||
# indices.append(torch.arange(start,start+l))
|
||||
# start += l
|
||||
# self.indices = indices
|
||||
|
||||
def _get_idx(self,i,L):
|
||||
"""
|
||||
Returns the zero-indexed indices of the residues in chain i
|
||||
@@ -398,51 +305,6 @@ class olig_contacts(Potential):
|
||||
|
||||
return all_contacts
|
||||
|
||||
|
||||
class olig_intra_contacts(Potential):
|
||||
"""
|
||||
Applies PV's num contacts potential for each chain individually in an oligomer design
|
||||
|
||||
Author: DJ
|
||||
"""
|
||||
|
||||
def __init__(self, chain_lengths, weight=1):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
chain_lengths (list, required): Ordered list of chain lengths
|
||||
|
||||
weight (int/float, optional): Scaling/weighting factor
|
||||
"""
|
||||
self.chain_lengths = chain_lengths
|
||||
self.weight = weight
|
||||
|
||||
|
||||
def compute(self, xyz):
|
||||
"""
|
||||
Computes intra-chain num contacts potential
|
||||
"""
|
||||
assert sum(self.chain_lengths) == xyz.shape[0], 'given chain lengths do not match total protein length'
|
||||
|
||||
all_contacts = 0
|
||||
start = 0
|
||||
for Lc in self.chain_lengths:
|
||||
Ca = xyz[start:start+Lc] # slice out crds for this chain
|
||||
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
|
||||
divide_by_r_0 = (dgram - self.d_0) / self.r_0
|
||||
numerator = torch.pow(divide_by_r_0,6)
|
||||
denominator = torch.pow(divide_by_r_0,12)
|
||||
ncontacts = (1 - numerator) / (1 - denominator)
|
||||
|
||||
# add contacts for this chain to all contacts
|
||||
all_contacts += ncontacts.sum()
|
||||
|
||||
# increment the start to be at the next chain
|
||||
start += Lc
|
||||
|
||||
|
||||
return self.weight * all_contacts
|
||||
|
||||
def get_damped_lj(r_min, r_lin,p1=6,p2=12):
|
||||
|
||||
y_at_r_lin = lj(r_lin, r_min, p1, p2)
|
||||
@@ -592,131 +454,15 @@ class substrate_contacts(Potential):
|
||||
self.motif_frame = xyz[rand_idx[0],:4]
|
||||
self.motif_mapping = [(rand_idx, i) for i in range(4)]
|
||||
|
||||
class binder_distance_ReLU(Potential):
|
||||
'''
|
||||
Given the current coordinates of the diffusion trajectory, calculate a potential that is the distance between each residue
|
||||
and the closest target residue.
|
||||
|
||||
This potential is meant to encourage the binder to interact with a certain subset of residues on the target that
|
||||
define the binding site.
|
||||
|
||||
Author: NRB
|
||||
'''
|
||||
|
||||
def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False):
|
||||
|
||||
self.binderlen = binderlen
|
||||
self.hotspot_res = [res + binderlen for res in hotspot_res]
|
||||
self.weight = weight
|
||||
self.min_dist = min_dist
|
||||
self.use_Cb = use_Cb
|
||||
|
||||
def compute(self, xyz):
|
||||
binder = xyz[:self.binderlen,:,:] # (Lb,27,3)
|
||||
target = xyz[self.hotspot_res,:,:] # (N,27,3)
|
||||
|
||||
if self.use_Cb:
|
||||
N = binder[:,0]
|
||||
Ca = binder[:,1]
|
||||
C = binder[:,2]
|
||||
|
||||
Cb = generate_Cbeta(N,Ca,C) # (Lb,3)
|
||||
|
||||
N_t = target[:,0]
|
||||
Ca_t = target[:,1]
|
||||
C_t = target[:,2]
|
||||
|
||||
Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3)
|
||||
|
||||
dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N)
|
||||
|
||||
else:
|
||||
# Use Ca dist for potential
|
||||
|
||||
Ca = binder[:,1] # (Lb,3)
|
||||
|
||||
Ca_t = target[:,1] # (N,3)
|
||||
|
||||
dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N)
|
||||
|
||||
closest_dist = torch.min(dgram.squeeze(0), dim=1)[0] # (Lb)
|
||||
|
||||
# Cap the distance at a minimum value
|
||||
min_distance = self.min_dist * torch.ones_like(closest_dist) # (Lb)
|
||||
potential = torch.maximum(min_distance, closest_dist) # (Lb)
|
||||
|
||||
# torch.Tensor.backward() requires the potential to be a single value
|
||||
potential = torch.sum(potential, dim=-1)
|
||||
|
||||
return -1 * self.weight * potential
|
||||
|
||||
class binder_any_ReLU(Potential):
|
||||
'''
|
||||
Given the current coordinates of the diffusion trajectory, calculate a potential that is the minimum distance between
|
||||
ANY residue and the closest target residue.
|
||||
|
||||
In contrast to binder_distance_ReLU this potential will only penalize a pose if all of the binder residues are outside
|
||||
of a certain distance from the target residues.
|
||||
|
||||
Author: NRB
|
||||
'''
|
||||
|
||||
def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False):
|
||||
|
||||
self.binderlen = binderlen
|
||||
self.hotspot_res = [res + binderlen for res in hotspot_res]
|
||||
self.weight = weight
|
||||
self.min_dist = min_dist
|
||||
self.use_Cb = use_Cb
|
||||
|
||||
def compute(self, xyz):
|
||||
binder = xyz[:self.binderlen,:,:] # (Lb,27,3)
|
||||
target = xyz[self.hotspot_res,:,:] # (N,27,3)
|
||||
|
||||
if use_Cb:
|
||||
N = binder[:,0]
|
||||
Ca = binder[:,1]
|
||||
C = binder[:,2]
|
||||
|
||||
Cb = generate_Cbeta(N,Ca,C) # (Lb,3)
|
||||
|
||||
N_t = target[:,0]
|
||||
Ca_t = target[:,1]
|
||||
C_t = target[:,2]
|
||||
|
||||
Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3)
|
||||
|
||||
dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N)
|
||||
|
||||
else:
|
||||
# Use Ca dist for potential
|
||||
|
||||
Ca = binder[:,1] # (Lb,3)
|
||||
|
||||
Ca_t = target[:,1] # (N,3)
|
||||
|
||||
dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N)
|
||||
|
||||
|
||||
closest_dist = torch.min(dgram.squeeze(0)) # (1)
|
||||
|
||||
potential = torch.maximum(min_dist, closest_dist) # (1)
|
||||
|
||||
return -1 * self.weight * potential
|
||||
|
||||
# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
|
||||
# If you implement a new potential you must add it to this dictionary for it to be used by
|
||||
# the PotentialManager
|
||||
implemented_potentials = { 'monomer_ROG': monomer_ROG,
|
||||
'binder_ROG': binder_ROG,
|
||||
'binder_distance_ReLU': binder_distance_ReLU,
|
||||
'binder_any_ReLU': binder_any_ReLU,
|
||||
'dimer_ROG': dimer_ROG,
|
||||
'binder_ncontacts': binder_ncontacts,
|
||||
'dimer_ncontacts': dimer_ncontacts,
|
||||
'interface_ncontacts': interface_ncontacts,
|
||||
'monomer_contacts': monomer_contacts,
|
||||
'olig_intra_contacts': olig_intra_contacts,
|
||||
'olig_contacts': olig_contacts,
|
||||
'substrate_contacts': substrate_contacts}
|
||||
|
||||
@@ -725,9 +471,5 @@ require_binderlen = { 'binder_ROG',
|
||||
'binder_any_ReLU',
|
||||
'dimer_ROG',
|
||||
'binder_ncontacts',
|
||||
'dimer_ncontacts',
|
||||
'interface_ncontacts'}
|
||||
|
||||
require_hotspot_res = { 'binder_distance_ReLU',
|
||||
'binder_any_ReLU' }
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@ def generate_Cbeta(N, Ca, C):
|
||||
b = Ca - N
|
||||
c = C - Ca
|
||||
a = torch.cross(b, c, dim=-1)
|
||||
# Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||
# These are the values used during training
|
||||
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||
# fd: below matches sidechain generator (=Rosetta params)
|
||||
Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
|
||||
# Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
|
||||
|
||||
return Cb
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import copy
|
||||
import dgl
|
||||
from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points
|
||||
|
||||
def init_lecun_normal(module, scale=1.0):
|
||||
def init_lecun_normal(module):
|
||||
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||
normal = torch.distributions.normal.Normal(0, 1)
|
||||
|
||||
@@ -23,14 +23,14 @@ def init_lecun_normal(module, scale=1.0):
|
||||
|
||||
return x
|
||||
|
||||
def sample_truncated_normal(shape, scale=1.0):
|
||||
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
def sample_truncated_normal(shape):
|
||||
stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
return stddev * truncated_normal(torch.rand(shape))
|
||||
|
||||
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
|
||||
return module
|
||||
|
||||
def init_lecun_normal_param(weight, scale=1.0):
|
||||
def init_lecun_normal_param(weight):
|
||||
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||
normal = torch.distributions.normal.Normal(0, 1)
|
||||
|
||||
@@ -46,8 +46,8 @@ def init_lecun_normal_param(weight, scale=1.0):
|
||||
|
||||
return x
|
||||
|
||||
def sample_truncated_normal(shape, scale=1.0):
|
||||
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
def sample_truncated_normal(shape):
|
||||
stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||
return stddev * truncated_normal(torch.rand(shape))
|
||||
|
||||
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
|
||||
|
||||
Reference in New Issue
Block a user