From 9d0bf3d884a44cd371aec76a376bcafb0367d9bd Mon Sep 17 00:00:00 2001 From: Jacob Silterra Date: Thu, 6 Jun 2024 10:25:51 -0400 Subject: [PATCH] Set any nans/infs in model scores to a small value. With some low (but non-trivial) frequency, processing through the convolutional layers diverge and node attributes become a mixture of nan and inf (which seem to all turn to nan). This later throws an exception during the Kabsch transform, which ruins results for the whole complex. Setting these to 0 basically skips an iteration, at worst it ruins one of the sampled complexes, but leaves the others. Note this is only applied to the main model, *not* the confidence model. --- utils/sampling.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/utils/sampling.py b/utils/sampling.py index 67f3c50..78b2434 100644 --- a/utils/sampling.py +++ b/utils/sampling.py @@ -9,8 +9,8 @@ from torch_geometric.loader import DataLoader from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch from utils.torsion import modify_conformer_torsion_angles from scipy.spatial.transform import Rotation as R - from utils.utils import crop_beyond +from utils.logging_utils import get_logger def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7, @@ -72,6 +72,7 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False): N = len(data_list) trajectory = [] + logger = get_logger() if return_features: lig_features, rec_features = [], [] assert batch_size >= N, "Not implemented yet" @@ -113,6 +114,21 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s 'all_atoms' in model_args and model_args.all_atoms, device) tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3] + mean_scores = torch.mean(tr_score, dim=-1) + num_nans = torch.sum(torch.isnan(mean_scores)) + if num_nans > 0: + name = complex_graph_batch['name'] + if isinstance(name, list): + name = name[0] + logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: " + f"{num_nans} / {mean_scores.numel()} samples failed") + + # Set the nan values to a small value, just want to disturb slightly + # Hopefully won't get nan the next iteration + tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps) + rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps) + tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps) + del eps tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min))) rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))