Set any nans/infs in model scores to a small value.

With some low (but non-trivial) frequency, processing through the convolutional layers diverge and node attributes become a mixture of nan and inf (which seem to all turn to nan). This later throws an exception during the Kabsch transform, which ruins results for the whole complex. Setting these to 0 basically skips an iteration, at worst it ruins one of the sampled complexes, but leaves the others. Note this is only applied to the main model, *not* the confidence model.
This commit is contained in:
Jacob Silterra
2024-06-06 10:25:51 -04:00
parent 5238b18d4a
commit 9d0bf3d884

View File

@@ -9,8 +9,8 @@ from torch_geometric.loader import DataLoader
from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch
from utils.torsion import modify_conformer_torsion_angles
from scipy.spatial.transform import Rotation as R
from utils.utils import crop_beyond
from utils.logging_utils import get_logger
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7,
@@ -72,6 +72,7 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s
temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False):
N = len(data_list)
trajectory = []
logger = get_logger()
if return_features:
lig_features, rec_features = [], []
assert batch_size >= N, "Not implemented yet"
@@ -113,6 +114,21 @@ def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_s
'all_atoms' in model_args and model_args.all_atoms, device)
tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3]
mean_scores = torch.mean(tr_score, dim=-1)
num_nans = torch.sum(torch.isnan(mean_scores))
if num_nans > 0:
name = complex_graph_batch['name']
if isinstance(name, list):
name = name[0]
logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: "
f"{num_nans} / {mean_scores.numel()} samples failed")
# Set the nan values to a small value, just want to disturb slightly
# Hopefully won't get nan the next iteration
tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps)
rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps)
tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps)
del eps
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))