diff --git a/confidence/confidence_train.py b/confidence/confidence_train.py index 8130ee8..6e52280 100644 --- a/confidence/confidence_train.py +++ b/confidence/confidence_train.py @@ -149,7 +149,6 @@ def test_epoch(model, loader, rmsd_prediction): model.eval() meter = AverageMeter(['loss'], unpooled_metrics=True) if rmsd_prediction else AverageMeter(['confidence_loss', 'accuracy', 'ROC AUC'], unpooled_metrics=True) all_labels = [] - all_affinities = [] for data in tqdm(loader, total=len(loader)): try: with torch.no_grad(): @@ -250,8 +249,10 @@ def train(args, model, optimizer, scheduler, train_loader, val_loader, run_dir): def construct_loader_confidence(args, device): common_args = {'cache_path': args.cache_path, 'original_model_dir': args.original_model_dir, 'device': device, 'inference_steps': args.inference_steps, 'samples_per_complex': args.samples_per_complex, - 'limit_complexes': args.limit_complexes, 'all_atoms': args.all_atoms, 'balance': args.balance, 'rmsd_classification_cutoff': args.rmsd_classification_cutoff, - 'use_original_model_cache': args.use_original_model_cache, 'cache_creation_id': args.cache_creation_id, "cache_ids_to_combine": args.cache_ids_to_combine} + 'limit_complexes': args.limit_complexes, 'all_atoms': args.all_atoms, 'balance': args.balance, + 'rmsd_classification_cutoff': args.rmsd_classification_cutoff, 'use_original_model_cache': args.use_original_model_cache, + 'cache_creation_id': args.cache_creation_id, "cache_ids_to_combine": args.cache_ids_to_combine, + "model_ckpt": args.ckpt} loader_class = DataListLoader if torch.cuda.is_available() else DataLoader exception_flag = False diff --git a/confidence/dataset.py b/confidence/dataset.py index bee4fc1..1103b79 100644 --- a/confidence/dataset.py +++ b/confidence/dataset.py @@ -57,8 +57,8 @@ def get_args_and_cache_path(original_model_dir, split): class ConfidenceDataset(Dataset): def __init__(self, cache_path, original_model_dir, split, device, limit_complexes, inference_steps, samples_per_complex, all_atoms, - args, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2, - cache_ids_to_combine= None, cache_creation_id=None): + args, model_ckpt, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2, + cache_ids_to_combine=None, cache_creation_id=None): super(ConfidenceDataset, self).__init__() @@ -73,9 +73,21 @@ class ConfidenceDataset(Dataset): self.cache_ids_to_combine = cache_ids_to_combine self.cache_creation_id = cache_creation_id self.samples_per_complex = samples_per_complex + self.model_ckpt = model_ckpt self.original_model_args, original_model_cache = get_args_and_cache_path(original_model_dir, split) self.complex_graphs_cache = original_model_cache if self.use_original_model_cache else get_cache_path(args, split) + + # check if the docked positions have already been computed, if not run the preprocessing (docking every complex) + self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}' + f'_split_{split}_limit_{limit_complexes}') + + if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \ + (not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None): + os.makedirs(self.full_cache_path, exist_ok=True) + self.preprocessing(original_model_cache) + + # load the graphs that the confidence model will use print('Using the cached complex graphs of the original model args' if self.use_original_model_cache else 'Not using the cached complex graphs of the original model args. Instead the complex graphs are used that are at the location given by the dataset parameters given to confidence_train.py') print(self.complex_graphs_cache) if not os.path.exists(os.path.join(self.complex_graphs_cache, "heterographs.pkl")): @@ -99,14 +111,6 @@ class ConfidenceDataset(Dataset): complex_graphs = pickle.load(f) self.complex_graph_dict = {d.name: d for d in complex_graphs} - self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}' - f'_split_{split}_limit_{limit_complexes}') - - if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \ - (not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None): - os.makedirs(self.full_cache_path, exist_ok=True) - self.preprocessing(original_model_cache) - if self.cache_ids_to_combine is None: print(f'HAPPENING | Loading positions and rmsds from: {os.path.join(self.full_cache_path, "ligand_positions.pkl")}') with open(os.path.join(self.full_cache_path, "ligand_positions.pkl"), 'rb') as f: @@ -209,7 +213,7 @@ class ConfidenceDataset(Dataset): t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args) model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True) - state_dict = torch.load(f'{self.original_model_dir}/best_model.pt', map_location=torch.device('cpu')) + state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}.pt', map_location=torch.device('cpu')) model.load_state_dict(state_dict, strict=True) model = model.to(self.device) model.eval()