Code cleaning suggestions from arogozhnikov

This commit is contained in:
Joseph Watson
2023-04-26 11:26:17 -07:00
committed by Joseph Watson/Watchwell
parent 5c6f2f1b14
commit 0d629aa672
18 changed files with 35 additions and 681 deletions

View File

@@ -64,8 +64,6 @@ model:
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
d_time_emb: null
d_time_emb_proj: null
freeze_track_motif: False
use_motif_timestep: False

View File

@@ -40,7 +40,6 @@ RUN apt-get -q update \
decorator==5.1.0 \
hydra-core==1.3.2 \
pyrsistent==0.19.3 \
icecream==2.1.3 \
/app/RFdiffusion/env/SE3Transformer \
&& pip install --no-cache-dir /app/RFdiffusion --no-deps
@@ -48,4 +47,4 @@ WORKDIR /app/RFdiffusion
ENV DGLBACKEND="pytorch"
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]

1
env/SE3nv.yml vendored
View File

@@ -12,7 +12,6 @@ dependencies:
- torchvision
- cudatoolkit=11.1
- dgl-cuda11.1
- icecream
- pip
- pip:
- hydra-core

View File

@@ -31,7 +31,7 @@ class FeedForwardLayer(nn.Module):
class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden

View File

@@ -34,7 +34,7 @@ class DistanceNetwork(nn.Module):
return logits_dist, logits_omega, logits_theta, logits_phi
class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
def __init__(self, n_feat):
super(MaskedTokenNetwork, self).__init__()
self.proj = nn.Linear(n_feat, 21)

View File

@@ -12,82 +12,6 @@ import math
# Module contains classes and functions to generate initial embeddings
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
class Timestep_emb(nn.Module):
def __init__(
self,
input_size,
output_size,
T,
use_motif_timestep=True
):
super(Timestep_emb, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.T = T
# get source for timestep embeddings at all t AND zero (for the motif)
self.source_embeddings = get_timestep_embedding(torch.arange(self.T+1), self.input_size)
self.source_embeddings.requires_grad = False
# Layers to use for projection
self.node_embedder = nn.Sequential(
nn.Linear(input_size, output_size, bias=False),
nn.ReLU(),
nn.Linear(output_size, output_size, bias=True),
nn.LayerNorm(output_size),
)
def get_init_emb(self, t, L, motif_mask):
"""
Calculates and stacks a timestep embedding to project
Parameters:
t (int, required): Current timestep
L (int, required): Length of protein
motif_mask (torch.tensor, required): Boolean mask where True denotes a fixed motif position
"""
assert t > 0, 't should be 1-indexed and cant have t=0'
t_emb = torch.clone(self.source_embeddings[t.squeeze()]).to(motif_mask.device)
zero_emb = torch.clone(self.source_embeddings[0]).to(motif_mask.device)
# timestep embedding for all residues
timestep_embedding = torch.stack([t_emb]*L)
# slice in motif zero timestep features
timestep_embedding[motif_mask] = zero_emb
return timestep_embedding
def forward(self, L, t, motif_mask):
"""
Constructs and projects a timestep embedding
"""
emb_in = self.get_init_emb(t,L,motif_mask)
emb_out = self.node_embedder(emb_in)
return emb_out
class PositionalEncoding2D(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
@@ -194,14 +118,6 @@ class Extra_emb(nn.Module):
# Sergey's one hot trick
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
"""
#TODO delete this once verified
if self.input_seq_onehot:
# Sergey's one hot trick
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
else:
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
"""
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
return self.drop(msa)
@@ -278,14 +194,14 @@ class Templ_emb(nn.Module):
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, p_drop=p_drop)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)
# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
# d_hidden=d_hidden, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)
self.reset_parameter()

View File

