diff --git a/utils/sampling.py b/utils/sampling.py index 67f3c50..78b2434 100644 --- a/utils/sampling.py +++ b/utils/sampling.py @@ -9,8 +9,8 @@ from torch_geometric.loader import DataLoader from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch from utils.torsion import modify_conformer_torsion_angles from scipy.spatial.transform import Rotation as R - from utils.utils import crop_beyond +from utils.logging_utils import get_logger def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7, @@ -72,6 +72,7 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False): N = len(data_list) trajectory = [] + logger = get_logger() if return_features: lig_features, rec_features = [], [] assert batch_size >= N, "Not implemented yet" @@ -113,6 +114,21 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s 'all_atoms' in model_args and model_args.all_atoms, device) tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3] + mean_scores = torch.mean(tr_score, dim=-1) + num_nans = torch.sum(torch.isnan(mean_scores)) + if num_nans > 0: + name = complex_graph_batch['name'] + if isinstance(name, list): + name = name[0] + logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: " + f"{num_nans} / {mean_scores.numel()} samples failed") + + # Set the nan values to a small value, just want to disturb slightly + # Hopefully won't get nan the next iteration + tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps) + rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps) + tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps) + del eps tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min))) rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))