mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 18:04:23 +08:00
Set any nans/infs in model scores to a small value.
With some low (but non-trivial) frequency, processing through the convolutional layers diverge and node attributes become a mixture of nan and inf (which seem to all turn to nan). This later throws an exception during the Kabsch transform, which ruins results for the whole complex. Setting these to 0 basically skips an iteration, at worst it ruins one of the sampled complexes, but leaves the others. Note this is only applied to the main model, *not* the confidence model.
This commit is contained in:
@@ -9,8 +9,8 @@ from torch_geometric.loader import DataLoader
|
||||
from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch
|
||||
from utils.torsion import modify_conformer_torsion_angles
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from utils.utils import crop_beyond
|
||||
from utils.logging_utils import get_logger
|
||||
|
||||
|
||||
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7,
|
||||
@@ -72,6 +72,7 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s
|
||||
temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False):
|
||||
N = len(data_list)
|
||||
trajectory = []
|
||||
logger = get_logger()
|
||||
if return_features:
|
||||
lig_features, rec_features = [], []
|
||||
assert batch_size >= N, "Not implemented yet"
|
||||
@@ -113,6 +114,21 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s
|
||||
'all_atoms' in model_args and model_args.all_atoms, device)
|
||||
|
||||
tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3]
|
||||
mean_scores = torch.mean(tr_score, dim=-1)
|
||||
num_nans = torch.sum(torch.isnan(mean_scores))
|
||||
if num_nans > 0:
|
||||
name = complex_graph_batch['name']
|
||||
if isinstance(name, list):
|
||||
name = name[0]
|
||||
logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: "
|
||||
f"{num_nans} / {mean_scores.numel()} samples failed")
|
||||
|
||||
# Set the nan values to a small value, just want to disturb slightly
|
||||
# Hopefully won't get nan the next iteration
|
||||
tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps)
|
||||
rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps)
|
||||
tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps)
|
||||
del eps
|
||||
|
||||
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
|
||||
rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
|
||||
|
||||
Reference in New Issue
Block a user