mirror of
https://github.com/RosettaCommons/RFdiffusion.git
synced 2026-06-04 18:44:21 +08:00
119 lines
4.7 KiB
Python
119 lines
4.7 KiB
Python
"""SO(3) diffusion methods."""
|
|
import numpy as np
|
|
import os
|
|
from functools import cached_property
|
|
import torch
|
|
from scipy.spatial.transform import Rotation
|
|
import scipy.linalg
|
|
|
|
|
|
### First define geometric operations on the SO3 manifold
|
|
|
|
# hat map from vector space R^3 to Lie algebra so(3)
|
|
def hat(v):
|
|
hat_v = torch.zeros([v.shape[0], 3, 3])
|
|
hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0]
|
|
return hat_v + -hat_v.transpose(2, 1)
|
|
|
|
# Logarithmic map from SO(3) to R^3 (i.e. rotation vector)
|
|
def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec())
|
|
|
|
# logarithmic map from SO(3) to so(3), this is the matrix logarithm
|
|
def log(R): return hat(Log(R))
|
|
|
|
# Exponential map from vector space of so(3) to SO(3), this is the matrix
|
|
# exponential combined with the "hat" map
|
|
def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix())
|
|
|
|
# Angle of rotation SO(3) to R^+
|
|
def Omega(R): return np.linalg.norm(log(R), axis=[-2, -1])/np.sqrt(2.)
|
|
|
|
L_default = 2000
|
|
def f_igso3(omega, t, L=L_default):
|
|
"""Truncated sum of IGSO(3) distribution.
|
|
|
|
This function approximates the power series in equation 5 of
|
|
"DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL
|
|
ALIGNMENT"
|
|
Leach et al. 2022
|
|
|
|
This expression diverges from the expression in Leach in that here, sigma =
|
|
sqrt(2) * eps, if eps_leach were the scale parameter of the IGSO(3).
|
|
|
|
With this reparameterization, IGSO(3) agrees with the Brownian motion on
|
|
SO(3) with t=sigma^2 when defined for the canonical inner product on SO3,
|
|
<u, v>_SO3 = Trace(u v^T)/2
|
|
|
|
Args:
|
|
omega: i.e. the angle of rotation associated with rotation matrix
|
|
t: variance parameter of IGSO(3), maps onto time in Brownian motion
|
|
L: Truncation level
|
|
"""
|
|
ls = torch.arange(L)[None] # of shape [1, L]
|
|
return ((2*ls + 1) * torch.exp(-ls*(ls+1)*t/2) *
|
|
torch.sin(omega[:, None]*(ls+1/2)) / torch.sin(omega[:, None]/2)).sum(dim=-1)
|
|
|
|
def d_logf_d_omega(omega, t, L=L_default):
|
|
omega = torch.tensor(omega, requires_grad=True)
|
|
log_f = torch.log(f_igso3(omega, t, L))
|
|
return torch.autograd.grad(log_f.sum(), omega)[0].numpy()
|
|
|
|
# IGSO3 density with respect to the volume form on SO(3)
|
|
def igso3_density(Rt, t, L=L_default):
|
|
return f_igso3(torch.tensor(Omega(Rt)), t, L).numpy()
|
|
|
|
def igso3_density_angle(omega, t, L=L_default):
|
|
return f_igso3(torch.tensor(omega), t, L).numpy()*(1-np.cos(omega))/np.pi
|
|
|
|
# grad_R log IGSO3(R; I_3, t)
|
|
def igso3_score(R, t, L=L_default):
|
|
omega = Omega(R)
|
|
unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None]
|
|
return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None]
|
|
|
|
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma, L=L_default):
|
|
"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs
|
|
and score norms and expected squared score norms.
|
|
|
|
Args:
|
|
num_sigma: number of different sigmas for which to compute igso3
|
|
quantities.
|
|
num_omega: number of point in the discretization in the angle of
|
|
rotation.
|
|
min_sigma, max_sigma: the upper and lower ranges for the angle of
|
|
rotation on which to consider the IGSO3 distribution. This cannot
|
|
be too low or it will create numerical instability.
|
|
"""
|
|
# Discretize omegas for calculating CDFs. Skip omega=0.
|
|
discrete_omega = np.linspace(0, np.pi, num_omega+1)[1:]
|
|
|
|
# Exponential noise schedule. This choice is closely tied to the
|
|
# scalings used when simulating the reverse time SDE. For each step n,
|
|
# discrete_sigma[n] = min_eps^(1-n/num_eps) * max_eps^(n/num_eps)
|
|
discrete_sigma = 10 ** np.linspace(np.log10(min_sigma), np.log10(max_sigma), num_sigma + 1)[1:]
|
|
|
|
# Compute the pdf and cdf values for the marginal distribution of the angle
|
|
# of rotation (which is needed for sampling)
|
|
pdf_vals = np.asarray(
|
|
[igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma])
|
|
cdf_vals = np.asarray(
|
|
[pdf.cumsum() / num_omega * np.pi for pdf in pdf_vals])
|
|
|
|
# Compute the norms of the scores. This are used to scale the rotation axis when
|
|
# computing the score as a vector.
|
|
score_norm = np.asarray(
|
|
[d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma])
|
|
|
|
# Compute the standard deviation of the score norm for each sigma
|
|
exp_score_norms = np.sqrt(
|
|
np.sum(
|
|
score_norm**2 * pdf_vals, axis=1) / np.sum(
|
|
pdf_vals, axis=1))
|
|
return {
|
|
'cdf': cdf_vals,
|
|
'score_norm': score_norm,
|
|
'exp_score_norms': exp_score_norms,
|
|
'discrete_omega': discrete_omega,
|
|
'discrete_sigma': discrete_sigma,
|
|
}
|