@@ -1,11 +1,10 @@
import torch
import torch.nn as nn
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling, Timestep_emb
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
from rfdiffusion.Track_module import IterativeSimulator
from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
from opt_einsum import contract as einsum
class RoseTTAFoldModule(nn.Module):
def __init__(self,
n_extra_block,
@@ -23,8 +22,6 @@ class RoseTTAFoldModule(nn.Module):
p_drop,
d_t1d,
d_t2d,
d_time_emb, # total dims for input timestep emb
d_time_emb_proj, # size of projected timestep emb
T, # total timesteps (used in timestep emb
use_motif_timestep, # Whether to have a distinct emb for motif
freeze_track_motif, # Whether to freeze updates to motif in track
@@ -47,15 +44,6 @@ class RoseTTAFoldModule(nn.Module):
n_head=n_head_templ,
d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
# timestep embedder
if d_time_emb:
print('NOTE: Using sinusoidal timestep embeddings of dim ',d_time_emb, ' projected to dim ',d_time_emb_proj)
assert d_t1d >= 22 + d_time_emb_proj, 'timestep projection size doesn\'t fit into RF t1d projection layers'
self.timestep_embedder = Timestep_emb(input_size=d_time_emb,
output_size=d_time_emb_proj,
T=T,
use_motif_timestep=use_motif_timestep)
# Update inputs with outputs from previous round
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
@@ -72,7 +60,7 @@ class RoseTTAFoldModule(nn.Module):
p_drop=p_drop)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa)
self.lddt_pred = LDDTNetwork(d_state)
self.exp_pred = ExpResolvedNetwork(d_msa, d_state)

View File

