fix confidence ckpt and improve memory

This commit is contained in:
Gabriele Corso
2022-11-16 14:13:42 -05:00
parent 4c60e5b7ee
commit 5393a8cafc
2 changed files with 19 additions and 14 deletions

View File

@@ -149,7 +149,6 @@ def test_epoch(model, loader, rmsd_prediction):
model.eval()
meter = AverageMeter(['loss'], unpooled_metrics=True) if rmsd_prediction else AverageMeter(['confidence_loss', 'accuracy', 'ROC AUC'], unpooled_metrics=True)
all_labels = []
all_affinities = []
for data in tqdm(loader, total=len(loader)):
try:
with torch.no_grad():
@@ -250,8 +249,10 @@ def train(args, model, optimizer, scheduler, train_loader, val_loader, run_dir):
def construct_loader_confidence(args, device):
common_args = {'cache_path': args.cache_path, 'original_model_dir': args.original_model_dir, 'device': device,
'inference_steps': args.inference_steps, 'samples_per_complex': args.samples_per_complex,
'limit_complexes': args.limit_complexes, 'all_atoms': args.all_atoms, 'balance': args.balance, 'rmsd_classification_cutoff': args.rmsd_classification_cutoff,
'use_original_model_cache': args.use_original_model_cache, 'cache_creation_id': args.cache_creation_id, "cache_ids_to_combine": args.cache_ids_to_combine}
'limit_complexes': args.limit_complexes, 'all_atoms': args.all_atoms, 'balance': args.balance,
'rmsd_classification_cutoff': args.rmsd_classification_cutoff, 'use_original_model_cache': args.use_original_model_cache,
'cache_creation_id': args.cache_creation_id, "cache_ids_to_combine": args.cache_ids_to_combine,
"model_ckpt": args.ckpt}
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
exception_flag = False

View File

@@ -57,8 +57,8 @@ def get_args_and_cache_path(original_model_dir, split):
class ConfidenceDataset(Dataset):
def __init__(self, cache_path, original_model_dir, split, device, limit_complexes,
inference_steps, samples_per_complex, all_atoms,
args, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2,
cache_ids_to_combine= None, cache_creation_id=None):
args, model_ckpt, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2,
cache_ids_to_combine=None, cache_creation_id=None):
super(ConfidenceDataset, self).__init__()
@@ -73,9 +73,21 @@ class ConfidenceDataset(Dataset):
self.cache_ids_to_combine = cache_ids_to_combine
self.cache_creation_id = cache_creation_id
self.samples_per_complex = samples_per_complex
self.model_ckpt = model_ckpt
self.original_model_args, original_model_cache = get_args_and_cache_path(original_model_dir, split)
self.complex_graphs_cache = original_model_cache if self.use_original_model_cache else get_cache_path(args, split)
# check if the docked positions have already been computed, if not run the preprocessing (docking every complex)
self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}'
f'_split_{split}_limit_{limit_complexes}')
if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \
(not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None):
os.makedirs(self.full_cache_path, exist_ok=True)
self.preprocessing(original_model_cache)
# load the graphs that the confidence model will use
print('Using the cached complex graphs of the original model args' if self.use_original_model_cache else 'Not using the cached complex graphs of the original model args. Instead the complex graphs are used that are at the location given by the dataset parameters given to confidence_train.py')
print(self.complex_graphs_cache)
if not os.path.exists(os.path.join(self.complex_graphs_cache, "heterographs.pkl")):
@@ -99,14 +111,6 @@ class ConfidenceDataset(Dataset):
complex_graphs = pickle.load(f)
self.complex_graph_dict = {d.name: d for d in complex_graphs}
self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}'
f'_split_{split}_limit_{limit_complexes}')
if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \
(not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None):
os.makedirs(self.full_cache_path, exist_ok=True)
self.preprocessing(original_model_cache)
if self.cache_ids_to_combine is None:
print(f'HAPPENING | Loading positions and rmsds from: {os.path.join(self.full_cache_path, "ligand_positions.pkl")}')
with open(os.path.join(self.full_cache_path, "ligand_positions.pkl"), 'rb') as f:
@@ -209,7 +213,7 @@ class ConfidenceDataset(Dataset):
t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args)
model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True)
state_dict = torch.load(f'{self.original_model_dir}/best_model.pt', map_location=torch.device('cpu'))
state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model = model.to(self.device)
model.eval()