Files
ReQFlow/data/rectify_prot_datasets.py
Angxiao Yue 5bad7f2134 upload code
2025-02-20 17:54:00 +08:00

117 lines
4.3 KiB
Python

import numpy as np
import pandas as pd
import logging
import tree
import torch
import random
from torch.utils.data import Dataset
from data import utils as du
from openfold.data import data_transforms
from openfold.utils import rigid_utils
from openfold.utils.rigid_utils import rot_to_quat, quat_to_rot
class RectifyProtDataset(Dataset):
def __init__(
self,
*,
dataset_cfg,
task,
is_training):
self._log = logging.getLogger(__name__)
self._dataset_cfg = dataset_cfg
self.is_training = is_training
data = pd.read_csv(self._dataset_cfg.rectify_csv_path)
self._create_split(data)
def _create_split(self, data_csv):
if self.is_training:
self.csv = data_csv
self._log.info(f'Loaded training dataset with {len(self.csv)} samples.')
else:
if self._dataset_cfg.max_eval_length is None and self._dataset_cfg.min_eval_length is None:
# min and max are empty
eval_lengths = data_csv.length
elif self._dataset_cfg.max_eval_length is None:
# min is not empty, max is empty
eval_lengths = data_csv.length[
data_csv.length >= self._dataset_cfg.min_eval_length
]
else:
# max is not empty, min is empty
eval_lengths = data_csv.length[
data_csv.length <= self._dataset_cfg.max_eval_length
]
if self._dataset_cfg.min_eval_length is not None:
eval_lengths = eval_lengths[eval_lengths >= self._dataset_cfg.min_eval_length]
all_lengths = np.sort(eval_lengths.unique())
length_indices = (len(all_lengths) - 1) * np.linspace(
0.0, 1.0, self._dataset_cfg.num_eval_lengths)
length_indices = length_indices.astype(int)
eval_lengths = all_lengths[length_indices]
eval_csv = data_csv[data_csv.length.isin(eval_lengths)]
eval_csv = eval_csv.groupby('length').sample(
self._dataset_cfg.samples_per_eval_length,
replace=True,
random_state=123
)
eval_csv = eval_csv.sort_values('length', ascending=False)
self.csv = eval_csv
self._log.info(
f'Validation: {len(self.csv)} examples with lengths {eval_lengths}')
self.csv['index'] = list(range(len(self.csv)))
def __len__(self):
return len(self.csv)
def __getitem__(self, idx):
row = self.csv.iloc[idx]
prot_traj = torch.load(row['prot_path'], map_location='cpu')
num_res = length = row['length']
sample_batch = {}
noise_batch = {}
trans_0 = prot_traj[0][0]
rotmats_0 = prot_traj[0][1]
gt_trans_1 = prot_traj[1][0]
gt_rotmats_1 = prot_traj[1][1]
rotquats_0 = rot_to_quat(rotmats_0)
gt_rotquats_1 = rot_to_quat(gt_rotmats_1)
sample_batch['trans_1'] = gt_trans_1.squeeze(dim=0) # (1, num_res, 3) -> (num_res, 3)
sample_batch['rotmats_1'] = gt_rotmats_1.squeeze(dim=0) # (1, num_res, 3, 3) -> (num_res, 3, 3)
sample_batch['rotquats_1'] = gt_rotquats_1.squeeze(dim=0) # (1, num_res, 4) -> (num_res, 4)
sample_batch['res_mask'] = torch.ones(num_res, dtype=torch.int)
sample_batch['diffuse_mask'] = torch.ones(num_res, dtype=torch.int)
sample_batch['res_idx'] = torch.arange(num_res, dtype=torch.int)
sample_batch['csv_idx'] = torch.tensor([idx], dtype=torch.long)
sample_batch['chain_idx'] = torch.zeros(1, dtype=torch.int)
noise_batch['trans_1'] = trans_0.squeeze(dim=0)
noise_batch['rotmats_1'] = rotmats_0.squeeze(dim=0)
noise_batch['rotquats_1'] = rotquats_0.squeeze(dim=0)
noise_batch['res_mask'] = torch.ones(num_res, dtype=torch.int)
noise_batch['diffuse_mask'] = torch.ones(num_res, dtype=torch.int)
noise_batch['res_idx'] = torch.arange(num_res, dtype=torch.int)
noise_batch['csv_idx'] = torch.tensor([idx], dtype=torch.long)
noise_batch['chain_idx'] = torch.zeros(1, dtype=torch.int)
return {'sample': sample_batch, 'noise': noise_batch}