@@ -1,23 +1,7 @@
import numpy as np
import scipy
import scipy.spatial
# calculate dihedral angles defined by 4 sets of points
def get_dihedrals(a, b, c, d):
b0 = -1.0*(b - a)
b1 = c - b
b2 = d - c
b1 /= np.linalg.norm(b1, axis=-1)[:,None]
v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
x = np.sum(v*w, axis=-1)
y = np.sum(np.cross(b1, v)*w, axis=-1)
return np.arctan2(y, x)
from rfdiffusion.kinematics import get_dih
# calculate planar angles defined by 3 sets of points
def get_angles(a, b, c):
@@ -65,11 +49,10 @@ def get_coords6d(xyz, dmax):
# matrix of Ca-Cb-Cb-Ca dihedrals
omega6d = np.zeros((nres, nres), dtype=np.float32)
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
omega6d[idx0,idx1] = get_dih(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
# matrix of polar coord theta
theta6d = np.zeros((nres, nres), dtype=np.float32)
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
theta6d[idx0,idx1] = get_dih(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
# matrix of polar coord phi
phi6d = np.zeros((nres, nres), dtype=np.float32)

View File

@@ -1,255 +0,0 @@
import torch
import numpy as np
import random
from rfdiffusion.chemical import INIT_CRDS
from icecream import ic
def th_min_angle(start, end, radians=False):
"""
Finds the angle you would add to <start> in order to get to <end>
on the shortest path.
"""
a,b,c = (np.pi, 2*np.pi, 3*np.pi) if radians else (180, 360, 540)
shortest_angle = ((((end - start) % b) + c) % b) - a
return shortest_angle
def th_interpolate_angles(start, end, T, n_diffuse,mindiff=None, radians=True):
"""
"""
# find the minimum angle to add to get from start to end
angle_diffs = th_min_angle(start, end, radians=radians)
if mindiff is not None:
assert torch.sum(mindiff.flatten()-angle_diffs) == 0.
if n_diffuse is None:
# default is to diffuse for max steps
n_diffuse = torch.full((len(angle_diffs)), T)
interps = []
for i,diff in enumerate(angle_diffs):
N = int(n_diffuse[i])
actual_interp = torch.linspace(start[i], start[i]+diff, N)
whole_interp = torch.full((T,), float(start[i]+diff))
temp=torch.clone(whole_interp)
whole_interp[:N] = actual_interp
interps.append(whole_interp)
return torch.stack(interps, dim=0)
def th_interpolate_angle_single(start, end, step, T, mindiff=None, radians=True):
"""
"""
# find the minimum angle to add to get from start to end
angle_diffs = th_min_angle(start, end, radians=radians)
if mindiff is not None:
assert torch.sum(mindiff.flatten()-angle_diffs) == 0.
# linearly interpolate between x = [0, T-1], y = [start, start + diff]
x_range = T-1
interps = step / x_range * angle_diffs + start
return interps
def get_aa_schedule(T, L, nsteps=100):
"""
Returns the steps t when each amino acid should be decoded,
as well as how many steps that amino acids chi angles will be diffused
Parameters:
T (int, required): Total number of steps we are decoding the sequence over
L (int, required): Length of protein sequence
nsteps (int, optional): Number of steps over the course of which to decode the amino acids
Returns: three items
decode_times (list): List of times t when the positions in <decode_order> should be decoded
decode_order (list): List of lists, each element containing which positions are going to be decoded at
the corresponding time in <decode_times>
idx2diffusion_steps (np.array): Array mapping the index of the residue to how many diffusion steps it will require
"""
# nsteps can't be more than T or more than length of protein
if (nsteps > T) or (nsteps > L):
nsteps = min([T,L])
decode_order = [[a] for a in range(L)]
random.shuffle(decode_order)
while len(decode_order) > nsteps:
# pop an element and then add those positions randomly to some other step
tmp_seqpos = decode_order.pop()
decode_order[random.randint(0,len(decode_order)-1)] += tmp_seqpos
random.shuffle(decode_order)
decode_times = np.arange(nsteps)+1
# now given decode times, calculate number of diffusion steps each position gets
aa_masks = np.full((200,L), False)
idx2diffusion_steps = np.full((L,),float(np.nan))
for i,t in enumerate(decode_times):
decode_pos = decode_order[i] # positions to be decoded at this step
for j,pos in enumerate(decode_pos):
# calculate number of diffusion steps this residue gets
idx2diffusion_steps[pos] = int(t)
aa_masks[t,pos] = True
aa_masks = np.cumsum(aa_masks, axis=0)
return decode_times, decode_order, idx2diffusion_steps, ~(aa_masks.astype(bool))
####################
### for SecStruc ###
####################
def ss_to_tensor(ss_dict):
"""
Function to convert ss files to indexed tensors
0 = Helix
1 = Strand
2 = Loop
3 = Mask/unknown
4 = idx for pdb
"""
ss_conv = {'H':0,'E':1,'L':2}
ss_int = np.array([int(ss_conv[i]) for i in ss_dict['ss']])
return ss_int
def mask_ss(ss, min_mask = 0, max_mask = 0.75):
"""
Function to take ss array, find the junctions, and randomly mask these until a random proportion (up to 75%) is masked
Input: numpy array of ss (H=0,E=1,L=2,mask=3)
output: tensor with some proportion of junctions masked
"""
mask_prop = random.uniform(min_mask, max_mask)
transitions = np.where(ss[:-1] - ss[1:] != 0)[0] #gets last index of each block of ss
counter = 0
#TODO think about masking whole ss elements
while len(ss[ss == 3])/len(ss) < mask_prop and counter < 100: #very hacky - do better
try:
width = random.randint(1,9)
start = random.choice(transitions)
offset = random.randint(-8,1)
ss[start+offset:start+offset+width] = 3
counter += 1
except:
counter += 1
ss = torch.tensor(ss)
mask = torch.where(ss == 3, True, False)
ss = torch.nn.functional.one_hot(ss, num_classes=4)
return ss, mask
def construct_block_adj_matrix( sstruct, xyz, nan_mask, cutoff=6, include_loops=False ):
'''
Given a sstruct specification and backbone coordinates, build a block adjacency matrix.
Input:
sstruct (torch.FloatTensor): (L) length tensor with numeric encoding of sstruct at each position
xyz (torch.FloatTensor): (L,3,3) tensor of Cartesian coordinates of backbone N,Ca,C atoms
cutoff (float): The Cb distance cutoff under which residue pairs are considered adjacent
By eye, Nate thinks 6A is a good Cb distance cutoff
Output:
block_adj (torch.FloatTensor): (L,L) boolean matrix where adjacent secondary structure contacts are 1
'''
# Remove nans at this stage, as ss doesn't consider nans
xyz_nonan = xyz[nan_mask]
L = xyz_nonan.shape[0]
assert L == sstruct.shape[0]
# three anchor atoms
N = xyz_nonan[:,0]
Ca = xyz_nonan[:,1]
C = xyz_nonan[:,2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
dist = get_pair_dist(Cb,Cb) # [L,L]
dist[torch.isnan(dist)] = 999.9
assert torch.sum(torch.isnan(dist)) == 0
dist += 999.9*torch.eye(L,device=xyz.device)
# Now we have dist matrix and sstruct specification, turn this into a block adjacency matrix
# First: Construct a list of segments and the index at which they begin and end
in_segment = True
segments = []
begin = -1
end = -1
# need to expand ss out to size L
for i in range(sstruct.shape[0]):
# Starting edge case
if i == 0:
begin = 0
continue
if not sstruct[i] == sstruct[i-1]:
end = i
segments.append( (sstruct[i-1], begin, end) )
begin = i
# Ending edge case: last segment is length one
if not end == sstruct.shape[0]:
segments.append( (sstruct[-1], begin, sstruct.shape[0]) )
# Second: Using segments and dgram, determine adjacent blocks
block_adj = torch.zeros_like(dist)
for i in range(len(segments)):
curr_segment = segments[i]
if curr_segment[0] == 2 and not include_loops: continue
begin_i = curr_segment[1]
end_i = curr_segment[2]
for j in range(i+1, len(segments)):
j_segment = segments[j]
if j_segment[0] == 2 and not include_loops: continue
begin_j = j_segment[1]
end_j = j_segment[2]
if torch.any( dist[begin_i:end_i, begin_j:end_j] < cutoff ):
# Matrix is symmetic
block_adj[begin_i:end_i, begin_j:end_j] = torch.ones(end_i - begin_i, end_j - begin_j)
block_adj[begin_j:end_j, begin_i:end_i] = torch.ones(end_j - begin_j, end_i - begin_i)
return block_adj
def get_pair_dist(a, b):
"""calculate pair distances between two sets of points
Parameters
----------
a,b : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of two sets of atoms
Returns
-------
dist : pytorch tensor of shape [batch,nres,nres]
stores paitwise distances between atoms in a and b
"""
dist = torch.cdist(a, b, p=2)
return dist

View File

@@ -237,8 +237,7 @@ class IGSO3:
num_sigma=self.num_sigma,
min_sigma=self.min_sigma,
max_sigma=self.max_sigma,
num_omega=self.num_omega,
L=L,
num_omega=self.num_omega
)
write_pkl(cache_fname, igso3_vals)

View File

@@ -71,7 +71,7 @@ def igso3_score(R, t, L=L_default):
unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None]
return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None]
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma, L=L_default):
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma):
"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs
and score norms and expected squared score norms.

