Files
ReQFlow/data/rectify_datasets.py
Angxiao Yue 5bad7f2134 upload code
2025-02-20 17:54:00 +08:00

369 lines
14 KiB
Python

import numpy as np
import pandas as pd
import logging
import tree
import torch
import random
from torch.utils.data import Dataset
from data import utils as du
from openfold.data import data_transforms
from openfold.utils import rigid_utils
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
def _rog_filter(df, quantile):
y_quant = pd.pivot_table(
df,
values='radius_gyration',
index='modeled_seq_len',
aggfunc=lambda x: np.quantile(x, quantile)
)
x_quant = y_quant.index.to_numpy()
y_quant = y_quant.radius_gyration.to_numpy()
# Fit polynomial regressor
poly = PolynomialFeatures(degree=4, include_bias=True)
poly_features = poly.fit_transform(x_quant[:, None])
poly_reg_model = LinearRegression()
poly_reg_model.fit(poly_features, y_quant)
# Calculate cutoff for all sequence lengths
max_len = df.modeled_seq_len.max()
pred_poly_features = poly.fit_transform(np.arange(max_len)[:, None])
# Add a little more.
pred_y = poly_reg_model.predict(pred_poly_features) + 0.1
row_rog_cutoffs = df.modeled_seq_len.map(lambda x: pred_y[x-1])
return df[df.radius_gyration < row_rog_cutoffs]
def _length_filter(data_csv, min_res, max_res):
return data_csv[
(data_csv.modeled_seq_len >= min_res)
& (data_csv.modeled_seq_len <= max_res)
]
def _plddt_percent_filter(data_csv, min_plddt_percent):
return data_csv[data_csv.num_confident_plddt > min_plddt_percent]
def _max_coil_filter(data_csv, max_coil_percent):
return data_csv[data_csv.coil_percent <= max_coil_percent]
def _process_csv_row(processed_file_path):
processed_feats = du.read_pkl(processed_file_path)
processed_feats = du.parse_chain_feats(processed_feats)
# Only take modeled residues.
modeled_idx = processed_feats['modeled_idx']
min_idx = np.min(modeled_idx)
max_idx = np.max(modeled_idx)
del processed_feats['modeled_idx']
processed_feats = tree.map_structure(
lambda x: x[min_idx:(max_idx+1)], processed_feats)
# Run through OpenFold data transforms.
chain_feats = {
'aatype': torch.tensor(processed_feats['aatype']).long(),
'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(),
'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double()
}
chain_feats = data_transforms.atom37_to_frames(chain_feats)
rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0]
rotmats_1 = rigids_1.get_rots().get_rot_mats()
rotquats_1 = rigids_1.get_rots().get_quats()
trans_1 = rigids_1.get_trans()
res_plddt = processed_feats['b_factors'][:, 1]
res_mask = torch.tensor(processed_feats['bb_mask']).int()
# Re-number residue indices for each chain such that it starts from 1.
# Randomize chain indices.
chain_idx = processed_feats['chain_index']
res_idx = processed_feats['residue_index']
new_res_idx = np.zeros_like(res_idx)
new_chain_idx = np.zeros_like(res_idx)
all_chain_idx = np.unique(chain_idx).tolist()
shuffled_chain_idx = np.array(
random.sample(all_chain_idx, len(all_chain_idx))) - np.min(all_chain_idx) + 1
for i,chain_id in enumerate(all_chain_idx):
chain_mask = (chain_idx == chain_id).astype(int)
chain_min_idx = np.min(res_idx + (1 - chain_mask) * 1e3).astype(int)
new_res_idx = new_res_idx + (res_idx - chain_min_idx + 1) * chain_mask
# Shuffle chain_index
replacement_chain_id = shuffled_chain_idx[i]
new_chain_idx = new_chain_idx + replacement_chain_id * chain_mask
if torch.isnan(trans_1).any() or torch.isnan(rotmats_1).any():
raise ValueError(f'Found NaNs in {processed_file_path}')
return {
'res_plddt': res_plddt,
'aatype': chain_feats['aatype'],
'rotmats_1': rotmats_1,
'rotquats_1': rotquats_1,
'trans_1': trans_1,
'res_mask': res_mask,
'chain_idx': new_chain_idx,
'res_idx': new_res_idx,
}
def _add_plddt_mask(feats, plddt_threshold):
feats['plddt_mask'] = torch.tensor(
feats['res_plddt'] > plddt_threshold).int()
def _read_clusters(cluster_path):
pdb_to_cluster = {}
with open(cluster_path, "r") as f:
for i,line in enumerate(f):
for chain in line.split(' '):
pdb = chain.split('_')[0]
pdb_to_cluster[pdb.upper()] = i
return pdb_to_cluster
class RectifyDataset(Dataset):
def __init__(
self,
*,
dataset_cfg,
is_training,
task,
):
self._log = logging.getLogger(__name__)
self._is_training = is_training
self._dataset_cfg = dataset_cfg
self.task = task
self.raw_csv = pd.read_csv(self._dataset_cfg.rectify_csv_path)
metadata_csv = self._filter_paired_metadata()
metadata_csv = self._filter_metadata(metadata_csv)
metadata_csv = metadata_csv.sort_values(
'modeled_seq_len', ascending=False)
self._create_split(metadata_csv) #* create self.csv
self.paired_pdb_names = self.csv['pdb_name'].unique()
self._cache = {}
self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
@property
def is_training(self):
return self._is_training
@property
def dataset_cfg(self):
return self._dataset_cfg
def _filter_paired_metadata(self):
"""
Only left keep pdb_name with sample and noise pair
"""
grouped = self.raw_csv.groupby('pdb_name')
paired_pdbs = grouped.filter(lambda x: set(x['type']) == {'sample', 'noise'})['pdb_name'].unique()
filtered_csv = self.raw_csv[self.raw_csv['pdb_name'].isin(paired_pdbs)].copy()
self._log.info(f'Found {len(paired_pdbs)} paired pdbs.')
return filtered_csv
def __len__(self):
return len(self.csv)
def _filter_metadata(self, meta_csv: pd.DataFrame) -> pd.DataFrame:
filter_cfg = self._dataset_cfg.filter
sample_csv = meta_csv[meta_csv.type == 'sample']
data_csv = _length_filter(
sample_csv,
filter_cfg.min_num_res,
filter_cfg.max_num_res
)
# data_csv = data_csv[
# data_csv.oligomeric_detail.isin(filter_cfg.oligomeric)]
data_csv = data_csv[
data_csv.num_chains.isin(filter_cfg.num_chains)]
data_csv = _max_coil_filter(data_csv, filter_cfg.max_coil_percent)
data_csv = _rog_filter(data_csv, filter_cfg.rog_quantile)
valid_pdb_names = data_csv['pdb_name'].unique()
filtered_csv = meta_csv[meta_csv['pdb_name'].isin(valid_pdb_names)]
filtered_csv['oligomeric_detail'] = 'monomeric'
return filtered_csv
def _create_split(self, data_csv):
# Training or validation specific logic.
if self.is_training:
self.csv = data_csv
self._log.info(
f'Training: {len(self.csv)} examples')
else: #* Validation
if self._dataset_cfg.max_eval_length is None and self._dataset_cfg.min_eval_length is None:
# min and max are empty
eval_lengths = data_csv.modeled_seq_len
elif self._dataset_cfg.max_eval_length is None:
# min is not empty, max is empty
eval_lengths = data_csv.modeled_seq_len[
data_csv.modeled_seq_len >= self._dataset_cfg.min_eval_length
]
else:
# max is not empty, min is empty
eval_lengths = data_csv.modeled_seq_len[
data_csv.modeled_seq_len <= self._dataset_cfg.max_eval_length
]
if self._dataset_cfg.min_eval_length is not None:
eval_lengths = eval_lengths[eval_lengths >= self._dataset_cfg.min_eval_length]
all_lengths = np.sort(eval_lengths.unique())
length_indices = (len(all_lengths) - 1) * np.linspace(
0.0, 1.0, self.dataset_cfg.num_eval_lengths)
length_indices = length_indices.astype(int)
eval_lengths = all_lengths[length_indices]
eval_csv = data_csv[data_csv.modeled_seq_len.isin(eval_lengths)]
# Fix a random seed to get the same split each time.
eval_csv = eval_csv.groupby('modeled_seq_len').sample(
self.dataset_cfg.samples_per_eval_length,
replace=True,
random_state=123
)
eval_csv = eval_csv.sort_values('modeled_seq_len', ascending=False)
self.csv = eval_csv
self._log.info(
f'Validation: {len(self.csv)} examples with lengths {eval_lengths}')
self.csv['index'] = list(range(len(self.csv)))
def process_csv_row(self, csv_row):
path = csv_row['processed_path']
seq_len = csv_row['modeled_seq_len']
# Large protein files are slow to read. Cache them.
use_cache = seq_len > self._dataset_cfg.cache_num_res
if use_cache and path in self._cache:
return self._cache[path]
processed_row = _process_csv_row(path)
if use_cache:
self._cache[path] = processed_row
return processed_row
def _sample_scaffold_mask(self, batch, rng):
trans_1 = batch['trans_1']
num_res = trans_1.shape[0]
min_motif_size = int(self._dataset_cfg.min_motif_percent * num_res)
max_motif_size = int(self._dataset_cfg.max_motif_percent * num_res)
# Sample the total number of residues that will be used as the motif.
total_motif_size = self._rng.integers(
low=min_motif_size,
high=max_motif_size
)
# Sample motifs at different locations.
num_motifs = rng.integers(low=1, high=total_motif_size)
# Attempt to sample
attempt = 0
while attempt < 100:
# Sample lengths of each motif.
motif_lengths = np.sort(
rng.integers(
low=1,
high=max_motif_size,
size=(num_motifs,)
)
)
# Truncate motifs to not go over the motif length.
cumulative_lengths = np.cumsum(motif_lengths)
motif_lengths = motif_lengths[cumulative_lengths < total_motif_size]
if len(motif_lengths) == 0:
attempt += 1
else:
break
if len(motif_lengths) == 0:
motif_lengths = [total_motif_size]
# Sample start location of each motif.
seed_residues = rng.integers(
low=0,
high=num_res-1,
size=(len(motif_lengths),)
)
# Construct the motif mask.
motif_mask = torch.zeros(num_res)
for motif_seed, motif_len in zip(seed_residues, motif_lengths):
motif_mask[motif_seed:min(motif_seed+motif_len, num_res)] = 1.0
scaffold_mask = 1 - motif_mask
return scaffold_mask * batch['res_mask']
def setup_inpainting(self, feats, rng):
diffuse_mask = self._sample_scaffold_mask(feats, rng)
if 'plddt_mask' in feats:
diffuse_mask = diffuse_mask * feats['plddt_mask']
if torch.sum(diffuse_mask) < 1:
# Should only happen rarely.
diffuse_mask = torch.ones_like(diffuse_mask)
feats['diffuse_mask'] = diffuse_mask
def __getitem__(self, idx):
# Process data example.
#* For training, idx is the index of pdb
#* For validation, idx is the index of the csv
if self.is_training:
pdb_name = self.paired_pdb_names[idx]
noise_pdb_name = self.paired_pdb_names[idx]
# noise_pdb_name = pdb_name
sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0]
# noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
noise_row = self.csv[(self.csv['pdb_name'] == noise_pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
else:
pdb_name = self.csv.iloc[idx]['pdb_name']
sample_row = self.csv[(self.csv['pdb_name'] == pdb_name)].iloc[0]
noise_row = sample_row
# sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0]
# noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
if sample_row.empty or noise_row.empty:
raise ValueError(f"Missing sample or noise row for pdb_name: {pdb_name}")
sample_feats = self.process_csv_row(sample_row)
noise_feats = self.process_csv_row(noise_row)
sample_feats = self.process_feats(sample_feats)
noise_feats = self.process_feats(noise_feats)
# Storing the csv index is helpful for debugging.
if self.is_training:
sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx)
noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx + 1)
else:
sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
return {'sample': sample_feats, 'noise': noise_feats}
def process_feats(self, feats):
if self._dataset_cfg.add_plddt_mask:
_add_plddt_mask(feats, self._dataset_cfg.min_plddt_threshold)
else:
feats['plddt_mask'] = torch.ones_like(feats['res_mask'])
if self.task == 'hallucination':
feats['diffuse_mask'] = torch.ones_like(feats['res_mask']).bool()
elif self.task == 'inpainting':
if self._dataset_cfg.inpainting_percent < random.random():
feats['diffuse_mask'] = torch.ones_like(feats['res_mask'])
else:
rng = self._rng if self.is_training else np.random.default_rng(seed=123)
self.setup_inpainting(feats, rng)
# Center based on motif locations
motif_mask = 1 - feats['diffuse_mask']
trans_1 = feats['trans_1']
motif_1 = trans_1 * motif_mask[:, None]
motif_com = torch.sum(motif_1, dim=0) / (torch.sum(motif_mask) + 1)
trans_1 -= motif_com[None, :]
feats['trans_1'] = trans_1
else:
raise ValueError(f'Unknown task {self.task}')
feats['diffuse_mask'] = feats['diffuse_mask'].int()
return feats