mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 12:14:23 +08:00
488 lines
20 KiB
Python
488 lines
20 KiB
Python
from collections import defaultdict
|
|
import torch
|
|
from data import so3_utils
|
|
from data import utils as du
|
|
from scipy.spatial.transform import Rotation
|
|
from data import all_atom
|
|
import copy
|
|
from torch import autograd
|
|
from motif_scaffolding import twisting
|
|
|
|
from openfold.utils.rigid_utils import rot_to_quat
|
|
|
|
def _centered_gaussian(num_batch, num_res, device):
|
|
noise = torch.randn(num_batch, num_res, 3, device=device)
|
|
return noise - torch.mean(noise, dim=-2, keepdims=True)
|
|
|
|
def _uniform_so3(num_batch, num_res, device):
|
|
return torch.tensor(
|
|
Rotation.random(num_batch*num_res).as_matrix(),
|
|
device=device,
|
|
dtype=torch.float32,
|
|
).reshape(num_batch, num_res, 3, 3)
|
|
|
|
def _trans_diffuse_mask(trans_t, trans_1, diffuse_mask):
|
|
return trans_t * diffuse_mask[..., None] + trans_1 * (1 - diffuse_mask[..., None])
|
|
|
|
def _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask):
|
|
return (
|
|
rotmats_t * diffuse_mask[..., None, None]
|
|
+ rotmats_1 * (1 - diffuse_mask[..., None, None])
|
|
)
|
|
|
|
def _rots_quats_diffuse_mask(rotquats_t, rotquats_1, diffuse_mask):
|
|
return (
|
|
rotquats_t * diffuse_mask[..., None]
|
|
+ rotquats_1 * (1 - diffuse_mask[..., None])
|
|
)
|
|
|
|
class Interpolant:
|
|
|
|
def __init__(self, cfg):
|
|
self._cfg = cfg
|
|
self._rots_cfg = cfg.rots
|
|
self._trans_cfg = cfg.trans
|
|
self._sample_cfg = cfg.sampling
|
|
self._igso3 = None
|
|
|
|
@property
|
|
def igso3(self):
|
|
if self._igso3 is None:
|
|
sigma_grid = torch.linspace(0.1, 1.5, 1000)
|
|
self._igso3 = so3_utils.SampleIGSO3(
|
|
1000, sigma_grid, cache_dir='.cache')
|
|
return self._igso3
|
|
|
|
def set_device(self, device):
|
|
self._device = device
|
|
|
|
def sample_t(self, num_batch):
|
|
t = torch.rand(num_batch, device=self._device)
|
|
truncate = None
|
|
if truncate is None:
|
|
return t * (1 - 2*self._cfg.min_t) + self._cfg.min_t # [min_t, 1-min_t]
|
|
else:
|
|
return t * truncate + self._cfg.min_t # [min_t, min_t + truncate]
|
|
|
|
def _corrupt_trans(self, trans_1, t, res_mask, diffuse_mask, trans_0=None):
|
|
if trans_0 is None:
|
|
trans_nm_0 = _centered_gaussian(*res_mask.shape, self._device)
|
|
trans_0 = trans_nm_0 * du.NM_TO_ANG_SCALE
|
|
trans_t = (1 - t[..., None]) * trans_0 + t[..., None] * trans_1
|
|
trans_t = _trans_diffuse_mask(trans_t, trans_1, diffuse_mask)
|
|
return trans_t * res_mask[..., None]
|
|
|
|
def _corrupt_rotmats(self, rotmats_1, t, res_mask, diffuse_mask, rotmats_0=None):
|
|
num_batch, num_res = res_mask.shape
|
|
if rotmats_0 is None:
|
|
noisy_rotmats = self.igso3.sample(
|
|
torch.tensor([1.5]),
|
|
num_batch*num_res
|
|
).to(self._device)
|
|
noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
|
|
rotmats_0 = torch.einsum(
|
|
"...ij,...jk->...ik", rotmats_1, noisy_rotmats)
|
|
|
|
rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
|
|
identity = torch.eye(3, device=self._device)
|
|
rotmats_t = (
|
|
rotmats_t * res_mask[..., None, None]
|
|
+ identity[None, None] * (1 - res_mask[..., None, None])
|
|
)
|
|
return _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask)
|
|
|
|
def _corrupt_rotquats(self, rotmats_1, rotquats_1, t, res_mask, diffuse_mask, rotquats_0=None):
|
|
num_batch, num_res = res_mask.shape
|
|
if rotquats_0 is None:
|
|
noisy_rotmats = self.igso3.sample(
|
|
torch.tensor([1.5]),
|
|
num_batch*num_res
|
|
).to(self._device)
|
|
noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
|
|
rotmats_0 = torch.einsum(
|
|
"...ij,...jk->...ik", rotmats_1, noisy_rotmats)
|
|
rotquats_0 = rot_to_quat(rotmats_0) # [B, N, 4]
|
|
# rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
|
|
# rotquats_0 = rot_to_quat(rotmats_0)
|
|
rotquats_t = so3_utils.quaternion_slerp_exp(t.expand(-1, num_res), rotquats_1, rotquats_0) # t.expand(-1, num_res): [B, N]
|
|
identity = torch.tensor([1, 0, 0, 0], device=self._device)
|
|
rotquats_t = (
|
|
rotquats_t * res_mask[..., None]
|
|
+ identity[None] * (1 - res_mask[..., None])
|
|
) # rotquats_t: [B, N, 4], res_mask: [B, N], identity: [4]
|
|
return _rots_quats_diffuse_mask(rotquats_t, rotquats_1, diffuse_mask)
|
|
|
|
def corrupt_batch(self, batch):
|
|
# print("training: corrupt_batch")
|
|
noisy_batch = copy.deepcopy(batch)
|
|
|
|
# [B, N, 3]
|
|
trans_1 = batch['trans_1'] # Angstrom
|
|
|
|
# [B, N, 3, 3]
|
|
rotmats_1 = batch['rotmats_1']
|
|
|
|
# [B, N, 4]
|
|
rotquats_1 = batch['rotquats_1']
|
|
|
|
# [B, N]
|
|
res_mask = batch['res_mask']
|
|
diffuse_mask = batch['diffuse_mask']
|
|
num_batch, _ = diffuse_mask.shape
|
|
|
|
# [B, 1]
|
|
t = self.sample_t(num_batch)[:, None]
|
|
so3_t = t
|
|
r3_t = t
|
|
noisy_batch['so3_t'] = so3_t
|
|
noisy_batch['r3_t'] = r3_t
|
|
|
|
# Apply corruptions
|
|
if self._trans_cfg.corrupt:
|
|
trans_t = self._corrupt_trans(
|
|
trans_1, r3_t, res_mask, diffuse_mask)
|
|
else:
|
|
trans_t = trans_1
|
|
if torch.any(torch.isnan(trans_t)):
|
|
raise ValueError('NaN in trans_t during corruption')
|
|
noisy_batch['trans_t'] = trans_t
|
|
|
|
if self._rots_cfg.corrupt:
|
|
rotquats_t = self._corrupt_rotquats(
|
|
rotmats_1, rotquats_1, so3_t, res_mask, diffuse_mask)
|
|
else:
|
|
rotquats_t = rotquats_1
|
|
if torch.any(torch.isnan(rotquats_t)):
|
|
raise ValueError('NaN in rotquats_t during corruption')
|
|
noisy_batch['rotquats_t'] = rotquats_t
|
|
return noisy_batch
|
|
|
|
|
|
def rectify_corrupt_batch(self, batch):
|
|
noisy_batch = copy.deepcopy(batch['noise'])
|
|
|
|
# [B, N, 3]
|
|
trans_1 = batch['sample']['trans_1'] # The true translation
|
|
# [B, N, 3, 3]
|
|
rotmats_1 = batch['sample']['rotmats_1']
|
|
# [B, N, 4]
|
|
rotquats_1 = batch['sample']['rotquats_1']
|
|
# [B, N]
|
|
res_mask = batch['sample']['res_mask']
|
|
diffuse_mask = batch['sample']['diffuse_mask']
|
|
num_batch, _ = diffuse_mask.shape
|
|
|
|
t = self.sample_t(num_batch)[:, None]
|
|
so3_t = t
|
|
r3_t = t
|
|
noisy_batch['so3_t'] = so3_t
|
|
noisy_batch['r3_t'] = r3_t
|
|
|
|
# Apply corruptions
|
|
if self._trans_cfg.corrupt:
|
|
trans_t = self._corrupt_trans(
|
|
trans_1, r3_t, res_mask, diffuse_mask, trans_0=noisy_batch['trans_1'])
|
|
else:
|
|
trans_t = trans_1
|
|
if torch.any(torch.isnan(trans_t)):
|
|
raise ValueError('NaN in trans_t during corruption')
|
|
|
|
if self._rots_cfg.corrupt:
|
|
rotquats_t = self._corrupt_rotquats(
|
|
rotmats_1, rotquats_1, so3_t, res_mask, diffuse_mask, rotquats_0=noisy_batch['rotquats_1'])
|
|
else:
|
|
rotquats_t = rotquats_1
|
|
if torch.any(torch.isnan(rotquats_t)):
|
|
raise ValueError('NaN in rotmats_t during corruption')
|
|
|
|
noisy_batch['trans_1'] = trans_1
|
|
noisy_batch['rotmats_1'] = rotmats_1
|
|
noisy_batch['rotquats_1'] = rotquats_1
|
|
noisy_batch['trans_t'] = trans_t
|
|
noisy_batch['rotquats_t'] = rotquats_t
|
|
return noisy_batch
|
|
|
|
def rot_sample_kappa(self, t):
|
|
if self._rots_cfg.sample_schedule == 'exp':
|
|
return 1 - torch.exp(-t*self._rots_cfg.exp_rate)
|
|
elif self._rots_cfg.sample_schedule == 'linear':
|
|
return t
|
|
else:
|
|
raise ValueError(
|
|
f'Invalid schedule: {self._rots_cfg.sample_schedule}')
|
|
|
|
def _trans_vector_field(self, t, trans_1, trans_t):
|
|
return (trans_1 - trans_t) / (1 - t)
|
|
|
|
def _trans_euler_step(self, d_t, t, trans_1, trans_t):
|
|
assert d_t > 0
|
|
trans_vf = self._trans_vector_field(t, trans_1, trans_t)
|
|
return trans_t + trans_vf * d_t
|
|
|
|
def _rots_euler_step(self, d_t, t, rotmats_1, rotmats_t):
|
|
if self._rots_cfg.sample_schedule == 'linear':
|
|
scaling = 1 / (1 - t)
|
|
elif self._rots_cfg.sample_schedule == 'exp':
|
|
scaling = self._rots_cfg.exp_rate
|
|
else:
|
|
raise ValueError(
|
|
f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
|
|
return so3_utils.geodesic_t(
|
|
scaling * d_t, rotmats_1, rotmats_t)
|
|
|
|
def _rots_quats_euler_step(self, d_t, t, rotquats_1, rotquats_t):
|
|
if self._rots_cfg.sample_schedule == 'linear':
|
|
scaling = 1 / (1 - t)
|
|
elif self._rots_cfg.sample_schedule == 'exp':
|
|
scaling = self._rots_cfg.exp_rate
|
|
else:
|
|
raise ValueError(
|
|
f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
|
|
return so3_utils.quaternion_slerp_exp(
|
|
scaling * d_t, rotquats_1, rotquats_t)
|
|
|
|
def sample(
|
|
self,
|
|
num_batch,
|
|
num_res,
|
|
model,
|
|
num_timesteps=None,
|
|
trans_potential=None,
|
|
trans_0=None,
|
|
rotmats_0=None,
|
|
trans_1=None,
|
|
rotmats_1=None,
|
|
rotquats_1=None,
|
|
diffuse_mask=None,
|
|
chain_idx=None,
|
|
res_idx=None,
|
|
verbose=False,
|
|
):
|
|
res_mask = torch.ones(num_batch, num_res, device=self._device)
|
|
|
|
# Set-up initial prior samples
|
|
if trans_0 is None:
|
|
trans_0 = _centered_gaussian(
|
|
num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE
|
|
|
|
if rotmats_0 is None:
|
|
#* For uniform sampling
|
|
rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
|
|
rotquats_0 = rot_to_quat(rotmats_0)
|
|
#* For IGSO3 sampling
|
|
# noisy_rotmats = self.igso3.sample(
|
|
# torch.tensor([1.5]),
|
|
# num_batch*num_res
|
|
# ).to(self._device)
|
|
# noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
|
|
# rotmats_0 = torch.einsum(
|
|
# "...ij,...jk->...ik", rotmats_1, noisy_rotmats)
|
|
# rotquats_0 = rot_to_quat(rotmats_0) # [B, N, 4]
|
|
|
|
if res_idx is None:
|
|
res_idx = torch.arange(
|
|
num_res,
|
|
device=self._device,
|
|
dtype=torch.float32)[None].repeat(num_batch, 1)
|
|
batch = {
|
|
'res_mask': res_mask,
|
|
'diffuse_mask': res_mask,
|
|
'res_idx': res_idx
|
|
}
|
|
|
|
motif_scaffolding = False
|
|
if diffuse_mask is not None and trans_1 is not None and rotmats_1 is not None and rotquats_1 is not None:
|
|
motif_scaffolding = True
|
|
motif_mask = ~diffuse_mask.bool().squeeze(0)
|
|
else:
|
|
motif_mask = None
|
|
if motif_scaffolding and not self._cfg.twisting.use: # amortisation
|
|
diffuse_mask = diffuse_mask.expand(num_batch, -1) # shape = (B, num_residue)
|
|
batch['diffuse_mask'] = diffuse_mask
|
|
rotmats_0 = _rots_diffuse_mask(rotmats_0, rotmats_1, diffuse_mask)
|
|
rotquats_0 = _rots_quats_diffuse_mask(rotquats_0, rotquats_1, diffuse_mask)
|
|
trans_0 = _trans_diffuse_mask(trans_0, trans_1, diffuse_mask)
|
|
if torch.isnan(trans_0).any():
|
|
raise ValueError('NaN detected in trans_0')
|
|
|
|
logs_traj = defaultdict(list)
|
|
if motif_scaffolding and self._cfg.twisting.use: # sampling / guidance
|
|
assert trans_1.shape[0] == 1 # assume only one motif
|
|
motif_locations = torch.nonzero(motif_mask).squeeze().tolist()
|
|
true_motif_locations, motif_segments_length = twisting.find_ranges_and_lengths(motif_locations)
|
|
|
|
# Marginalise both rotation and motif location
|
|
assert len(motif_mask.shape) == 1
|
|
trans_motif = trans_1[:, motif_mask] # [1, motif_res, 3]
|
|
R_motif = rotmats_1[:, motif_mask] # [1, motif_res, 3, 3]
|
|
# TODO: quaternion version for motif_scaffolding
|
|
num_res = trans_1.shape[-2]
|
|
with torch.inference_mode(False):
|
|
motif_locations = true_motif_locations if self._cfg.twisting.motif_loc else None
|
|
F, motif_locations = twisting.motif_offsets_and_rots_vec_F(num_res, motif_segments_length, motif_locations=motif_locations, num_rots=self._cfg.twisting.num_rots, align=self._cfg.twisting.align, scale=self._cfg.twisting.scale_rots, trans_motif=trans_motif, R_motif=R_motif, max_offsets=self._cfg.twisting.max_offsets, device=self._device, dtype=torch.float64, return_rots=False)
|
|
|
|
if motif_mask is not None and len(motif_mask.shape) == 1:
|
|
motif_mask = motif_mask[None].expand((num_batch, -1))
|
|
|
|
# Set-up time
|
|
if num_timesteps is None:
|
|
num_timesteps = self._sample_cfg.num_timesteps
|
|
ts = torch.linspace(self._cfg.min_t, 1.0, num_timesteps)
|
|
t_1 = ts[0]
|
|
|
|
prot_traj = [(trans_0, rotmats_0)]
|
|
prot_traj_quats = [(trans_0, rotquats_0)]
|
|
clean_traj = []
|
|
clean_traj_quats = []
|
|
for i, t_2 in enumerate(ts[1:]):
|
|
if verbose: # and i % 1 == 0:
|
|
print(f'{i=}, t={t_1.item():.2f}')
|
|
print(torch.cuda.mem_get_info(trans_0.device), torch.cuda.memory_allocated(trans_0.device))
|
|
# Run model.
|
|
trans_t_1, rotquats_t_1 = prot_traj_quats[-1]
|
|
if self._trans_cfg.corrupt:
|
|
batch['trans_t'] = trans_t_1
|
|
else:
|
|
if trans_1 is None:
|
|
raise ValueError('Must provide trans_1 if not corrupting.')
|
|
batch['trans_t'] = trans_1
|
|
if self._rots_cfg.corrupt:
|
|
batch['rotquats_t'] = rotquats_t_1
|
|
else:
|
|
if rotquats_1 is None:
|
|
raise ValueError('Must provide rotmats_1 if not corrupting.')
|
|
batch['rotquats_t'] = rotquats_1
|
|
batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
|
|
batch['so3_t'] = batch['t']
|
|
batch['r3_t'] = batch['t']
|
|
d_t = t_2 - t_1
|
|
|
|
use_twisting = motif_scaffolding and self._cfg.twisting.use and t_1 >= self._cfg.twisting.t_min
|
|
|
|
if use_twisting: # Reconstruction guidance
|
|
with torch.inference_mode(False):
|
|
batch, Log_delta_R, delta_x = twisting.perturbations_for_grad(batch)
|
|
model_out = model(batch)
|
|
t = batch['r3_t'] #TODO: different time for SO3?
|
|
trans_t_1, rotmats_t_1, logs_traj = self.guidance(trans_t_1, rotmats_t_1, model_out, motif_mask, R_motif, trans_motif, Log_delta_R, delta_x, t, d_t, logs_traj)
|
|
|
|
else:
|
|
with torch.no_grad():
|
|
model_out = model(batch)
|
|
|
|
# Process model output.
|
|
pred_trans_1 = model_out['pred_trans']
|
|
pred_rotmats_1 = model_out['pred_rotmats']
|
|
pred_rotquats_1 = model_out['pred_rotquats']
|
|
clean_traj_quats.append(
|
|
(pred_trans_1.detach().cpu(), pred_rotquats_1.detach().cpu())
|
|
)
|
|
if self._cfg.self_condition:
|
|
if motif_scaffolding:
|
|
batch['trans_sc'] = (
|
|
pred_trans_1 * diffuse_mask[..., None]
|
|
+ trans_1 * (1 - diffuse_mask[..., None])
|
|
)
|
|
else:
|
|
batch['trans_sc'] = pred_trans_1
|
|
|
|
# Take reverse step
|
|
|
|
trans_t_2 = self._trans_euler_step(
|
|
d_t, t_1, pred_trans_1, trans_t_1)
|
|
if trans_potential is not None:
|
|
with torch.inference_mode(False):
|
|
grad_pred_trans_1 = pred_trans_1.clone().detach().requires_grad_(True)
|
|
pred_trans_potential = autograd.grad(outputs=trans_potential(grad_pred_trans_1), inputs=grad_pred_trans_1)[0]
|
|
if self._trans_cfg.potential_t_scaling:
|
|
trans_t_2 -= t_1 / (1 - t_1) * pred_trans_potential * d_t
|
|
else:
|
|
trans_t_2 -= pred_trans_potential * d_t
|
|
rotquats_t_2 = self._rots_quats_euler_step(
|
|
d_t, t_1, pred_rotquats_1, rotquats_t_1)
|
|
if motif_scaffolding and not self._cfg.twisting.use:
|
|
trans_t_2 = _trans_diffuse_mask(trans_t_2, trans_1, diffuse_mask)
|
|
rotquats_t_2 = _rots_quats_diffuse_mask(rotquats_t_2, rotquats_1, diffuse_mask)
|
|
|
|
prot_traj_quats.append((trans_t_2, rotquats_t_2))
|
|
t_1 = t_2
|
|
|
|
# We only integrated to min_t, so need to make a final step
|
|
t_1 = ts[-1]
|
|
trans_t_1, rotquats_t_1 = prot_traj_quats[-1]
|
|
if self._trans_cfg.corrupt:
|
|
batch['trans_t'] = trans_t_1
|
|
else:
|
|
if trans_1 is None:
|
|
raise ValueError('Must provide trans_1 if not corrupting.')
|
|
batch['trans_t'] = trans_1
|
|
if self._rots_cfg.corrupt:
|
|
batch['rotquats_t'] = rotquats_t_1
|
|
else:
|
|
if rotquats_1 is None:
|
|
raise ValueError('Must provide rotquats_1 if not corrupting.')
|
|
batch['rotquats_t'] = rotquats_1
|
|
batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
|
|
with torch.no_grad():
|
|
model_out = model(batch)
|
|
pred_trans_1 = model_out['pred_trans']
|
|
pred_rotmats_1 = model_out['pred_rotmats']
|
|
pred_rotquats_1 = model_out['pred_rotquats']
|
|
clean_traj.append(
|
|
(pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
|
|
)
|
|
clean_traj_quats.append(
|
|
(pred_trans_1.detach().cpu(), pred_rotquats_1.detach().cpu())
|
|
)
|
|
prot_traj.append((pred_trans_1, pred_rotmats_1))
|
|
prot_traj_quats.append((pred_trans_1, pred_rotquats_1))
|
|
|
|
# Convert trajectories to atom37.
|
|
atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask)
|
|
clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask)
|
|
return prot_traj, atom37_traj, clean_atom37_traj, clean_traj
|
|
|
|
def guidance(self, trans_t, rotmats_t, model_out, motif_mask, R_motif, trans_motif, Log_delta_R, delta_x, t, d_t, logs_traj):
|
|
# Select motif
|
|
motif_mask = motif_mask.clone()
|
|
trans_pred = model_out['pred_trans'][:, motif_mask] # [B, motif_res, 3]
|
|
R_pred = model_out['pred_rotmats'][:, motif_mask] # [B, motif_res, 3, 3]
|
|
|
|
# Proposal for marginalising motif rotation
|
|
F = twisting.motif_rots_vec_F(trans_motif, R_motif, self._cfg.twisting.num_rots, align=self._cfg.twisting.align, scale=self._cfg.twisting.scale_rots, device=self._device, dtype=torch.float32)
|
|
|
|
# Estimate p(motif|predicted_motif)
|
|
grad_Log_delta_R, grad_x_log_p_motif, logs = twisting.grad_log_lik_approx(R_pred, trans_pred, R_motif, trans_motif, Log_delta_R, delta_x, None, None, None, F, twist_potential_rot=self._cfg.twisting.potential_rot, twist_potential_trans=self._cfg.twisting.potential_trans)
|
|
|
|
with torch.no_grad():
|
|
# Choose scaling
|
|
t_trans = t
|
|
t_so3 = t
|
|
if self._cfg.twisting.scale_w_t == 'ot':
|
|
var_trans = ((1 - t_trans) / t_trans)[:, None]
|
|
var_rot = ((1 - t_so3) / t_so3)[:, None, None]
|
|
elif self._cfg.twisting.scale_w_t == 'linear':
|
|
var_trans = (1 - t)[:, None]
|
|
var_rot = (1 - t_so3)[:, None, None]
|
|
elif self._cfg.twisting.scale_w_t == 'constant':
|
|
num_batch = trans_pred.shape[0]
|
|
var_trans = torch.ones((num_batch, 1, 1)).to(R_pred.device)
|
|
var_rot = torch.ones((num_batch, 1, 1, 1)).to(R_pred.device)
|
|
var_trans = var_trans + self._cfg.twisting.obs_noise ** 2
|
|
var_rot = var_rot + self._cfg.twisting.obs_noise ** 2
|
|
|
|
trans_scale_t = self._cfg.twisting.scale / var_trans
|
|
rot_scale_t = self._cfg.twisting.scale / var_rot
|
|
|
|
# Compute update
|
|
trans_t, rotmats_t = twisting.step(trans_t, rotmats_t, grad_x_log_p_motif, grad_Log_delta_R, d_t, trans_scale_t, rot_scale_t, self._cfg.twisting.update_trans, self._cfg.twisting.update_rot)
|
|
|
|
# delete unsused arrays to prevent from any memory leak
|
|
del grad_Log_delta_R
|
|
del grad_x_log_p_motif
|
|
del Log_delta_R
|
|
del delta_x
|
|
for key, value in model_out.items():
|
|
model_out[key] = value.detach().requires_grad_(False)
|
|
|
|
return trans_t, rotmats_t, logs_traj |