View File

@@ -247,7 +247,6 @@ class Sampler:
'L': L,
'diffuser': self.diffuser,
'potential_manager': self.potential_manager,
'visible': visible
})
return iu.Denoise(**denoise_kwargs)

View File

@@ -223,8 +223,6 @@ class Denoise:
T,
L,
diffuser,
visible,
seq_diffuser=None,
b_0=0.001,
b_T=0.1,
min_b=1.0,
@@ -256,7 +254,6 @@ class Denoise:
self.T = T
self.L = L
self.diffuser = diffuser
self.seq_diffuser = seq_diffuser
self.b_0 = b_0
self.b_T = b_T
self.noise_level = noise_level
@@ -301,8 +298,6 @@ class Denoise:
Third, centre at origin
"""
# if True:
# return px0
def rmsd(V, W, eps=0):
# First sum down atoms, then sum down xyz
N = V.shape[-2]
@@ -358,17 +353,12 @@ class Denoise:
px0[~atom_mask] = 0 # convert nans to 0
px0 = px0.reshape(-1, 3) - px0_motif_mean
px0_ = px0 @ R
# xT_motif_out = xT_motif.reshape(-1,3)
# xT_motif_out = (xT_motif_out @ R ) + px0_motif_mean
# ic(xT_motif_out.shape)
# xT_motif_out = xT_motif_out.reshape((diffusion_mask.sum(),3,3))
# 3 put in same global position as xT
px0_ = px0_ + xT_motif_mean
px0_ = px0_.reshape([L, n_atom, 3])
px0_[~atom_mask] = float("nan")
return torch.Tensor(px0_)
# return torch.tensor(xT_motif_out)
def get_potential_gradients(self, xyz, diffusion_mask):
"""

