Files
ReQFlow/data/interpolant.py
Angxiao Yue 5bad7f2134 upload code
2025-02-20 17:54:00 +08:00

488 lines
20 KiB
Python

from collections import defaultdict
import torch
from data import so3_utils
from data import utils as du
from scipy.spatial.transform import Rotation
from data import all_atom
import copy
from torch import autograd
from motif_scaffolding import twisting
from openfold.utils.rigid_utils import rot_to_quat
def _centered_gaussian(num_batch, num_res, device):
noise = torch.randn(num_batch, num_res, 3, device=device)
return noise - torch.mean(noise, dim=-2, keepdims=True)
def _uniform_so3(num_batch, num_res, device):
return torch.tensor(
Rotation.random(num_batch*num_res).as_matrix(),
device=device,
dtype=torch.float32,
).reshape(num_batch, num_res, 3, 3)
def _trans_diffuse_mask(trans_t, trans_1, diffuse_mask):
return trans_t * diffuse_mask[..., None] + trans_1 * (1 - diffuse_mask[..., None])
def _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask):
return (
rotmats_t * diffuse_mask[..., None, None]
+ rotmats_1 * (1 - diffuse_mask[..., None, None])
)
def _rots_quats_diffuse_mask(rotquats_t, rotquats_1, diffuse_mask):
return (
rotquats_t * diffuse_mask[..., None]
+ rotquats_1 * (1 - diffuse_mask[..., None])
)
class Interpolant:
def __init__(self, cfg):
self._cfg = cfg
self._rots_cfg = cfg.rots
self._trans_cfg = cfg.trans
self._sample_cfg = cfg.sampling
self._igso3 = None
@property
def igso3(self):
if self._igso3 is None:
sigma_grid = torch.linspace(0.1, 1.5, 1000)
self._igso3 = so3_utils.SampleIGSO3(
1000, sigma_grid, cache_dir='.cache')
return self._igso3
def set_device(self, device):
self._device = device
def sample_t(self, num_batch):
t = torch.rand(num_batch, device=self._device)
truncate = None
if truncate is None:
return t * (1 - 2*self._cfg.min_t) + self._cfg.min_t # [min_t, 1-min_t]
else:
return t * truncate + self._cfg.min_t # [min_t, min_t + truncate]
def _corrupt_trans(self, trans_1, t, res_mask, diffuse_mask, trans_0=None):
if trans_0 is None:
trans_nm_0 = _centered_gaussian(*res_mask.shape, self._device)
trans_0 = trans_nm_0 * du.NM_TO_ANG_SCALE
trans_t = (1 - t[..., None]) * trans_0 + t[..., None] * trans_1
trans_t = _trans_diffuse_mask(trans_t, trans_1, diffuse_mask)
return trans_t * res_mask[..., None]
def _corrupt_rotmats(self, rotmats_1, t, res_mask, diffuse_mask, rotmats_0=None):
num_batch, num_res = res_mask.shape
if rotmats_0 is None:
noisy_rotmats = self.igso3.sample(
torch.tensor([1.5]),
num_batch*num_res
).to(self._device)
noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
rotmats_0 = torch.einsum(
"...ij,...jk->...ik", rotmats_1, noisy_rotmats)
rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
identity = torch.eye(3, device=self._device)
rotmats_t = (
rotmats_t * res_mask[..., None, None]
+ identity[None, None] * (1 - res_mask[..., None, None])
)
return _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask)
def _corrupt_rotquats(self, rotmats_1, rotquats_1, t, res_mask, diffuse_mask, rotquats_0=None):
num_batch, num_res = res_mask.shape
if rotquats_0 is None:
noisy_rotmats = self.igso3.sample(
torch.tensor([1.5]),
num_batch*num_res
).to(self._device)
noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
rotmats_0 = torch.einsum(
"...ij,...jk->...ik", rotmats_1, noisy_rotmats)
rotquats_0 = rot_to_quat(rotmats_0) # [B, N, 4]
# rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
# rotquats_0 = rot_to_quat(rotmats_0)
rotquats_t = so3_utils.quaternion_slerp_exp(t.expand(-1, num_res), rotquats_1, rotquats_0) # t.expand(-1, num_res): [B, N]
identity = torch.tensor([1, 0, 0, 0], device=self._device)
rotquats_t = (
rotquats_t * res_mask[..., None]
+ identity[None] * (1 - res_mask[..., None])
) # rotquats_t: [B, N, 4], res_mask: [B, N], identity: [4]
return _rots_quats_diffuse_mask(rotquats_t, rotquats_1, diffuse_mask)
def corrupt_batch(self, batch):
# print("training: corrupt_batch")
noisy_batch = copy.deepcopy(batch)
# [B, N, 3]
trans_1 = batch['trans_1'] # Angstrom
# [B, N, 3, 3]
rotmats_1 = batch['rotmats_1']
# [B, N, 4]
rotquats_1 = batch['rotquats_1']
# [B, N]
res_mask = batch['res_mask']
diffuse_mask = batch['diffuse_mask']
num_batch, _ = diffuse_mask.shape
# [B, 1]
t = self.sample_t(num_batch)[:, None]
so3_t = t
r3_t = t
noisy_batch['so3_t'] = so3_t
noisy_batch['r3_t'] = r3_t
# Apply corruptions
if self._trans_cfg.corrupt:
trans_t = self._corrupt_trans(
trans_1, r3_t, res_mask, diffuse_mask)
else:
trans_t = trans_1
if torch.any(torch.isnan(trans_t)):
raise ValueError('NaN in trans_t during corruption')
noisy_batch['trans_t'] = trans_t
if self._rots_cfg.corrupt:
rotquats_t = self._corrupt_rotquats(
rotmats_1, rotquats_1, so3_t, res_mask, diffuse_mask)
else:
rotquats_t = rotquats_1
if torch.any(torch.isnan(rotquats_t)):
raise ValueError('NaN in rotquats_t during corruption')
noisy_batch['rotquats_t'] = rotquats_t
return noisy_batch
def rectify_corrupt_batch(self, batch):
noisy_batch = copy.deepcopy(batch['noise'])
# [B, N, 3]
trans_1 = batch['sample']['trans_1'] # The true translation
# [B, N, 3, 3]
rotmats_1 = batch['sample']['rotmats_1']
# [B, N, 4]
rotquats_1 = batch['sample']['rotquats_1']
# [B, N]
res_mask = batch['sample']['res_mask']
diffuse_mask = batch['sample']['diffuse_mask']
num_batch, _ = diffuse_mask.shape
t = self.sample_t(num_batch)[:, None]
so3_t = t
r3_t = t
noisy_batch['so3_t'] = so3_t
noisy_batch['r3_t'] = r3_t
# Apply corruptions
if self._trans_cfg.corrupt:
trans_t = self._corrupt_trans(
trans_1, r3_t, res_mask, diffuse_mask, trans_0=noisy_batch['trans_1'])
else:
trans_t = trans_1
if torch.any(torch.isnan(trans_t)):
raise ValueError('NaN in trans_t during corruption')
if self._rots_cfg.corrupt:
rotquats_t = self._corrupt_rotquats(
rotmats_1, rotquats_1, so3_t, res_mask, diffuse_mask, rotquats_0=noisy_batch['rotquats_1'])
else:
rotquats_t = rotquats_1
if torch.any(torch.isnan(rotquats_t)):
raise ValueError('NaN in rotmats_t during corruption')
noisy_batch['trans_1'] = trans_1
noisy_batch['rotmats_1'] = rotmats_1
noisy_batch['rotquats_1'] = rotquats_1
noisy_batch['trans_t'] = trans_t
noisy_batch['rotquats_t'] = rotquats_t
return noisy_batch
def rot_sample_kappa(self, t):
if self._rots_cfg.sample_schedule == 'exp':
return 1 - torch.exp(-t*self._rots_cfg.exp_rate)
elif self._rots_cfg.sample_schedule == 'linear':
return t
else:
raise ValueError(
f'Invalid schedule: {self._rots_cfg.sample_schedule}')
def _trans_vector_field(self, t, trans_1, trans_t):
return (trans_1 - trans_t) / (1 - t)
def _trans_euler_step(self, d_t, t, trans_1, trans_t):
assert d_t > 0
trans_vf = self._trans_vector_field(t, trans_1, trans_t)
return trans_t + trans_vf * d_t
def _rots_euler_step(self, d_t, t, rotmats_1, rotmats_t):
if self._rots_cfg.sample_schedule == 'linear':
scaling = 1 / (1 - t)
elif self._rots_cfg.sample_schedule == 'exp':
scaling = self._rots_cfg.exp_rate
else:
raise ValueError(
f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
return so3_utils.geodesic_t(
scaling * d_t, rotmats_1, rotmats_t)
def _rots_quats_euler_step(self, d_t, t, rotquats_1, rotquats_t):
if self._rots_cfg.sample_schedule == 'linear':
scaling = 1 / (1 - t)
elif self._rots_cfg.sample_schedule == 'exp':
scaling = self._rots_cfg.exp_rate
else:
raise ValueError(
f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
return so3_utils.quaternion_slerp_exp(
scaling * d_t, rotquats_1, rotquats_t)
def sample(
self,
num_batch,
num_res,
model,
num_timesteps=None,
trans_potential=None,
trans_0=None,
rotmats_0=None,
trans_1=None,
rotmats_1=None,
rotquats_1=None,
diffuse_mask=None,
chain_idx=None,
res_idx=None,
verbose=False,
):
res_mask = torch.ones(num_batch, num_res, device=self._device)
# Set-up initial prior samples
if trans_0 is None:
trans_0 = _centered_gaussian(
num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE
if rotmats_0 is None:
#* For uniform sampling
rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
rotquats_0 = rot_to_quat(rotmats_0)
#* For IGSO3 sampling
# noisy_rotmats = self.igso3.sample(
# torch.tensor([1.5]),
# num_batch*num_res
# ).to(self._device)
# noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
# rotmats_0 = torch.einsum(
# "...ij,...jk->...ik", rotmats_1, noisy_rotmats)
# rotquats_0 = rot_to_quat(rotmats_0) # [B, N, 4]
if res_idx is None:
res_idx = torch.arange(
num_res,
device=self._device,
dtype=torch.float32)[None].repeat(num_batch, 1)
batch = {
'res_mask': res_mask,
'diffuse_mask': res_mask,
'res_idx': res_idx
}
motif_scaffolding = False
if diffuse_mask is not None and trans_1 is not None and rotmats_1 is not None and rotquats_1 is not None:
motif_scaffolding = True
motif_mask = ~diffuse_mask.bool().squeeze(0)
else:
motif_mask = None
if motif_scaffolding and not self._cfg.twisting.use: # amortisation
diffuse_mask = diffuse_mask.expand(num_batch, -1) # shape = (B, num_residue)
batch['diffuse_mask'] = diffuse_mask
rotmats_0 = _rots_diffuse_mask(rotmats_0, rotmats_1, diffuse_mask)
rotquats_0 = _rots_quats_diffuse_mask(rotquats_0, rotquats_1, diffuse_mask)
trans_0 = _trans_diffuse_mask(trans_0, trans_1, diffuse_mask)
if torch.isnan(trans_0).any():
raise ValueError('NaN detected in trans_0')
logs_traj = defaultdict(list)
if motif_scaffolding and self._cfg.twisting.use: # sampling / guidance
assert trans_1.shape[0] == 1 # assume only one motif
motif_locations = torch.nonzero(motif_mask).squeeze().tolist()
true_motif_locations, motif_segments_length = twisting.find_ranges_and_lengths(motif_locations)
# Marginalise both rotation and motif location
assert len(motif_mask.shape) == 1
trans_motif = trans_1[:, motif_mask] # [1, motif_res, 3]
R_motif = rotmats_1[:, motif_mask] # [1, motif_res, 3, 3]
# TODO: quaternion version for motif_scaffolding
num_res = trans_1.shape[-2]
with torch.inference_mode(False):
motif_locations = true_motif_locations if self._cfg.twisting.motif_loc else None
F, motif_locations = twisting.motif_offsets_and_rots_vec_F(num_res, motif_segments_length, motif_locations=motif_locations, num_rots=self._cfg.twisting.num_rots, align=self._cfg.twisting.align, scale=self._cfg.twisting.scale_rots, trans_motif=trans_motif, R_motif=R_motif, max_offsets=self._cfg.twisting.max_offsets, device=self._device, dtype=torch.float64, return_rots=False)
if motif_mask is not None and len(motif_mask.shape) == 1:
motif_mask = motif_mask[None].expand((num_batch, -1))
# Set-up time
if num_timesteps is None:
num_timesteps = self._sample_cfg.num_timesteps
ts = torch.linspace(self._cfg.min_t, 1.0, num_timesteps)
t_1 = ts[0]
prot_traj = [(trans_0, rotmats_0)]
prot_traj_quats = [(trans_0, rotquats_0)]
clean_traj = []
clean_traj_quats = []
for i, t_2 in enumerate(ts[1:]):
if verbose: # and i % 1 == 0:
print(f'{i=}, t={t_1.item():.2f}')
print(torch.cuda.mem_get_info(trans_0.device), torch.cuda.memory_allocated(trans_0.device))
# Run model.
trans_t_1, rotquats_t_1 = prot_traj_quats[-1]
if self._trans_cfg.corrupt:
batch['trans_t'] = trans_t_1
else:
if trans_1 is None:
raise ValueError('Must provide trans_1 if not corrupting.')
batch['trans_t'] = trans_1
if self._rots_cfg.corrupt:
batch['rotquats_t'] = rotquats_t_1
else:
if rotquats_1 is None:
raise ValueError('Must provide rotmats_1 if not corrupting.')
batch['rotquats_t'] = rotquats_1
batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
batch['so3_t'] = batch['t']
batch['r3_t'] = batch['t']
d_t = t_2 - t_1
use_twisting = motif_scaffolding and self._cfg.twisting.use and t_1 >= self._cfg.twisting.t_min
if use_twisting: # Reconstruction guidance
with torch.inference_mode(False):
batch, Log_delta_R, delta_x = twisting.perturbations_for_grad(batch)
model_out = model(batch)
t = batch['r3_t'] #TODO: different time for SO3?
trans_t_1, rotmats_t_1, logs_traj = self.guidance(trans_t_1, rotmats_t_1, model_out, motif_mask, R_motif, trans_motif, Log_delta_R, delta_x, t, d_t, logs_traj)
else:
with torch.no_grad():
model_out = model(batch)
# Process model output.
pred_trans_1 = model_out['pred_trans']
pred_rotmats_1 = model_out['pred_rotmats']
pred_rotquats_1 = model_out['pred_rotquats']
clean_traj_quats.append(
(pred_trans_1.detach().cpu(), pred_rotquats_1.detach().cpu())
)
if self._cfg.self_condition:
if motif_scaffolding:
batch['trans_sc'] = (
pred_trans_1 * diffuse_mask[..., None]
+ trans_1 * (1 - diffuse_mask[..., None])
)
else:
batch['trans_sc'] = pred_trans_1
# Take reverse step
trans_t_2 = self._trans_euler_step(
d_t, t_1, pred_trans_1, trans_t_1)
if trans_potential is not None:
with torch.inference_mode(False):
grad_pred_trans_1 = pred_trans_1.clone().detach().requires_grad_(True)
pred_trans_potential = autograd.grad(outputs=trans_potential(grad_pred_trans_1), inputs=grad_pred_trans_1)[0]
if self._trans_cfg.potential_t_scaling:
trans_t_2 -= t_1 / (1 - t_1) * pred_trans_potential * d_t
else:
trans_t_2 -= pred_trans_potential * d_t
rotquats_t_2 = self._rots_quats_euler_step(
d_t, t_1, pred_rotquats_1, rotquats_t_1)
if motif_scaffolding and not self._cfg.twisting.use:
trans_t_2 = _trans_diffuse_mask(trans_t_2, trans_1, diffuse_mask)
rotquats_t_2 = _rots_quats_diffuse_mask(rotquats_t_2, rotquats_1, diffuse_mask)
prot_traj_quats.append((trans_t_2, rotquats_t_2))
t_1 = t_2
# We only integrated to min_t, so need to make a final step
t_1 = ts[-1]
trans_t_1, rotquats_t_1 = prot_traj_quats[-1]
if self._trans_cfg.corrupt:
batch['trans_t'] = trans_t_1
else:
if trans_1 is None:
raise ValueError('Must provide trans_1 if not corrupting.')
batch['trans_t'] = trans_1
if self._rots_cfg.corrupt:
batch['rotquats_t'] = rotquats_t_1
else:
if rotquats_1 is None:
raise ValueError('Must provide rotquats_1 if not corrupting.')
batch['rotquats_t'] = rotquats_1
batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
with torch.no_grad():
model_out = model(batch)
pred_trans_1 = model_out['pred_trans']
pred_rotmats_1 = model_out['pred_rotmats']
pred_rotquats_1 = model_out['pred_rotquats']
clean_traj.append(
(pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
)
clean_traj_quats.append(
(pred_trans_1.detach().cpu(), pred_rotquats_1.detach().cpu())
)
prot_traj.append((pred_trans_1, pred_rotmats_1))
prot_traj_quats.append((pred_trans_1, pred_rotquats_1))
# Convert trajectories to atom37.
atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask)
clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask)
return prot_traj, atom37_traj, clean_atom37_traj, clean_traj
def guidance(self, trans_t, rotmats_t, model_out, motif_mask, R_motif, trans_motif, Log_delta_R, delta_x, t, d_t, logs_traj):
# Select motif
motif_mask = motif_mask.clone()
trans_pred = model_out['pred_trans'][:, motif_mask] # [B, motif_res, 3]
R_pred = model_out['pred_rotmats'][:, motif_mask] # [B, motif_res, 3, 3]
# Proposal for marginalising motif rotation
F = twisting.motif_rots_vec_F(trans_motif, R_motif, self._cfg.twisting.num_rots, align=self._cfg.twisting.align, scale=self._cfg.twisting.scale_rots, device=self._device, dtype=torch.float32)
# Estimate p(motif|predicted_motif)
grad_Log_delta_R, grad_x_log_p_motif, logs = twisting.grad_log_lik_approx(R_pred, trans_pred, R_motif, trans_motif, Log_delta_R, delta_x, None, None, None, F, twist_potential_rot=self._cfg.twisting.potential_rot, twist_potential_trans=self._cfg.twisting.potential_trans)
with torch.no_grad():
# Choose scaling
t_trans = t
t_so3 = t
if self._cfg.twisting.scale_w_t == 'ot':
var_trans = ((1 - t_trans) / t_trans)[:, None]
var_rot = ((1 - t_so3) / t_so3)[:, None, None]
elif self._cfg.twisting.scale_w_t == 'linear':
var_trans = (1 - t)[:, None]
var_rot = (1 - t_so3)[:, None, None]
elif self._cfg.twisting.scale_w_t == 'constant':
num_batch = trans_pred.shape[0]
var_trans = torch.ones((num_batch, 1, 1)).to(R_pred.device)
var_rot = torch.ones((num_batch, 1, 1, 1)).to(R_pred.device)
var_trans = var_trans + self._cfg.twisting.obs_noise ** 2
var_rot = var_rot + self._cfg.twisting.obs_noise ** 2
trans_scale_t = self._cfg.twisting.scale / var_trans
rot_scale_t = self._cfg.twisting.scale / var_rot
# Compute update
trans_t, rotmats_t = twisting.step(trans_t, rotmats_t, grad_x_log_p_motif, grad_Log_delta_R, d_t, trans_scale_t, rot_scale_t, self._cfg.twisting.update_trans, self._cfg.twisting.update_rot)
# delete unsused arrays to prevent from any memory leak
del grad_Log_delta_R
del grad_x_log_p_motif
del Log_delta_R
del delta_x
for key, value in model_out.items():
model_out[key] = value.detach().requires_grad_(False)
return trans_t, rotmats_t, logs_traj