Files
ReQFlow/experiments/inference_se3_flows.py
Angxiao Yue 5bad7f2134 upload code
2025-02-20 17:54:00 +08:00

120 lines
4.2 KiB
Python

"""Script for running inference and evaluation."""
import os
import time
import numpy as np
import hydra
import torch
import GPUtil
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from experiments import utils as eu
from models.flow_module import FlowModule
import pandas as pd
torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)
class EvalRunner:
def __init__(self, cfg: DictConfig):
"""Initialize sampler.
Args:
cfg: inference config.
"""
ckpt_path = cfg.inference.ckpt_path
ckpt_dir = os.path.dirname(ckpt_path)
ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))
# Set-up config.
OmegaConf.set_struct(cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
cfg = OmegaConf.merge(cfg, ckpt_cfg)
cfg.experiment.checkpointer.dirpath = './'
cfg.experiment.is_training = False
self._cfg = cfg
self._exp_cfg = cfg.experiment
self._infer_cfg = cfg.inference
self._samples_cfg = self._infer_cfg.samples
self._task = self._infer_cfg.task
self._rng = np.random.default_rng(self._infer_cfg.seed)
# Set-up output directory only on rank 0
local_rank = os.environ.get('LOCAL_RANK', 0)
if local_rank == 0:
inference_dir = self.setup_inference_dir(ckpt_path)
self._exp_cfg.inference_dir = inference_dir
config_path = os.path.join(inference_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config=self._cfg, f=f)
log.info(f'Saving inference config to {config_path}')
# Read checkpoint and initialize module.
self._flow_module = FlowModule.load_from_checkpoint(
checkpoint_path=ckpt_path,
cfg=self._cfg,
strict=False
)
log.info(pl.utilities.model_summary.ModelSummary(self._flow_module))
self._flow_module.eval()
self._flow_module._infer_cfg = self._infer_cfg
self._flow_module._samples_cfg = self._samples_cfg
@property
def inference_dir(self):
return self._flow_module.inference_dir
def setup_inference_dir(self, ckpt_path):
self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
output_dir = os.path.join(
self._infer_cfg.predict_dir,
self._ckpt_name,
self._infer_cfg.task,
self._infer_cfg.inference_subdir,
)
os.makedirs(output_dir, exist_ok=True)
log.info(f'Saving results to {output_dir}')
return output_dir
def run_sampling(self):
# devices = GPUtil.getAvailable(
# order='memory', limit = 8)[:self._infer_cfg.num_gpus]
devices = [0, 1, 2, 3]
num_gpus = self._infer_cfg.num_gpus
log.info(f"Using devices: {devices}")
log.info(f'Evaluating {self._infer_cfg.task}')
if self._infer_cfg.task == 'unconditional':
if self._infer_cfg.for_rectify:
eval_dataset = eu.RectifyLengthDataset(dataset_cfg=self._infer_cfg.datasets)
else:
eval_dataset = eu.LengthDataset(self._samples_cfg)
elif self._infer_cfg.task == 'scaffolding':
eval_dataset = eu.ScaffoldingDataset(self._samples_cfg)
else:
raise ValueError(f'Unknown task {self._infer_cfg.task}')
dataloader = torch.utils.data.DataLoader(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
trainer = Trainer(
accelerator="gpu",
strategy="ddp",
devices=devices,
)
trainer.predict(self._flow_module, dataloaders=dataloader)
@hydra.main(version_base=None, config_path="../configs", config_name="inference_unconditional")
def run(cfg: DictConfig) -> None:
# os.environ['HYDRA_FULL_ERROR'] = '1'
# Read model checkpoint.
log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
start_time = time.time()
sampler = EvalRunner(cfg)
sampler.run_sampling()
elapsed_time = time.time() - start_time
log.info(f'Finished in {elapsed_time:.2f}s')
if __name__ == '__main__':
run()