View File

@@ -1,6 +1,7 @@
import numpy as np
import torch
from rfdiffusion.chemical import INIT_CRDS
from rfdiffusion.util import generate_Cbeta
PARAMS = {
"DMIN" : 2.0,
@@ -55,13 +56,18 @@ def get_dih(a, b, c, d):
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
a,b,c,d : pytorch tensors or numpy array of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
dih : pytorch tensor or numpy array of shape [batch,nres]
stores resulting dihedrals
"""
convert_to_torch = lambda *arrays: [torch.from_numpy(arr) for arr in arrays]
output_np=False
if isinstance(a, np.ndarray):
output_np=True
a,b,c,d = convert_to_torch(a,b,c,d)
b0 = a - b
b1 = c - b
b2 = d - c
@@ -73,18 +79,10 @@ def get_dih(a, b, c, d):
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
return torch.atan2(y, x)
def get_Cb(xyz):
'''recreate Cb given N,Ca,C'''
N = xyz[...,0,:]
Ca = xyz[...,1,:]
C = xyz[...,2,:]
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
return -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
output = torch.atan2(y, x)
if output_np:
return output.numpy()
return output
# ============================================================
def xyz_to_c6d(xyz, params=PARAMS):
@@ -108,7 +106,7 @@ def xyz_to_c6d(xyz, params=PARAMS):
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
Cb = get_Cb(xyz)
Cb = generate_Cbeta(N, Ca, C)
# 6d coordinates order: (dist,omega,theta,phi)
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)

View File

@@ -113,9 +113,6 @@ class PotentialManager:
if setting['type'] in potentials.require_binderlen:
setting.update(binderlen_update)
if setting['type'] in potentials.require_hotspot_res:
setting.update(hotspot_res_update)
self.potentials_to_apply = self.initialize_all_potentials(setting_list)
self.T = diffuser_config.T
@@ -199,7 +196,7 @@ class PotentialManager:
# Linear interpolation with y2: 0, y1: guide_scale, x2: 0, x1: T, x: t
'linear' : lambda t: t/self.T * self.guide_scale,
'quadratic' : lambda t: t**2/self.T**2 * self.guide_scale,
'cubic' : lambda t: t**3/self.T**3
'cubic' : lambda t: t**3/self.T**3 * self.guide_scale
}
if self.guide_decay not in implemented_decay_types:

View File

@@ -146,52 +146,6 @@ class binder_ncontacts(Potential):
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
return self.weight * binder_ncontacts.sum()
class dimer_ncontacts(Potential):
'''
Differentiable way to maximise number of contacts for two individual monomers in a dimer
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html
Author: PV
'''
def __init__(self, binderlen, weight=1, r_0=8, d_0=4):
self.binderlen = binderlen
self.r_0 = r_0
self.weight = weight
self.d_0 = d_0
def compute(self, xyz):
# Only look at binder Ca residues
Ca = xyz[:self.binderlen,1] # [Lb,3]
#cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
binder_ncontacts = (1 - numerator) / (1 - denominator)
#Potential is the sum of values in the tensor
binder_ncontacts = binder_ncontacts.sum()
# Only look at target Ca residues
Ca = xyz[self.binderlen:,1] # [Lb,3]
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
target_ncontacts = (1 - numerator) / (1 - denominator)
#Potential is the sum of values in the tensor
target_ncontacts = target_ncontacts.sum()
print("DIMER NCONTACTS:", (binder_ncontacts+target_ncontacts)/2)
#Returns average of n contacts withiin monomer 1 and monomer 2
return self.weight * (binder_ncontacts+target_ncontacts)/2
class interface_ncontacts(Potential):
'''
@@ -266,42 +220,6 @@ class monomer_contacts(Potential):
return self.weight * ncontacts.sum()
def make_contact_matrix(nchain, contact_string=None):
"""
Calculate a matrix of inter/intra chain contact indicators
Parameters:
nchain (int, required): How many chains are in this design
contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains
'!' denotes repulsive, '&' denotes attractive
"""
alphabet = [a for a in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ']
letter2num = {a:i for i,a in enumerate(alphabet)}
contacts = np.zeros((nchain,nchain))
written = np.zeros((nchain,nchain))
contact_list = contact_string.split(',')
for c in contact_list:
if not len(c) == 3:
raise SyntaxError('Invalid contact(s) specification')
i,j = letter2num[c[0]],letter2num[c[2]]
symbol = c[1]
# denote contacting/repulsive
assert symbol in ['!','&']
if symbol == '!':
contacts[i,j] = -1
contacts[j,i] = -1
else:
contacts[i,j] = 1
contacts[j,i] = 1
return contacts
class olig_contacts(Potential):
"""
Applies PV's num contacts potential within/between chains in symmetric oligomers
@@ -343,17 +261,6 @@ class olig_contacts(Potential):
self.nchain=shape[0]
# self._compute_chain_indices()
# def _compute_chain_indices(self):
# # make list of shape [i,N] for indices of each chain in total length
# indices = []
# start = 0
# for l in self.chain_lengths:
# indices.append(torch.arange(start,start+l))
# start += l
# self.indices = indices
def _get_idx(self,i,L):
"""
Returns the zero-indexed indices of the residues in chain i
@@ -398,51 +305,6 @@ class olig_contacts(Potential):
return all_contacts
class olig_intra_contacts(Potential):
"""
Applies PV's num contacts potential for each chain individually in an oligomer design
Author: DJ
"""
def __init__(self, chain_lengths, weight=1):
"""
Parameters:
chain_lengths (list, required): Ordered list of chain lengths
weight (int/float, optional): Scaling/weighting factor
"""
self.chain_lengths = chain_lengths
self.weight = weight
def compute(self, xyz):
"""
Computes intra-chain num contacts potential
"""
assert sum(self.chain_lengths) == xyz.shape[0], 'given chain lengths do not match total protein length'
all_contacts = 0
start = 0
for Lc in self.chain_lengths:
Ca = xyz[start:start+Lc] # slice out crds for this chain
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
ncontacts = (1 - numerator) / (1 - denominator)
# add contacts for this chain to all contacts
all_contacts += ncontacts.sum()
# increment the start to be at the next chain
start += Lc
return self.weight * all_contacts
def get_damped_lj(r_min, r_lin,p1=6,p2=12):
y_at_r_lin = lj(r_lin, r_min, p1, p2)
@@ -592,131 +454,15 @@ class substrate_contacts(Potential):
self.motif_frame = xyz[rand_idx[0],:4]
self.motif_mapping = [(rand_idx, i) for i in range(4)]
class binder_distance_ReLU(Potential):
'''
Given the current coordinates of the diffusion trajectory, calculate a potential that is the distance between each residue
and the closest target residue.
This potential is meant to encourage the binder to interact with a certain subset of residues on the target that
define the binding site.
Author: NRB
'''
def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False):
self.binderlen = binderlen
self.hotspot_res = [res + binderlen for res in hotspot_res]
self.weight = weight
self.min_dist = min_dist
self.use_Cb = use_Cb
def compute(self, xyz):
binder = xyz[:self.binderlen,:,:] # (Lb,27,3)
target = xyz[self.hotspot_res,:,:] # (N,27,3)
if self.use_Cb:
N = binder[:,0]
Ca = binder[:,1]
C = binder[:,2]
Cb = generate_Cbeta(N,Ca,C) # (Lb,3)
N_t = target[:,0]
Ca_t = target[:,1]
C_t = target[:,2]
Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3)
dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N)
else:
# Use Ca dist for potential
Ca = binder[:,1] # (Lb,3)
Ca_t = target[:,1] # (N,3)
dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N)
closest_dist = torch.min(dgram.squeeze(0), dim=1)[0] # (Lb)
# Cap the distance at a minimum value
min_distance = self.min_dist * torch.ones_like(closest_dist) # (Lb)
potential = torch.maximum(min_distance, closest_dist) # (Lb)
# torch.Tensor.backward() requires the potential to be a single value
potential = torch.sum(potential, dim=-1)
return -1 * self.weight * potential
class binder_any_ReLU(Potential):
'''
Given the current coordinates of the diffusion trajectory, calculate a potential that is the minimum distance between
ANY residue and the closest target residue.
In contrast to binder_distance_ReLU this potential will only penalize a pose if all of the binder residues are outside
of a certain distance from the target residues.
Author: NRB
'''
def __init__(self, binderlen, hotspot_res, weight=1, min_dist=15, use_Cb=False):
self.binderlen = binderlen
self.hotspot_res = [res + binderlen for res in hotspot_res]
self.weight = weight
self.min_dist = min_dist
self.use_Cb = use_Cb
def compute(self, xyz):
binder = xyz[:self.binderlen,:,:] # (Lb,27,3)
target = xyz[self.hotspot_res,:,:] # (N,27,3)
if use_Cb:
N = binder[:,0]
Ca = binder[:,1]
C = binder[:,2]
Cb = generate_Cbeta(N,Ca,C) # (Lb,3)
N_t = target[:,0]
Ca_t = target[:,1]
C_t = target[:,2]
Cb_t = generate_Cbeta(N_t,Ca_t,C_t) # (N,3)
dgram = torch.cdist(Cb[None,...], Cb_t[None,...], p=2) # (1,Lb,N)
else:
# Use Ca dist for potential
Ca = binder[:,1] # (Lb,3)
Ca_t = target[:,1] # (N,3)
dgram = torch.cdist(Ca[None,...], Ca_t[None,...], p=2) # (1,Lb,N)
closest_dist = torch.min(dgram.squeeze(0)) # (1)
potential = torch.maximum(min_dist, closest_dist) # (1)
return -1 * self.weight * potential
# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
# If you implement a new potential you must add it to this dictionary for it to be used by
# the PotentialManager
implemented_potentials = { 'monomer_ROG': monomer_ROG,
'binder_ROG': binder_ROG,
'binder_distance_ReLU': binder_distance_ReLU,
'binder_any_ReLU': binder_any_ReLU,
'dimer_ROG': dimer_ROG,
'binder_ncontacts': binder_ncontacts,
'dimer_ncontacts': dimer_ncontacts,
'interface_ncontacts': interface_ncontacts,
'monomer_contacts': monomer_contacts,
'olig_intra_contacts': olig_intra_contacts,
'olig_contacts': olig_contacts,
'substrate_contacts': substrate_contacts}
@@ -725,9 +471,5 @@ require_binderlen = { 'binder_ROG',
'binder_any_ReLU',
'dimer_ROG',
'binder_ncontacts',
'dimer_ncontacts',
'interface_ncontacts'}
require_hotspot_res = { 'binder_distance_ReLU',
'binder_any_ReLU' }

View File

@@ -8,9 +8,10 @@ def generate_Cbeta(N, Ca, C):
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
# Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
# These are the values used during training
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
# fd: below matches sidechain generator (=Rosetta params)
Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
# Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
return Cb

View File

@@ -7,7 +7,7 @@ import copy
import dgl
from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points
def init_lecun_normal(module, scale=1.0):
def init_lecun_normal(module):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
@@ -23,14 +23,14 @@ def init_lecun_normal(module, scale=1.0):
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
def sample_truncated_normal(shape):
stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
return module
def init_lecun_normal_param(weight, scale=1.0):
def init_lecun_normal_param(weight):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
@@ -46,8 +46,8 @@ def init_lecun_normal_param(weight, scale=1.0):
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
def sample_truncated_normal(shape):
stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )