mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Use loggers
Add utilities for training on a single-entry dataset. Allow validation skipping. WIP AF3 Non-equivariant structure encoder/decoder Add flag to force training from scratch Force training from scratch in debug config All modules in diffusion module implemented Document behavior of dropout with test Finish majority of model trunk Convert some ModuleLists to nn.Sequential Add RelativePositionEncoding and WIP af3_repro config Fix ref_space_uid embedding in AtomEncoder Put Model together with fake MSAModule and TemplateEmbedder AF3 repro loads model. WIP af3 data-adaptor, AF3_structure fixes Feature initializer working Standardize S_inputs_I Fix pairformer stack Forward pass working, WIP: backward pass stale reference fixing Add dataloader_adaptor_af3.py Backward pass working, WIP: still some unused params Backprop working Training runs Add pytorch lightning training and some wandb logging Training converging for single example. Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif trainer_lightning.py --config-name af3_repro_single_example_small logger.use_wandb=True af3_data_prep.D=6 Log loss Training working for single example. Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif trainer_lightning.py --config-name af3_repro_single_example_small_working_4 logger.use_wandb=True on an a4000 Add test_diffusion_module.py
This commit is contained in:
committed by
Rohith Krishna
parent
d33c097ff5
commit
11101963df
228
data/subsample_dataset_pickle.ipynb
Normal file
228
data/subsample_dataset_pickle.ipynb
Normal file
@@ -0,0 +1,228 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9986c90c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Dataset subsampling\n",
|
||||
"\n",
|
||||
"This notebook can be used to subsample a dataset pickle to debug training runs, especially for new archictures by attempting to train the architecture on a single example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "32cf6d8c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import pickle\n",
|
||||
"import gc\n",
|
||||
"import copy\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"root_dir = os.path.join(\n",
|
||||
" os.getcwd(),\n",
|
||||
" '..',\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sys.path.append(root_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "960c344d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the existing large dataset\n",
|
||||
"\n",
|
||||
"with open(os.path.join(root_dir, 'rf2aa/dataset_20240318.pkl'), 'rb') as fh:\n",
|
||||
" dataset_clean = pickle.load(fh)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "31f55497",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ii=1154 len(dici[\"PARTNERS\"])=1 dici[\"LEN_EXIST\"]=65\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define filters\n",
|
||||
"\n",
|
||||
"def get_index_passing_filter(train_dict, f):\n",
|
||||
" for i, row in train_dict.iterrows():\n",
|
||||
" if f(row):\n",
|
||||
" return i\n",
|
||||
" raise Exception('No index passing filter')\n",
|
||||
"\n",
|
||||
"def get_short_single_ligand(row):\n",
|
||||
" if row['LEN_EXIST'] > 100:\n",
|
||||
" return False\n",
|
||||
" if len(row['PARTNERS']) != 1:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"ii = get_index_passing_filter(dataset_clean['train_dict']['sm_compl'], get_short_single_ligand)\n",
|
||||
"dici = dataset_clean['train_dict']['sm_compl'].loc[ii]\n",
|
||||
"print(f'{ii=} {len(dici[\"PARTNERS\"])=} {dici[\"LEN_EXIST\"]=}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04c40e33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Subsample the dataset\n",
|
||||
"\n",
|
||||
"def subsample_v2(p, n_per_dataset=1, filter_by_dataset=dict()):\n",
|
||||
" assert n_per_dataset==1\n",
|
||||
" train_ID_dict = p['train_ID_dict']\n",
|
||||
" valid_ID_dict = p['valid_ID_dict']\n",
|
||||
" weights_dict = p['weights_dict']\n",
|
||||
" train_dict = p['train_dict']\n",
|
||||
" valid_dict = p['valid_dict']\n",
|
||||
" homo = p['homo']\n",
|
||||
" chid2hash = p['chid2hash']\n",
|
||||
" chid2taxid = p['chid2taxid']\n",
|
||||
" chid2smpartners = p['chid2smpartners']\n",
|
||||
"\n",
|
||||
" all_chain_ids = set()\n",
|
||||
" datasets = set(train_ID_dict.keys())\n",
|
||||
" \n",
|
||||
" for k, v in train_ID_dict.items():\n",
|
||||
" if k in datasets:\n",
|
||||
" v = v[:n_per_dataset]\n",
|
||||
" dic = train_dict[k]\n",
|
||||
" row_filter = filter_by_dataset.get(k, lambda x: True)\n",
|
||||
" i = [get_index_passing_filter(train_dict[k], row_filter)]\n",
|
||||
" dic = dic.loc[i]\n",
|
||||
" v = np.array(dic.loc[i]['CLUSTER'])\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" v = []\n",
|
||||
" dic = train_dict[k][0:0]\n",
|
||||
" weights_dict[k] = []\n",
|
||||
" if 'CHAINID' in dic:\n",
|
||||
" chain_ids = dic['CHAINID']\n",
|
||||
" for ch_id in chain_ids:\n",
|
||||
" all_chain_ids.add(ch_id)\n",
|
||||
" for ch_id in ch_id.split(':'):\n",
|
||||
" all_chain_ids.add(ch_id)\n",
|
||||
"\n",
|
||||
" train_ID_dict[k] = v\n",
|
||||
" train_dict[k] = dic\n",
|
||||
" weights_dict[k] = weights_dict[k][:n_per_dataset]\n",
|
||||
" \n",
|
||||
" for k, v in valid_ID_dict.items():\n",
|
||||
" v = v[:n_per_dataset]\n",
|
||||
" dic = valid_dict[k]\n",
|
||||
" dic = dic[dic.CLUSTER.isin(v)].reset_index(drop=True)\n",
|
||||
" if 'CHAINID' in dic:\n",
|
||||
" chain_ids = dic['CHAINID']\n",
|
||||
" for ch_id in chain_ids:\n",
|
||||
" all_chain_ids.add(ch_id)\n",
|
||||
" for ch_id in ch_id.split(':'):\n",
|
||||
" all_chain_ids.add(ch_id)\n",
|
||||
"\n",
|
||||
" valid_ID_dict[k] = v\n",
|
||||
" valid_dict[k] = dic\n",
|
||||
" \n",
|
||||
" homo = homo.sample(n_per_dataset)\n",
|
||||
" chid2hash = {k:v for k,v in chid2hash.items() if k in all_chain_ids}\n",
|
||||
" chid2taxid = {k:v for k,v in chid2taxid.items() if k in all_chain_ids}\n",
|
||||
" chid2smpartners = {k:v for k,v in chid2smpartners.items() if k in all_chain_ids}\n",
|
||||
" return dict(\n",
|
||||
" train_ID_dict=train_ID_dict,\n",
|
||||
" valid_ID_dict=valid_ID_dict,\n",
|
||||
" weights_dict=weights_dict,\n",
|
||||
" train_dict=train_dict,\n",
|
||||
" valid_dict=valid_dict,\n",
|
||||
" homo=homo,\n",
|
||||
" chid2hash=chid2hash,\n",
|
||||
" chid2taxid=chid2taxid,\n",
|
||||
" chid2smpartners=chid2smpartners,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"dataset = copy.deepcopy(dataset_clean)\n",
|
||||
"dataset_subsampled = subsample_v2(dataset, filter_by_dataset={'sm_compl': get_short_single_ligand})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3a4c76c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Override the train dict with the valid dict\n",
|
||||
"def use_train_as_valid(dataset):\n",
|
||||
" dataset['valid_ID_dict'] = {k: [] for k,v in dataset['valid_ID_dict'].items()}\n",
|
||||
" dataset['valid_ID_dict'].update(dataset['train_ID_dict'])\n",
|
||||
" dataset['valid_dict'] = {k: v.iloc[0:0] for k,v in dataset['valid_dict'].items()}\n",
|
||||
" dataset['valid_dict'].update(dataset['train_dict'])\n",
|
||||
"\n",
|
||||
"use_train_as_valid(dataset_subsampled)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "c7c8cd07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Write the subsampled dataset pickle, which can now be used to debug architectures\n",
|
||||
"# by specifying loader_param.datapkl=YOUR_PKL_PATH\n",
|
||||
"\n",
|
||||
"dataset_subsampled_path = os.path.join(root_dir, 'rf2aa/subsampled/dataset_20240318_n-2.pkl')\n",
|
||||
"assert not os.path.exists(dataset_subsampled_path)\n",
|
||||
"with open(dataset_subsampled_path, 'wb') as fh:\n",
|
||||
" pickle.dump(dataset_subsampled, fh)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "343b8cca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
47
rf2aa/alignment.py
Normal file
47
rf2aa/alignment.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from rf2aa.debug import pretty_describe_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def weighted_rigid_align(
|
||||
X_L, # [B, L, 3]
|
||||
X_gt_L, # [B, L, 3]
|
||||
w_L, # [B, L]
|
||||
):
|
||||
'''
|
||||
Weighted rigid body alignment of X_L onto X_gt_L
|
||||
Returns:
|
||||
X_align_L: [B, L, 3]
|
||||
'''
|
||||
assert X_L.shape == X_gt_L.shape
|
||||
assert X_L.shape[:-1] == w_L.shape
|
||||
|
||||
u_X = torch.mean(X_L * w_L.unsqueeze(-1), dim=-2) / torch.mean(w_L, dim=-1, keepdim=True)
|
||||
u_X_gt = torch.mean(X_gt_L * w_L.unsqueeze(-1), dim=-2) / torch.mean(w_L, dim=-1, keepdim=True)
|
||||
|
||||
X_L = X_L - u_X.unsqueeze(-2)
|
||||
X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
|
||||
|
||||
# Computation of the covariance matrix
|
||||
C = torch.transpose(X_gt_L, -1, -2) @ X_L
|
||||
|
||||
U, S, V = torch.linalg.svd(C)
|
||||
|
||||
R = U @ V
|
||||
B, _, _ = X_L.shape
|
||||
F = torch.eye(3,3, device=X_L.device)[None].tile((B,1,1,))
|
||||
|
||||
F[...,-1, -1] = torch.sign(torch.linalg.det(R))
|
||||
R = U @ F @ V
|
||||
|
||||
X_align_L = X_L @ R.transpose(-1, -2) + u_X_gt.unsqueeze(-2)
|
||||
|
||||
return X_align_L.detach()
|
||||
|
||||
def get_rmsd(xyz1, xyz2, eps=1e-4):
|
||||
L = xyz1.shape[-2]
|
||||
rmsd = torch.sqrt(torch.sum((xyz2-xyz1)*(xyz2-xyz1), axis=(-1, -2)) / L + eps)
|
||||
return rmsd
|
||||
200
rf2aa/callbacks.py
Normal file
200
rf2aa/callbacks.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import os
|
||||
import csv
|
||||
import shutil
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
from icecream import ic
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tree
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning import Trainer, LightningModule
|
||||
from lightning_fabric.loggers.csv_logs import _ExperimentWriter
|
||||
|
||||
from rf2aa.tensor_util import apply_to_tensors
|
||||
from rf2aa.debug import pretty_describe_dict
|
||||
from rf2aa.model.AF3_structure import Loss
|
||||
from rf2aa import pymol_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def flatten_dictionary(dictionary, parent_key='', separator='.'):
|
||||
flattened_dict = {}
|
||||
for key, value in dictionary.items():
|
||||
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
||||
if isinstance(value, dict):
|
||||
flattened_dict.update(flatten_dictionary(value, new_key, separator))
|
||||
else:
|
||||
flattened_dict[new_key] = value
|
||||
return flattened_dict
|
||||
|
||||
class LogMetrics(Callback):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int) -> None:
|
||||
|
||||
logger.debug('on_train_batch_end outputs:\n' + pretty_describe_dict(outputs))
|
||||
|
||||
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
|
||||
o = {}
|
||||
stratifications = defaultdict(list)
|
||||
for metric in [diffusion_losses]:
|
||||
metric_d, stratification_keys = metric(self.config, outputs)
|
||||
stratifications[stratification_keys].extend(metric_d.keys())
|
||||
o.update(metric_d)
|
||||
|
||||
o['t'] = outputs['t']
|
||||
o['t_quantile_4'] = get_t_quantiles(outputs['t'], self.config.loss.sigma_data, 4)
|
||||
df = pd.DataFrame.from_dict(o)
|
||||
df = df.reindex(sorted(df.columns), axis=1)
|
||||
|
||||
D, = outputs['t'].shape
|
||||
df['batch_idx'] = batch_idx
|
||||
df['data_idx'] = np.arange(D)
|
||||
df['global_step'] = trainer.global_step
|
||||
trainer.logger.log_df(df, stratifications=stratifications)
|
||||
|
||||
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
||||
|
||||
|
||||
def diffusion_losses(config, outputs):
|
||||
|
||||
loss = Loss(**config.loss)
|
||||
|
||||
loss_dict_by_type = {}
|
||||
t = outputs['t']
|
||||
X_noisy_L = outputs['X_noisy_L']
|
||||
sigma_data = 16
|
||||
|
||||
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[...,None,None] * X_noisy_L
|
||||
|
||||
sigma_gt = torch.var(outputs['X_gt_L'], dim=(1,2))**0.5
|
||||
for input_type, X_L in (
|
||||
('pred', outputs['X_L']),
|
||||
# ('input', outputs['X_noisy_L']),
|
||||
('true', outputs['X_gt_L']),
|
||||
('null_pred', null_pred),
|
||||
):
|
||||
l_total, _, loss_dict_batched = loss(
|
||||
outputs['f'],
|
||||
X_L,
|
||||
outputs['X_gt_L'],
|
||||
outputs['t'],
|
||||
)
|
||||
# loss_dict_by_type[input_type] = loss_dict_batched
|
||||
loss_dict_batched_prefixed = {f'{k}.{input_type}':v for k,v in loss_dict_batched.items()}
|
||||
loss_dict_by_type.update(loss_dict_batched_prefixed)
|
||||
|
||||
# Correcting for EDM : AF3 lambda conversion
|
||||
edm_corr = (t+loss.sigma_data)**2 / (t*loss.sigma_data)**2
|
||||
loss_dict_batched_edm = {k:v * edm_corr for k,v in loss_dict_batched.items()}
|
||||
loss_dict_batched_prefixed_edm = {f'{k}_edm.{input_type}':v for k,v in loss_dict_batched_edm.items()}
|
||||
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
|
||||
|
||||
# Correcting for Var(gt) != sigma_data
|
||||
expected_loss_gt = 1 / (loss.sigma_data**2 + t**2) * (loss.sigma_data**2 + t**2 * sigma_gt**2 / loss.sigma_data**2)
|
||||
loss_dict_batched_edm_gt_corr = {k: edm_corr * v / expected_loss_gt for k,v in loss_dict_batched.items()}
|
||||
loss_dict_batched_prefixed_edm = {f'{k}_edm_gt_corr.{input_type}':v for k,v in loss_dict_batched_edm_gt_corr.items()}
|
||||
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
|
||||
|
||||
|
||||
o = flatten_dictionary(loss_dict_by_type)
|
||||
o['pred_over_null_pred'] = o['diffusion_loss.pred'] / o['diffusion_loss.null_pred']
|
||||
o['pred_over_null_pred_norm'] = o['diffusion_loss_edm_gt_corr.pred'] / o['diffusion_loss_edm_gt_corr.null_pred']
|
||||
return o, ('t_quantile_4',)
|
||||
|
||||
def get_normal_quantiles(n):
|
||||
# Generate n evenly spaced probabilities between 0 and 1
|
||||
probabilities = np.linspace(0, 1, n)
|
||||
# Use the percent point function (inverse CDF) of the standard normal distribution
|
||||
return norm.ppf(probabilities)
|
||||
|
||||
def get_t_quantiles(t, sigma_data, n):
|
||||
bins = sigma_data * np.exp(-1.2 + 1.5 * get_normal_quantiles(n+1))
|
||||
t_binned_list = []
|
||||
for t in t:
|
||||
t_bin = np.digitize(t, bins) - 1
|
||||
bin_start = bins[t_bin]
|
||||
bin_end = bins[t_bin+1]
|
||||
t_binned = f't=[{bin_start:.2f},{bin_end:.2f})'
|
||||
t_binned_list.append(t_binned)
|
||||
return t_binned_list
|
||||
|
||||
class NetworkOutputGradSanityCheck(Callback):
|
||||
def __init__(self, call_n_times=0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.call_n_times = call_n_times
|
||||
self.call_count = 0
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
|
||||
if self.call_count < self.call_n_times:
|
||||
self.call_count += 1
|
||||
r_projection_weight = pl_module.model.model.diffusion_module.atom_attention_decoder.to_r_update[1].weight
|
||||
ic(
|
||||
torch.linalg.norm(r_projection_weight) if r_projection_weight is not None else None,
|
||||
torch.linalg.norm(r_projection_weight.grad) if r_projection_weight.grad is not None else None,
|
||||
)
|
||||
|
||||
class MonitorActivations(Callback):
|
||||
|
||||
def make_hook(self, label):
|
||||
def hook(module, args, kwargs, output):
|
||||
activation_metrics = {
|
||||
|
||||
f'{label}:inter_batch_cosine_similarity': F.cosine_similarity(
|
||||
torch.flatten(output[0]),
|
||||
torch.flatten(output[1]),
|
||||
dim=0,
|
||||
),
|
||||
f'{label}:inter_batch_cosine_similarity': F.cosine_similarity(
|
||||
torch.flatten(output[0]),
|
||||
torch.flatten(output[1]),
|
||||
dim=0,
|
||||
),
|
||||
f'{label}:intra_batch_cosine_similarity_to_elem_0': F.cosine_similarity(
|
||||
output[0][0:1],
|
||||
output[0],
|
||||
).mean(),
|
||||
}
|
||||
self.log_dict(activation_metrics)
|
||||
return hook
|
||||
|
||||
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
self.pl_module = pl_module
|
||||
self.trainer = trainer
|
||||
|
||||
pl_module.model.model.diffusion_module.atom_attention_decoder.register_forward_hook(
|
||||
self.make_hook(
|
||||
'diffusion_module.atom_attention_decoder',
|
||||
),
|
||||
with_kwargs=True
|
||||
)
|
||||
|
||||
class FindUnusedParameters(Callback):
|
||||
def __init__(self, only_once=True, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.only_once = only_once
|
||||
self.called = False
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
if self.called and self.only_once:
|
||||
return
|
||||
self.called=True
|
||||
# Calculate unused parameters after each batch
|
||||
unused_params = [name for name, param in pl_module.named_parameters() if param.grad is None]
|
||||
|
||||
# Log unused parameters
|
||||
logging.info(f'global_step={pl_module.global_step}: parameters with no gradient: {json.dumps(unused_params, indent=4)}')
|
||||
if unused_params:
|
||||
raise Exception('storp')
|
||||
193
rf2aa/config/train/af3_repro.yaml
Normal file
193
rf2aa/config/train/af3_repro.yaml
Normal file
@@ -0,0 +1,193 @@
|
||||
defaults:
|
||||
# - af3
|
||||
- rf2aa
|
||||
- _self_
|
||||
|
||||
experiment:
|
||||
name: rf2aa-af3-repro
|
||||
trainer: af3_repro
|
||||
output_dir: null
|
||||
|
||||
loss:
|
||||
sigma_data: ${model.diffusion_module.sigma_data}
|
||||
alpha_dna: 5
|
||||
alpha_rna: 5
|
||||
alpha_ligand: 10
|
||||
edm_lambda: False
|
||||
se3_invariant_loss: True
|
||||
|
||||
ddp_params:
|
||||
batch_size: 1
|
||||
|
||||
loader_params:
|
||||
p_msa_mask: 0.0
|
||||
|
||||
interpolant:
|
||||
sigma_data: 16
|
||||
|
||||
min_t: 1e-2
|
||||
separate_t: False
|
||||
provide_kappa: False
|
||||
hierarchical_t: False
|
||||
codesign_separate_t: False
|
||||
codesign_forward_fold_prop: 0.0
|
||||
codesign_inverse_fold_prop: 0.0
|
||||
|
||||
twisting:
|
||||
use: False
|
||||
|
||||
rots:
|
||||
corrupt: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
exp_rate: 10
|
||||
|
||||
trans:
|
||||
corrupt: True
|
||||
batch_ot: True
|
||||
train_schedule: linear
|
||||
sample_schedule: linear
|
||||
sample_temp: 1.0
|
||||
vpsde_bmin: 0.1
|
||||
vpsde_bmax: 20.0
|
||||
potential: null
|
||||
potential_t_scaling: False
|
||||
rog:
|
||||
weight: 10.0
|
||||
cutoff: 5.0
|
||||
|
||||
aatypes:
|
||||
corrupt: False
|
||||
schedule: linear
|
||||
schedule_exp_rate: 10
|
||||
temp: 1.0
|
||||
noise: 0.0
|
||||
do_purity: False
|
||||
train_extra_mask: 0.0
|
||||
interpolant_type: masking
|
||||
num_tokens: 80
|
||||
|
||||
sampling:
|
||||
num_timesteps: 20
|
||||
do_sde: False
|
||||
self_condition: False
|
||||
|
||||
model_globals:
|
||||
l_max: 1000
|
||||
model:
|
||||
c_s: 384
|
||||
c_z: 128
|
||||
c_atom: 128
|
||||
c_atompair: 16
|
||||
c_s_inputs: 449
|
||||
feature_initializer:
|
||||
c_s_inputs: ${model.c_s_inputs}
|
||||
input_feature_embedder:
|
||||
features:
|
||||
- restype
|
||||
- profile
|
||||
- deletion_mean
|
||||
atom_attention_encoder:
|
||||
c_token: 384
|
||||
c_atom_1d_features: 389
|
||||
c_tokenpair: ${model.c_z}
|
||||
atom_1d_features:
|
||||
- ref_pos
|
||||
- ref_charge
|
||||
- ref_mask
|
||||
- ref_element
|
||||
- ref_atom_name_chars
|
||||
atom_transformer:
|
||||
n_queries: 32
|
||||
n_keys: 128
|
||||
l_max: ${model_globals.l_max}
|
||||
diffusion_transformer:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
relative_position_encoding:
|
||||
r_max: 32
|
||||
s_max: 2
|
||||
recycler:
|
||||
n_pairformer_blocks: 48
|
||||
pairformer_block:
|
||||
p_drop: 0.25
|
||||
c: 128
|
||||
attention_pair_bias:
|
||||
n_head: 16
|
||||
template_embedder:
|
||||
n_block: 2
|
||||
c: 64
|
||||
msa_module:
|
||||
n_block: 4
|
||||
c_m: 64
|
||||
diffusion_module:
|
||||
sigma_data: ${interpolant.sigma_data}
|
||||
c_token: 768
|
||||
f_pred: edm
|
||||
diffusion_conditioning:
|
||||
c_s_inputs: ${model.c_s_inputs}
|
||||
c_t_embed: 256
|
||||
relative_position_encoding:
|
||||
r_max: 32
|
||||
s_max: 2
|
||||
atom_attention_encoder:
|
||||
c_tokenpair: ${model.c_z}
|
||||
c_atom_1d_features: 389
|
||||
atom_1d_features:
|
||||
- ref_pos
|
||||
- ref_charge
|
||||
- ref_mask
|
||||
- ref_element
|
||||
- ref_atom_name_chars
|
||||
atom_transformer:
|
||||
n_queries: 32
|
||||
n_keys: 128
|
||||
l_max: ${model_globals.l_max}
|
||||
diffusion_transformer:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
diffusion_transformer:
|
||||
n_block: 24
|
||||
diffusion_transformer_block:
|
||||
n_head: 16
|
||||
atom_attention_decoder:
|
||||
atom_transformer:
|
||||
n_queries: 32
|
||||
n_keys: 128
|
||||
l_max: ${model_globals.l_max}
|
||||
diffusion_transformer:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
|
||||
optimizer:
|
||||
type: Adam
|
||||
params:
|
||||
lr: 1.8e-3
|
||||
betas: [0.9, 0.95]
|
||||
eps: 1.0e-8
|
||||
|
||||
logger:
|
||||
save_dir: csv_logs
|
||||
use_wandb: False
|
||||
sublogger:
|
||||
project: af3-debug
|
||||
|
||||
callbacks:
|
||||
log_metrics: {}
|
||||
|
||||
af3_data_prep:
|
||||
D: 12
|
||||
sigma_data: ${model.diffusion_module.sigma_data}
|
||||
s_trans: 1
|
||||
random_augmentation: True
|
||||
only_ca: False
|
||||
|
||||
recycling:
|
||||
max_cycle: 4
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accumulate_grad_batches: 25
|
||||
63
rf2aa/config/train/af3_repro_single_example.yaml
Normal file
63
rf2aa/config/train/af3_repro_single_example.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
defaults:
|
||||
- af3_repro
|
||||
- _self_
|
||||
|
||||
loader_params:
|
||||
datapkl: /home/ahern/projects/RF2-allatom/rf2aa/subsampled/dataset_20240318_n-1.pkl
|
||||
no_match_okay: True
|
||||
|
||||
training_params:
|
||||
from_scratch: True
|
||||
learning_rate_schedule:
|
||||
decay_rate: 1.
|
||||
|
||||
dataset_params:
|
||||
n_train: 500
|
||||
validate_every_n_epochs: 100
|
||||
validate_after_first_epoch: False
|
||||
fraction_pdb: 0.
|
||||
fraction_fb: 0.
|
||||
fraction_compl: 0.
|
||||
fraction_neg_compl: 0.
|
||||
fraction_na_compl: 0.
|
||||
fraction_neg_na_compl: 0.
|
||||
fraction_distil_tf: 0.
|
||||
fraction_tf: 0.
|
||||
fraction_neg_tf: 0.
|
||||
fraction_rna: 0.
|
||||
fraction_dna: 0.
|
||||
fraction_sm_compl: 1.
|
||||
fraction_metal_compl: 0.
|
||||
fraction_sm_compl_multi: 0.
|
||||
fraction_sm_compl_covale: 0.
|
||||
fraction_sm: 0.
|
||||
fraction_atomize_pdb: 0.
|
||||
fraction_atomize_complex: 0.
|
||||
fraction_sm_compl_asmb: 0.
|
||||
|
||||
n_valid_pdb: 0
|
||||
n_valid_homo: 0
|
||||
n_valid_dslf: 0
|
||||
n_valid_compl: 0
|
||||
n_valid_neg_compl: 0
|
||||
n_valid_na_compl: 0
|
||||
n_valid_neg_na_compl: 0
|
||||
n_valid_distil_tf: 0
|
||||
n_valid_tf: 0
|
||||
n_valid_neg_tf: 0
|
||||
n_valid_rna: 0
|
||||
n_valid_dna: 0
|
||||
n_valid_sm_compl: 2
|
||||
n_valid_metal_compl: 0
|
||||
n_valid_sm_compl_multi: 0
|
||||
n_valid_sm_compl_covale: 0
|
||||
n_valid_sm_compl_strict: 0
|
||||
n_valid_sm: 0
|
||||
n_valid_atomize_pdb: 0
|
||||
n_valid_atomize_complex: 0
|
||||
n_valid_sm_compl_asmb: 0
|
||||
n_valid_fb: 0
|
||||
|
||||
log_params:
|
||||
use_wandb: true
|
||||
wandb_project: rf2aa-debug
|
||||
32
rf2aa/config/train/af3_repro_single_example_extra_small.yaml
Normal file
32
rf2aa/config/train/af3_repro_single_example_extra_small.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
defaults:
|
||||
- af3_repro_single_example_small
|
||||
- _self_
|
||||
|
||||
cpu_training: true
|
||||
training_params:
|
||||
ddp_backend: gloo
|
||||
|
||||
model:
|
||||
recycler:
|
||||
n_pairformer_blocks: 2
|
||||
template_embedder:
|
||||
n_block: 2
|
||||
msa_module:
|
||||
n_block: 2
|
||||
diffusion_module:
|
||||
atom_attention_encoder:
|
||||
atom_transformer:
|
||||
diffusion_transformer:
|
||||
n_block: 2
|
||||
diffusion_transformer:
|
||||
n_block: 2
|
||||
atom_attention_decoder:
|
||||
atom_transformer:
|
||||
diffusion_transformer:
|
||||
n_block: 2
|
||||
|
||||
logger:
|
||||
sublogger:
|
||||
project: af3-debug-cpu
|
||||
|
||||
autograd_detect_anomaly: true
|
||||
28
rf2aa/config/train/af3_repro_single_example_small.yaml
Normal file
28
rf2aa/config/train/af3_repro_single_example_small.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
defaults:
|
||||
- af3_repro_single_example
|
||||
- _self_
|
||||
|
||||
model:
|
||||
recycler:
|
||||
n_pairformer_blocks: 10
|
||||
|
||||
recycling:
|
||||
max_cycle: 1
|
||||
|
||||
loader_params:
|
||||
maxcycle: 1
|
||||
# template_embedder:
|
||||
# n_block: 2
|
||||
# msa_module:
|
||||
# n_block: 2
|
||||
# diffusion_module:
|
||||
# atom_attention_encoder:
|
||||
# atom_transformer:
|
||||
# diffusion_transformer:
|
||||
# n_block: 2
|
||||
# diffusion_transformer:
|
||||
# n_block: 2
|
||||
# atom_attention_decoder:
|
||||
# atom_transformer:
|
||||
# diffusion_transformer:
|
||||
# n_block: 2
|
||||
@@ -0,0 +1,12 @@
|
||||
# Status: unvetted
|
||||
|
||||
defaults:
|
||||
- af3_repro_single_example_small
|
||||
- _self_
|
||||
|
||||
af3_data_prep:
|
||||
D: 12
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accumulate_grad_batches: 1
|
||||
@@ -12,6 +12,7 @@ model:
|
||||
legacy_model: null
|
||||
dataset_params:
|
||||
validate_every_n_epochs: 0
|
||||
validate_after_first_epoch: False
|
||||
fraction_pdb: 0
|
||||
fraction_fb: 0
|
||||
fraction_compl: 0
|
||||
@@ -89,6 +90,7 @@ loader_params:
|
||||
min_metal_contacts: 0
|
||||
min_metal_contact_dist: 2.6
|
||||
sampler_class: DistributedWeightedSampler
|
||||
no_match_okay: False # If False, asserts that the dataset params in the cached dataset pickle match the config dataset params
|
||||
dataloader_kwargs:
|
||||
shuffle: False
|
||||
num_workers: 0
|
||||
@@ -146,3 +148,12 @@ chem_params:
|
||||
use_lj_params_for_atoms: False
|
||||
|
||||
metrics: ["mean_pae", "mean_plddt"]
|
||||
|
||||
hydra:
|
||||
job_logging:
|
||||
formatters:
|
||||
simple:
|
||||
format: '[%(asctime)s][%(filename)25s:%(lineno)4s][%(name)30s:%(funcName)30s()][%(levelname)s] %(message)s'
|
||||
|
||||
cpu_training: false
|
||||
autograd_detect_anomaly: false
|
||||
|
||||
61
rf2aa/config/train/generative_refinement_debug.yaml
Normal file
61
rf2aa/config/train/generative_refinement_debug.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
defaults:
|
||||
- generative_refinement
|
||||
- _self_
|
||||
|
||||
loader_params:
|
||||
datapkl: /home/ahern/projects/RF2-allatom/rf2aa/subsampled/dataset_20240318_n-1.pkl
|
||||
no_match_okay: True
|
||||
|
||||
dataset_params:
|
||||
n_train: 500
|
||||
validate_every_n_epochs: 100
|
||||
validate_after_first_epoch: False
|
||||
fraction_pdb: 0.
|
||||
fraction_fb: 0.
|
||||
fraction_compl: 0.
|
||||
fraction_neg_compl: 0.
|
||||
fraction_na_compl: 0.
|
||||
fraction_neg_na_compl: 0.
|
||||
fraction_distil_tf: 0.
|
||||
fraction_tf: 0.
|
||||
fraction_neg_tf: 0.
|
||||
fraction_rna: 0.
|
||||
fraction_dna: 0.
|
||||
fraction_sm_compl: 1.
|
||||
fraction_metal_compl: 0.
|
||||
fraction_sm_compl_multi: 0.
|
||||
fraction_sm_compl_covale: 0.
|
||||
fraction_sm: 0.
|
||||
fraction_atomize_pdb: 0.
|
||||
fraction_atomize_complex: 0.
|
||||
fraction_sm_compl_asmb: 0.
|
||||
|
||||
n_valid_pdb: 0
|
||||
n_valid_homo: 0
|
||||
n_valid_dslf: 0
|
||||
n_valid_compl: 0
|
||||
n_valid_neg_compl: 0
|
||||
n_valid_na_compl: 0
|
||||
n_valid_neg_na_compl: 0
|
||||
n_valid_distil_tf: 0
|
||||
n_valid_tf: 0
|
||||
n_valid_neg_tf: 0
|
||||
n_valid_rna: 0
|
||||
n_valid_dna: 0
|
||||
n_valid_sm_compl: 2
|
||||
n_valid_metal_compl: 0
|
||||
n_valid_sm_compl_multi: 0
|
||||
n_valid_sm_compl_covale: 0
|
||||
n_valid_sm_compl_strict: 0
|
||||
n_valid_sm: 0
|
||||
n_valid_atomize_pdb: 0
|
||||
n_valid_atomize_complex: 0
|
||||
n_valid_sm_compl_asmb: 0
|
||||
n_valid_fb: 0
|
||||
|
||||
log_params:
|
||||
use_wandb: True
|
||||
wandb_project: rf2aa-debug
|
||||
|
||||
training_params:
|
||||
from_scratch: True
|
||||
@@ -1,8 +1,12 @@
|
||||
import numpy as np
|
||||
from icecream import ic
|
||||
import pickle
|
||||
import torch.utils.data as data
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from rf2aa.data.data_loader import get_train_valid_set, loader_pdb, loader_complex, loader_na_complex, \
|
||||
loader_distil_tf, loader_tf_complex, loader_fb, loader_dna_rna, \
|
||||
@@ -138,7 +142,7 @@ def get_distilled_dataset(dataset_params, loader_params):
|
||||
chid2hash,
|
||||
chid2taxid,
|
||||
chid2smpartners,
|
||||
) = get_train_valid_set(loader_params)
|
||||
) = get_train_valid_set(loader_params, no_match_okay=loader_params['no_match_okay'])
|
||||
|
||||
# define atomize_pdb train/valid sets, which use the same examples as pdb set
|
||||
train_ID_dict["atomize_pdb"] = train_ID_dict["pdb"]
|
||||
|
||||
@@ -10,6 +10,8 @@ from itertools import permutations
|
||||
from typing import Dict, Optional, Tuple, List, Set, Any
|
||||
from pathlib import Path
|
||||
from os.path import exists
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(script_dir)
|
||||
@@ -39,7 +41,14 @@ from rf2aa.util import get_nxgraph, get_atom_frames, get_bond_feats, get_protein
|
||||
reindex_protein_feats_after_atomize, get_residue_contacts, atomize_discontiguous_residues, pop_protein_feats, \
|
||||
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, \
|
||||
is_protein, is_nucleic, is_RNA, is_DNA, is_atom
|
||||
from rf2aa.data.cluster_dataset import cluster_factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from rf2aa.data.cluster_dataset import cluster_factory
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to import cluster_factory from rf2aa.data.cluster_dataset: if you are rebuilding the dataset .pkl expect failure: ' + repr(e))
|
||||
|
||||
assert "rf2aa" in os.path.abspath(cifutils.__file__)
|
||||
|
||||
|
||||
|
||||
565
rf2aa/data/dataloader_adaptor_af3.py
Normal file
565
rf2aa/data/dataloader_adaptor_af3.py
Normal file
@@ -0,0 +1,565 @@
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from icecream import ic
|
||||
|
||||
from rf2aa.kinematics import xyz_to_t2d
|
||||
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs
|
||||
from rf2aa.util import is_atom, \
|
||||
Ls_from_same_chain_2d, xyz_t_to_frame_xyz, get_prot_sm_mask
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.flow_matching import data_transforms
|
||||
from rf2aa.debug import pretty_describe_dict
|
||||
from rf2aa.util import rigid_from_3_points
|
||||
from rf2aa.flow_matching.rigid_utils import rot_vec_mul
|
||||
from rf2aa.set_seed import seed_all
|
||||
from rf2aa.util import writepdb
|
||||
from rf2aa import pymol_tools
|
||||
from rf2aa.pymol import cmd
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def within_group_unique_ids(group_ids, element_ids):
|
||||
# Initialize a dictionary to store unique element mappings for each group
|
||||
unique_mappings = {}
|
||||
unique_id_counter = 0
|
||||
|
||||
# Initialize a list to store the resulting within-group unique ids
|
||||
within_group_unique = []
|
||||
|
||||
# Iterate over the group ids and element ids simultaneously
|
||||
for group_id, element_id in zip(group_ids, element_ids):
|
||||
# Check if the current group_id is already in the unique_mappings dictionary
|
||||
if group_id in unique_mappings:
|
||||
# If the element_id is already mapped to a unique id within this group, use it
|
||||
if element_id in unique_mappings[group_id]:
|
||||
within_group_unique.append(unique_mappings[group_id][element_id])
|
||||
# If the element_id is not yet mapped to a unique id within this group, assign a new unique id
|
||||
else:
|
||||
unique_mappings[group_id][element_id] = unique_id_counter
|
||||
within_group_unique.append(unique_id_counter)
|
||||
unique_id_counter += 1
|
||||
# If the current group_id is not yet in the unique_mappings dictionary, add it
|
||||
else:
|
||||
unique_mappings[group_id] = {element_id: unique_id_counter}
|
||||
within_group_unique.append(unique_id_counter)
|
||||
unique_id_counter += 1
|
||||
|
||||
# Convert the resulting list to a PyTorch tensor
|
||||
within_group_unique_tensor = torch.tensor(within_group_unique)
|
||||
|
||||
return within_group_unique_tensor
|
||||
|
||||
def integer_tokenize(iterable):
|
||||
# Create a dictionary mapping unique elements to integers
|
||||
unique_elements = list(set(iterable))
|
||||
mapping = {element: i for i, element in enumerate(unique_elements)}
|
||||
|
||||
# Convert iterable to integer tensor
|
||||
int_tensor = torch.tensor([mapping[element] for element in iterable])
|
||||
|
||||
return int_tensor
|
||||
|
||||
af3_num2aa = [
|
||||
# Amino acids:
|
||||
'ALA','ARG','ASN','ASP','CYS',
|
||||
'GLN','GLU','GLY','HIS','ILE',
|
||||
'LEU','LYS','MET','PHE','PRO',
|
||||
'SER','THR','TRP','TYR','VAL',
|
||||
'UNK', # 20 + 1
|
||||
# DNA
|
||||
'DA','DC','DG','DT', 'DUNK',
|
||||
# RNA
|
||||
'RA','RC','RG',' RU', 'RUNK',
|
||||
# GAP
|
||||
'GAP',
|
||||
]
|
||||
|
||||
af3_aa2num = {x:i for i,x in enumerate(af3_num2aa)}
|
||||
|
||||
aa_coarse_from_fine = {
|
||||
'ALA':'ALA',
|
||||
'ARG':'ARG',
|
||||
'ASN':'ASN',
|
||||
'ASP':'ASP',
|
||||
'CYS':'CYS',
|
||||
'GLN':'GLN',
|
||||
'GLU':'GLU',
|
||||
'GLY':'GLY',
|
||||
'HIS':'HIS',
|
||||
'ILE':'ILE',
|
||||
'LEU':'LEU',
|
||||
'LYS':'LYS',
|
||||
'MET':'MET',
|
||||
'PHE':'PHE',
|
||||
'PRO':'PRO',
|
||||
'SER':'SER',
|
||||
'THR':'THR',
|
||||
'TRP':'TRP',
|
||||
'TYR':'TYR',
|
||||
'VAL':'VAL',
|
||||
'UNK':'UNK',
|
||||
# 'MAS':'MAS',
|
||||
' DA':'DA',
|
||||
' DC':'DC',
|
||||
' DG':'DG',
|
||||
' DT':'DT',
|
||||
' DX':'DUNK',
|
||||
' RA':'RA',
|
||||
' RC':'RC',
|
||||
' RG':'RG',
|
||||
' RU':'RU',
|
||||
' RX':'RUNK',
|
||||
# 'HIS_D':'UNK',
|
||||
'Al':'UNK',
|
||||
'As':'UNK',
|
||||
'Au':'UNK',
|
||||
'B':'UNK',
|
||||
'Be':'UNK',
|
||||
'Br':'UNK',
|
||||
'C':'C',
|
||||
'Ca':'UNK',
|
||||
'Cl':'UNK',
|
||||
'Co':'UNK',
|
||||
'Cr':'UNK',
|
||||
'Cu':'UNK',
|
||||
'F':'UNK',
|
||||
'Fe':'UNK',
|
||||
'Hg':'UNK',
|
||||
'I':'UNK',
|
||||
'Ir':'UNK',
|
||||
'K':'UNK',
|
||||
'Li':'UNK',
|
||||
'Mg':'UNK',
|
||||
'Mn':'UNK',
|
||||
'Mo':'UNK',
|
||||
'N':'UNK',
|
||||
'Ni':'UNK',
|
||||
'O':'UNK',
|
||||
'Os':'UNK',
|
||||
'P':'UNK',
|
||||
'Pb':'UNK',
|
||||
'Pd':'UNK',
|
||||
'Pr':'UNK',
|
||||
'Pt':'UNK',
|
||||
'Re':'UNK',
|
||||
'Rh':'UNK',
|
||||
'Ru':'UNK',
|
||||
'S':'UNK',
|
||||
'Sb':'UNK',
|
||||
'Se':'UNK',
|
||||
'Si':'UNK',
|
||||
'Sn':'UNK',
|
||||
'Tb':'UNK',
|
||||
'Te':'UNK',
|
||||
'U':'UNK',
|
||||
'W':'UNK',
|
||||
'V':'UNK',
|
||||
'Y':'UNK',
|
||||
'Zn':'UNK',
|
||||
'ATM':'ATM'
|
||||
}
|
||||
from enum import Enum
|
||||
|
||||
class TokenType(Enum):
|
||||
PROTEIN = 1
|
||||
DNA = 2
|
||||
RNA = 3
|
||||
LIGAND = 4
|
||||
|
||||
aa_restype_from_fine = {
|
||||
'ALA': TokenType.PROTEIN,
|
||||
'ARG': TokenType.PROTEIN,
|
||||
'ASN': TokenType.PROTEIN,
|
||||
'ASP': TokenType.PROTEIN,
|
||||
'CYS': TokenType.PROTEIN,
|
||||
'GLN': TokenType.PROTEIN,
|
||||
'GLU': TokenType.PROTEIN,
|
||||
'GLY': TokenType.PROTEIN,
|
||||
'HIS': TokenType.PROTEIN,
|
||||
'ILE': TokenType.PROTEIN,
|
||||
'LEU': TokenType.PROTEIN,
|
||||
'LYS': TokenType.PROTEIN,
|
||||
'MET': TokenType.PROTEIN,
|
||||
'PHE': TokenType.PROTEIN,
|
||||
'PRO': TokenType.PROTEIN,
|
||||
'SER': TokenType.PROTEIN,
|
||||
'THR': TokenType.PROTEIN,
|
||||
'TRP': TokenType.PROTEIN,
|
||||
'TYR': TokenType.PROTEIN,
|
||||
'VAL': TokenType.PROTEIN,
|
||||
'UNK': TokenType.PROTEIN,
|
||||
# 'MAS':'MAS',
|
||||
' DA': TokenType.DNA,
|
||||
' DC': TokenType.DNA,
|
||||
' DG': TokenType.DNA,
|
||||
' DT': TokenType.DNA,
|
||||
' DX': TokenType.DNA,
|
||||
' RA': TokenType.RNA,
|
||||
' RC': TokenType.RNA,
|
||||
' RG': TokenType.RNA,
|
||||
' RU': TokenType.RNA,
|
||||
' RX': TokenType.RNA,
|
||||
# 'HIS_D':'UNK',
|
||||
'Al': TokenType.LIGAND,
|
||||
'As': TokenType.LIGAND,
|
||||
'Au': TokenType.LIGAND,
|
||||
'B': TokenType.LIGAND,
|
||||
'Be': TokenType.LIGAND,
|
||||
'Br': TokenType.LIGAND,
|
||||
'C': TokenType.LIGAND,
|
||||
'Ca': TokenType.LIGAND,
|
||||
'Cl': TokenType.LIGAND,
|
||||
'Co': TokenType.LIGAND,
|
||||
'Cr': TokenType.LIGAND,
|
||||
'Cu': TokenType.LIGAND,
|
||||
'F': TokenType.LIGAND,
|
||||
'Fe': TokenType.LIGAND,
|
||||
'Hg': TokenType.LIGAND,
|
||||
'I': TokenType.LIGAND,
|
||||
'Ir': TokenType.LIGAND,
|
||||
'K': TokenType.LIGAND,
|
||||
'Li': TokenType.LIGAND,
|
||||
'Mg': TokenType.LIGAND,
|
||||
'Mn': TokenType.LIGAND,
|
||||
'Mo': TokenType.LIGAND,
|
||||
'N': TokenType.LIGAND,
|
||||
'Ni': TokenType.LIGAND,
|
||||
'O': TokenType.LIGAND,
|
||||
'Os': TokenType.LIGAND,
|
||||
'P': TokenType.LIGAND,
|
||||
'Pb': TokenType.LIGAND,
|
||||
'Pd': TokenType.LIGAND,
|
||||
'Pr': TokenType.LIGAND,
|
||||
'Pt': TokenType.LIGAND,
|
||||
'Re': TokenType.LIGAND,
|
||||
'Rh': TokenType.LIGAND,
|
||||
'Ru': TokenType.LIGAND,
|
||||
'S': TokenType.LIGAND,
|
||||
'Sb': TokenType.LIGAND,
|
||||
'Se': TokenType.LIGAND,
|
||||
'Si': TokenType.LIGAND,
|
||||
'Sn': TokenType.LIGAND,
|
||||
'Tb': TokenType.LIGAND,
|
||||
'Te': TokenType.LIGAND,
|
||||
'U': TokenType.LIGAND,
|
||||
'W': TokenType.LIGAND,
|
||||
'V': TokenType.LIGAND,
|
||||
'Y': TokenType.LIGAND,
|
||||
'Zn': TokenType.LIGAND,
|
||||
'ATM': TokenType.LIGAND,
|
||||
}
|
||||
|
||||
element_codes = [
|
||||
'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', # Atomic numbers 1-10
|
||||
'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', # Atomic numbers 11-20
|
||||
'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', # Atomic numbers 21-30
|
||||
'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', # Atomic numbers 31-40
|
||||
'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', # Atomic numbers 41-50
|
||||
'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', # Atomic numbers 51-60
|
||||
'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', # Atomic numbers 61-70
|
||||
'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', # Atomic numbers 71-80
|
||||
'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', # Atomic numbers 81-90
|
||||
'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', # Atomic numbers 91-100
|
||||
'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', # Atomic numbers 101-110
|
||||
'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', # Atomic numbers 111-118
|
||||
'Uue', 'Ubn', 'Ubu', 'Ubb', 'Ubt', 'Ubq', 'Ubp', 'Ubh', # Atomic numbers 119-126
|
||||
'the_element', 'of_surprise' # Non-existant elements 127-128
|
||||
]
|
||||
|
||||
|
||||
element_code_to_num = {x:i for i,x in enumerate(element_codes)}
|
||||
|
||||
|
||||
def discretize_distance_matrix(distance_matrix, num_bins=38, min_distance=3.25, max_distance=50.75):
|
||||
# Calculate the bin width
|
||||
bin_width = (max_distance - min_distance) / num_bins
|
||||
|
||||
# Discretize distances into bins
|
||||
bins = torch.floor((distance_matrix - min_distance) / bin_width).clamp_max(num_bins)
|
||||
|
||||
# Assign larger distances to the last bin
|
||||
bins = torch.where(bins < num_bins, bins, torch.tensor(num_bins))
|
||||
|
||||
return bins
|
||||
|
||||
def torch_vectorize(pyfunc):
|
||||
def f(*args, **kwargs):
|
||||
out_np = np.vectorize(pyfunc)(*args, **kwargs)
|
||||
return torch.tensor(out_np)
|
||||
return f
|
||||
|
||||
def prepare_input_af3(inputs, D, s_trans, sigma_data, random_augmentation, only_ca, device="cpu",):
|
||||
logger.debug('prepare_input_af3 input:\n' + pretty_describe_dict(inputs))
|
||||
(
|
||||
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
|
||||
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
|
||||
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
|
||||
) = inputs
|
||||
# transfer inputs to device
|
||||
B, _, N, I = msa.shape
|
||||
logger.debug('\n\n\n\n\n')
|
||||
logger.debug('prepare_input_af3 input:\n' + pretty_describe_dict(dict(
|
||||
seq=seq,
|
||||
msa=msa,
|
||||
msa_masked=msa_masked,
|
||||
msa_full=msa_full,
|
||||
mask_msa=mask_msa,
|
||||
true_crds=true_crds,
|
||||
mask_crds=mask_crds,
|
||||
idx_pdb=idx_pdb,
|
||||
xyz_t=xyz_t,
|
||||
t1d=t1d,
|
||||
mask_t=mask_t,
|
||||
xyz_prev=xyz_prev,
|
||||
mask_prev=mask_prev,
|
||||
same_chain=same_chain,
|
||||
unclamp=unclamp,
|
||||
negative=negative,
|
||||
atom_frames=atom_frames,
|
||||
bond_feats=bond_feats,
|
||||
dist_matrix=dist_matrix,
|
||||
chirals=chirals,
|
||||
ch_label=ch_label,
|
||||
symmgp=symmgp,
|
||||
task=task,
|
||||
item=item,
|
||||
)))
|
||||
|
||||
NUM_TEMPLATE_DISTOGRAM_BINS = 38
|
||||
|
||||
# Strip batch dimension
|
||||
msa = msa[0,0]
|
||||
idx_pdb = idx_pdb[0]
|
||||
ch_label = ch_label[0]
|
||||
true_crds = true_crds[0,0]
|
||||
seq = seq[0,0]
|
||||
xyz_t = xyz_t[0]
|
||||
mask_t = mask_t[0]
|
||||
mask_crds = mask_crds[0,0]
|
||||
bond_feats = bond_feats[0]
|
||||
|
||||
N_token = seq.shape[0]
|
||||
|
||||
logger.debug(f'{ch_label[:5]}')
|
||||
aa = [ChemData().num2aa[num] for num in seq]
|
||||
aa1 = [ChemData().to1letter.get(k, k) for k in aa]
|
||||
# aa_af3 = [aa_coarse_from_fine[s] for s in aa]
|
||||
|
||||
|
||||
# Converts a ChemData() sequence token to an af3 token
|
||||
@torch_vectorize
|
||||
def af3num_from_num(num):
|
||||
chemdata_code = ChemData().num2aa[num]
|
||||
coarse = aa_coarse_from_fine[chemdata_code]
|
||||
return af3_aa2num[coarse]
|
||||
|
||||
@np.vectorize
|
||||
def get_token_type(num):
|
||||
code3 = ChemData().num2aa[num]
|
||||
return aa_restype_from_fine[code3]
|
||||
|
||||
f = {}
|
||||
|
||||
### Residue level ###
|
||||
f['residue_index'] = idx_pdb
|
||||
f['token_index'] = torch.arange(N_token)
|
||||
f['asym_id'] = ch_label
|
||||
# Hacked:
|
||||
f['entity_id'] = torch.zeros(N_token)
|
||||
f['sym_id'] = within_group_unique_ids(f['entity_id'], f['asym_id'])
|
||||
# f['restype'] = F.one_hot(torch.tensor([af3_aa2num[aa_af3]]), len(af3_num2aa))
|
||||
f['restype'] = F.one_hot(af3num_from_num(seq), len(af3_num2aa))
|
||||
# token_type = torch.tensor([aa_restype_from_fine[aa]])
|
||||
# token_type = np.vectorize(aa_restype_from_fine.__getitem__)(aa)
|
||||
token_type = get_token_type(seq)
|
||||
f['is_protein'] = torch.tensor(token_type == TokenType.PROTEIN)
|
||||
f['is_rna'] = torch.tensor(token_type == TokenType.RNA)
|
||||
f['is_dna'] = torch.tensor(token_type == TokenType.DNA)
|
||||
f['is_ligand'] = torch.tensor(token_type == TokenType.LIGAND)
|
||||
|
||||
|
||||
### Atom level ###
|
||||
allatom_mask = ChemData().allatom_mask.to(device, non_blocking=True)
|
||||
|
||||
# remove symmetry dimension
|
||||
# if len(true_crds.shape) == 4:
|
||||
# true_crds = true_crds[0:1]
|
||||
# mask_crds = mask_crds[0:1]
|
||||
# true_crds = true_crds[0]
|
||||
|
||||
# want to unroll the coordinate tensors to get the full coordinates in (atoms, 3)
|
||||
is_real_atom = allatom_mask[seq].bool()
|
||||
|
||||
if only_ca:
|
||||
is_real_atom[:] = False
|
||||
is_real_atom[:, 1] = True
|
||||
|
||||
tok_idx = is_real_atom.nonzero()[:,0]
|
||||
within_tok_idx = is_real_atom.nonzero()[:,1]
|
||||
N_atom = len(tok_idx)
|
||||
f['tok_idx'] = tok_idx
|
||||
# atom_mask = mask_crds[is_real_atom]
|
||||
# t = interpolant.sample_t(D)
|
||||
|
||||
# Hacked:
|
||||
f['ref_pos'] = torch.rand((N_atom, 3))
|
||||
f['ref_mask'] = torch.arange(N_atom)
|
||||
|
||||
element = [ChemData().aa2elt[seq[tok]][within_tok] for tok, within_tok in zip(tok_idx, within_tok_idx)]
|
||||
f['ref_element'] = F.one_hot(torch.tensor([element_code_to_num[e] for e in element]), len(element_codes))
|
||||
|
||||
# Hacked:
|
||||
f['ref_charge'] = torch.zeros((N_atom))
|
||||
f['ref_atom_name_chars'] = torch.zeros((N_atom, 4, 64))
|
||||
f['ref_atom_name_chars'][:,0,0] = torch.arange(N_atom)
|
||||
|
||||
f['ref_space_uid'] = integer_tokenize(list(zip(f['asym_id'], f['residue_index'])))[tok_idx]
|
||||
|
||||
### MSA ###
|
||||
f['msa'] = F.one_hot(af3num_from_num(msa), len(af3_num2aa))
|
||||
# Hacked
|
||||
N_msa = msa.shape[0]
|
||||
f['has_deletion'] = torch.zeros((N_msa, N_token))
|
||||
f['deletion_value'] = torch.zeros((N_msa, N_token))
|
||||
f['profile'] = torch.zeros((N_token, 32))
|
||||
f['deletion_mean'] = torch.zeros((N_token))
|
||||
|
||||
### Templates ###
|
||||
N_templ = xyz_t.shape[0]
|
||||
# Hacked:
|
||||
template_seq = t1d[0].argmax(dim=-1) # [T, I]
|
||||
assert (template_seq < ChemData().NPROTAAS - 1).all() # only 20 AA + 1 UNK (No mask)
|
||||
|
||||
f['template_restype'] = F.one_hot(af3num_from_num(template_seq), len(af3_num2aa))
|
||||
|
||||
template_is_protein = torch.tensor(get_token_type(template_seq) == TokenType.PROTEIN)
|
||||
template_is_gly = template_seq == ChemData().aa2num['GLY']
|
||||
template_atom_name = np.where(template_is_gly, ' CA ', ' CB ')
|
||||
template_protein_beta_idx = torch_vectorize(lambda token, atom_name: ChemData().aa2long[token].index(atom_name))(
|
||||
template_seq[template_is_protein],
|
||||
template_atom_name[template_is_protein]
|
||||
)
|
||||
template_beta_idx = torch.full((N_templ, N_token), 0)
|
||||
template_beta_idx[template_is_protein] = template_protein_beta_idx
|
||||
template_beta_exists = torch.gather(mask_t, dim=2, index=template_beta_idx[..., None]).squeeze(-1)
|
||||
f['template_pseudo_beta_mask'] = template_beta_exists * template_is_protein # .reshape?
|
||||
f['template_backbone_frame_mask'] = mask_t[:, :, torch.tensor([0,1,2])].all(dim=-1)
|
||||
# Reshape index_tensor to match the dimensions of xyz_t except for the last dimension
|
||||
index_tensor_expanded = template_beta_idx[...,None,None].expand(-1, -1, -1, 3)
|
||||
template_pseudo_beta = torch.gather(xyz_t, dim=2, index=index_tensor_expanded).squeeze(-2) #.squeeze(-1)
|
||||
template_pseudo_beta_distogram = torch.cdist(template_pseudo_beta, template_pseudo_beta)
|
||||
template_pseudo_beta_distogram *= f['template_pseudo_beta_mask'].unsqueeze(-1) * f['template_pseudo_beta_mask'].unsqueeze(-2)
|
||||
f['template_distogram'] = discretize_distance_matrix(
|
||||
template_pseudo_beta_distogram,
|
||||
num_bins=NUM_TEMPLATE_DISTOGRAM_BINS,
|
||||
)
|
||||
|
||||
CA_IDX = 1
|
||||
template_ca = xyz_t[..., CA_IDX, :]
|
||||
template_ca_disp = template_ca.unsqueeze(-2) - template_ca.unsqueeze(-3)
|
||||
template_ca_disp_unit = template_ca_disp / torch.linalg.norm(template_ca_disp, dim=-1, keepdim=True)
|
||||
template_R, _ = rigid_from_3_points(
|
||||
xyz_t[:,:,0],
|
||||
xyz_t[:,:,1],
|
||||
xyz_t[:,:,2],
|
||||
)
|
||||
|
||||
has_ca = mask_t[..., CA_IDX]
|
||||
both_have_ca = has_ca[..., None, :] * has_ca[..., None]
|
||||
template_unit_vector = rot_vec_mul(template_R[:,:, None], template_ca_disp_unit)
|
||||
template_unit_vector[both_have_ca] = 0
|
||||
f['template_unit_vector'] = template_unit_vector
|
||||
|
||||
has_ligand_2d = (f['is_ligand'].unsqueeze(-2) + f['is_ligand'].unsqueeze(-1)).bool()
|
||||
# is_ligand_ligand = f['is_ligand'].unsqueeze(-2) * f['is_ligand'].unsqueeze(-1)
|
||||
|
||||
# Hacked (as covalent bonds are not represented in bond_feats and 2.4A filter not applied)
|
||||
f['token_bonds'] = has_ligand_2d * (bond_feats > 0)
|
||||
|
||||
X_gt_L = true_crds[is_real_atom]
|
||||
atom_mask = mask_crds[is_real_atom]
|
||||
t = sigma_data * torch.exp(-1.2 + 1.5 * torch.normal(mean=0, std=1, size=(D,)))
|
||||
|
||||
X_gt_L = centre(X_gt_L, atom_mask)
|
||||
X_gt_L = X_gt_L.tile(D,1,1)
|
||||
|
||||
if random_augmentation:
|
||||
X_gt_L = get_random_augmentation(X_gt_L, s_trans=s_trans)
|
||||
|
||||
_, L, _ = X_gt_L.shape
|
||||
t_tiled = t[:, None, None].tile(1, L, 3)
|
||||
noise = torch.normal(mean=0, std=t_tiled)
|
||||
X_noisy_L = X_gt_L + noise
|
||||
|
||||
return (
|
||||
# network input
|
||||
dict(
|
||||
X_noisy_L=X_noisy_L,
|
||||
t=t,
|
||||
f=f,
|
||||
),
|
||||
# loss input (trues)
|
||||
dict(
|
||||
X_gt_L=X_gt_L,
|
||||
crd_mask_I = is_real_atom,
|
||||
seq=seq,
|
||||
bond_feats=bond_feats,
|
||||
)
|
||||
)
|
||||
|
||||
def centre(X_L, X_exists_L):
|
||||
X_L = X_L.clone()
|
||||
X_L[X_exists_L] = X_L[X_exists_L] - torch.mean(X_L[X_exists_L], dim=-2, keepdim=True)
|
||||
X_L[~X_exists_L] = 0.0
|
||||
return X_L
|
||||
|
||||
def get_random_augmentation(X_L, s_trans):
|
||||
'''
|
||||
Inputs:
|
||||
X_L [D, L, 3]: Batched atom coordinates
|
||||
s_trans (float): standard deviation of a global translation to be applied for each
|
||||
element in the batch
|
||||
'''
|
||||
D, L, _ = X_L.shape
|
||||
R = uniform_random_rotation((D,))
|
||||
noise = s_trans * torch.normal(mean=0, std=1, size=(D,1,3))
|
||||
return rot_vec_mul(R[:,None], X_L) + noise
|
||||
|
||||
def uniform_random_rotation(size):
|
||||
# Sample random angles for rotations around X, Y, and Z axes
|
||||
theta_x = torch.rand(size) * 2 * math.pi
|
||||
theta_y = torch.rand(size) * 2 * math.pi
|
||||
theta_z = torch.rand(size) * 2 * math.pi
|
||||
|
||||
# Calculate the cosines and sines of the angles
|
||||
cos_x = torch.cos(theta_x)
|
||||
sin_x = torch.sin(theta_x)
|
||||
cos_y = torch.cos(theta_y)
|
||||
sin_y = torch.sin(theta_y)
|
||||
cos_z = torch.cos(theta_z)
|
||||
sin_z = torch.sin(theta_z)
|
||||
|
||||
# Create the rotation matrices around X, Y, and Z axes
|
||||
rotation_x = torch.stack([torch.tensor([[1, 0, 0],
|
||||
[0, c, -s],
|
||||
[0, s, c]]) for c, s in zip(cos_x, sin_x)])
|
||||
|
||||
rotation_y = torch.stack([torch.tensor([[c, 0, s],
|
||||
[0, 1, 0],
|
||||
[-s, 0, c]]) for c, s in zip(cos_y, sin_y)])
|
||||
|
||||
rotation_z = torch.stack([torch.tensor([[c, -s, 0],
|
||||
[s, c, 0],
|
||||
[0, 0, 1]]) for c, s in zip(cos_z, sin_z)])
|
||||
|
||||
# Combine the rotation matrices
|
||||
rotation_matrix = torch.matmul(rotation_z, torch.matmul(rotation_y, rotation_x))
|
||||
|
||||
return rotation_matrix
|
||||
@@ -6,6 +6,9 @@ import numpy as np
|
||||
from torch.utils import data
|
||||
from typing import Dict
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_id_lengths(
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import torch
|
||||
import tree
|
||||
import json
|
||||
from icecream import ic
|
||||
|
||||
|
||||
def debug_nans(latent_feats):
|
||||
@@ -32,3 +35,36 @@ def debug_grads(model):
|
||||
def debug_nan_params(model):
|
||||
for name, param in model.named_parameters():
|
||||
print(f"{name}: {torch.sum(param.isnan())}")
|
||||
|
||||
def pretty_describe_dict(d):
|
||||
mapped = describe_dict(d)
|
||||
mapped = tree.map_structure(str, mapped)
|
||||
return json.dumps(mapped, indent=4)
|
||||
|
||||
def describe_dict(d):
|
||||
return tree.map_structure(describe, d)
|
||||
|
||||
def describe(t: torch.Tensor):
|
||||
out = [f'type:{type(t)}']
|
||||
if hasattr(t, 'shape'):
|
||||
out.append(f'shape:{str(t.shape)}')
|
||||
if hasattr(t, 'dtype'):
|
||||
out.append(f'dtype:{str(t.dtype)}')
|
||||
if hasattr(t, 'device'):
|
||||
out.append(f'device:{str(t.device)}')
|
||||
return ' '.join(out)
|
||||
|
||||
def safe_shape(t: torch.Tensor):
|
||||
if hasattr(t, 'shape'):
|
||||
return t.shape
|
||||
return None
|
||||
|
||||
def log_in_out(f):
|
||||
def wrapped(*args, **kwargs):
|
||||
o = f(*args, **kwargs)
|
||||
ic(
|
||||
args, kwargs,
|
||||
o
|
||||
)
|
||||
return o
|
||||
return wrapped
|
||||
75
rf2aa/loggers.py
Normal file
75
rf2aa/loggers.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import logging
|
||||
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from icecream import ic
|
||||
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
from rf2aa.debug import pretty_describe_dict
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mean_over(df, grouper, metrics):
|
||||
out = df[[grouper] + metrics].groupby(grouper).mean(numeric_only=True).stack().to_dict()
|
||||
out = {f'{grouper}={grouper_v}.{metric}': v for (grouper_v, metric), v in out.items()}
|
||||
return out
|
||||
|
||||
class LitLogger(Logger):
|
||||
|
||||
def __init__(self, save_dir, use_wandb, sublogger):
|
||||
self.use_wandb = use_wandb
|
||||
if self.use_wandb:
|
||||
self.sublogger = WandbLogger(**sublogger)
|
||||
else:
|
||||
self.sublogger = CSVLogger(**sublogger)
|
||||
super().__init__()
|
||||
|
||||
def log_df(self, df, stratifications=None):
|
||||
global_step = df['global_step'].iloc[0]
|
||||
mean_over_step = df.groupby('global_step').mean(numeric_only=True)
|
||||
assert len(mean_over_step) == 1
|
||||
mean_over_step = mean_over_step.iloc[0]
|
||||
mean_over_step = mean_over_step.to_dict()
|
||||
|
||||
stratified = {}
|
||||
for groupers, values in stratifications.items():
|
||||
# TODO: enable stratification over multiple keys
|
||||
assert len(groupers) == 1
|
||||
stratified.update(mean_over(df, groupers[0], values))
|
||||
|
||||
stratified = mean_over_step | stratified
|
||||
self.sublogger.log_metrics(stratified, step=global_step)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "MyLogger"
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
# Return the experiment version, int or str.
|
||||
return "0.1"
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
# params is an argparse.Namespace
|
||||
# your code to record hyperparameters goes here
|
||||
self.sublogger.log_hyperparams(params)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
ic(step)
|
||||
# metrics is a dictionary of metric names and values
|
||||
# your code to record metrics goes here
|
||||
self.sublogger.log_metrics(metrics, step)
|
||||
|
||||
@rank_zero_only
|
||||
def save(self):
|
||||
# Optional. Any code necessary to save logger data goes here
|
||||
self.sublogger.save()
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status):
|
||||
# Optional. Any code that needs to be run after training
|
||||
# finishes goes here
|
||||
self.sublogger.finalize(status)
|
||||
File diff suppressed because it is too large
Load Diff
34
rf2aa/model/AF3_structure_wrapper.py
Normal file
34
rf2aa/model/AF3_structure_wrapper.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from torch import relu
|
||||
|
||||
from rf2aa.debug import debug_nans
|
||||
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
||||
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
||||
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
||||
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
|
||||
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
|
||||
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
|
||||
from rf2aa.training.checkpoint import create_custom_forward
|
||||
from rf2aa.util_module import Dropout
|
||||
from rf2aa.model.AF3_structure import AtomAttentionEncoder, AtomAttentionDecoder
|
||||
|
||||
|
||||
class NonEquivariantAtomEncoder(nn.Module):
|
||||
|
||||
def __init__(self, block_params):
|
||||
super().__init__()
|
||||
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
||||
self.model = AtomAttentionEncoder(**block_params)
|
||||
|
||||
|
||||
class NonEquivariantAtomDecoder(nn.Module):
|
||||
|
||||
def __init__(self, block_params):
|
||||
super().__init__()
|
||||
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
||||
self.model = AtomAttentionDecoder(**block_params)
|
||||
|
||||
56
rf2aa/pymol.py
Normal file
56
rf2aa/pymol.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import sys
|
||||
import os
|
||||
from icecream import ic
|
||||
|
||||
import xmlrpc.client as xmlrpclib
|
||||
|
||||
class XMLRPCWrapperProxy(object):
|
||||
def __init__(self, wrapped=None):
|
||||
self.name = 'cmd'
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self.wrapped, name)
|
||||
# ic(type(self), attr)
|
||||
wrapped = type(self)(attr)
|
||||
wrapped.name = name
|
||||
return wrapped
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
try:
|
||||
return self.wrapped(*args, **kw)
|
||||
except Exception as e:
|
||||
all_args = tuple(map(str, args))
|
||||
all_args += tuple(f'{k}={v}' for k,v in kw.items())
|
||||
raise Exception(f"cmd.{self.name}('{','.join(all_args)})'") from e
|
||||
|
||||
def get_cmd(pymol_url='http://localhost:9123'):
|
||||
cmd = xmlrpclib.ServerProxy(pymol_url)
|
||||
if not ('ipd' in pymol_url or 'localhost' in pymol_url):
|
||||
make_network_cmd(cmd)
|
||||
return cmd
|
||||
|
||||
cmd = None
|
||||
def init(pymol_url='http://localhost:9123'):
|
||||
global cmd
|
||||
cmd_inner = get_cmd(pymol_url)
|
||||
if cmd is None:
|
||||
cmd = XMLRPCWrapperProxy(cmd_inner)
|
||||
else:
|
||||
cmd.wrapped = cmd_inner
|
||||
|
||||
|
||||
def make_network_cmd(cmd):
|
||||
# old_load = cmd.load
|
||||
def new_load(*args, **kwargs):
|
||||
path = args[0]
|
||||
with open(path) as f:
|
||||
contents = f.read()
|
||||
# args[0] = contents
|
||||
args = (contents,) + args[1:]
|
||||
#print('writing contents')
|
||||
cmd.read_pdbstr(*args, **kwargs)
|
||||
cmd.is_network = True
|
||||
cmd.load = new_load
|
||||
|
||||
init()
|
||||
65
rf2aa/pymol_tools.py
Normal file
65
rf2aa/pymol_tools.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from rf2aa.pymol import cmd
|
||||
from rf2aa.util import writepdb
|
||||
|
||||
def clear():
|
||||
cmd.reinitialize('everything')
|
||||
cmd.delete('all')
|
||||
# cmd.do(f'cd {REPO_DIR}/pymol_config')
|
||||
cmd.do('@./pymolrc')
|
||||
|
||||
|
||||
# def pseudoatom(
|
||||
# pos: list = [0,0,0],
|
||||
# label='origin',
|
||||
# ):
|
||||
# cmd.pseudoatom(label,'', 'PS1','PSD', '1', 'P',
|
||||
# 'PSDO', 'PS', -1.0, 1, 0.0, 0.0, '',
|
||||
# '', pos)
|
||||
# # cmd.do(f'label {label}, "{label}"')
|
||||
# return label
|
||||
|
||||
def pseudoatom(
|
||||
cmd,
|
||||
pos: list = [0,0,0],
|
||||
label='origin',
|
||||
):
|
||||
cmd.pseudoatom(label,'', 'PS1','PSD', '1', 'P',
|
||||
'PSDO', 'PS', -1.0, 1, 0.0, 0.0, '',
|
||||
'', pos)
|
||||
# cmd.do(f'label {label}, "{label}"')
|
||||
return label
|
||||
|
||||
def show_origin():
|
||||
pa = pseudoatom(cmd, label='the_origin')
|
||||
cmd.center(pa)
|
||||
cmd.color('red', pa)
|
||||
cmd.set('grid_slot', -2, pa)
|
||||
|
||||
def show_pymol(
|
||||
true_crds,
|
||||
seq,
|
||||
bond_feats,
|
||||
label='unlabeled'
|
||||
):
|
||||
pdb_path = f"tmp/true_0.pdb"
|
||||
writepdb(
|
||||
pdb_path,
|
||||
true_crds,
|
||||
seq.long(),
|
||||
bond_feats=bond_feats[None],
|
||||
)
|
||||
cmd.load(os.path.abspath(pdb_path), label)
|
||||
show_origin()
|
||||
|
||||
|
||||
def to_atom37(X_L, atom_mask):
|
||||
assert X_L.shape[-1] == 3
|
||||
assert X_L.numel() / 3 == atom_mask.sum(), f'{X_L.numel()/3=} != {atom_mask.sum()=}. {X_L.shape=} {atom_mask.shape=}'
|
||||
L, _ = atom_mask.shape[-2:]
|
||||
X_I = torch.zeros(atom_mask.shape + (3,), dtype=torch.float) - 10
|
||||
X_I[atom_mask] = X_L
|
||||
return X_I
|
||||
@@ -167,3 +167,8 @@ def cmp(got, want, **kwargs):
|
||||
if dd:
|
||||
return dd
|
||||
return ''
|
||||
|
||||
def assert_cmp(got, want, **kwargs):
|
||||
diff = cmp(got, want, **kwargs)
|
||||
if diff:
|
||||
raise AssertionError(diff)
|
||||
31
rf2aa/tests/test_align.py
Normal file
31
rf2aa/tests/test_align.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
import torch
|
||||
import pytest
|
||||
from icecream import ic
|
||||
from rf2aa.alignment import weighted_rigid_align, get_rmsd
|
||||
from rf2aa.util import kabsch
|
||||
|
||||
def pseudobatched_kabsch(xyz1, xyz2):
|
||||
B = xyz1.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
out.append(kabsch(xyz1[i], xyz2[i])[0])
|
||||
return torch.stack(out)
|
||||
|
||||
def test_align():
|
||||
torch.manual_seed(0)
|
||||
|
||||
B = 9
|
||||
L = 5
|
||||
x_from = torch.rand((B, L, 3))
|
||||
x_to = torch.rand((B, L, 3))
|
||||
w = torch.ones((B, L))
|
||||
|
||||
rmsd_kabsch = pseudobatched_kabsch(x_from, x_to)
|
||||
x_from_align = weighted_rigid_align(x_from, x_to, w)
|
||||
rmsd_weighted_rigid = get_rmsd(x_to, x_from_align)
|
||||
ic(rmsd_weighted_rigid, rmsd_kabsch)
|
||||
assert (torch.abs(rmsd_weighted_rigid - rmsd_kabsch) < 1e-5).all(), f'{rmsd_weighted_rigid} != {rmsd_kabsch}'
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_align()
|
||||
165
rf2aa/tests/test_diffusion_module.py
Normal file
165
rf2aa/tests/test_diffusion_module.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import torch
|
||||
import pytest
|
||||
from icecream import ic
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
from hydra import initialize, compose
|
||||
from pl_bolts.utils import BatchGradientVerification
|
||||
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
from rf2aa.debug import pretty_describe_dict
|
||||
from rf2aa.model.AF3_structure import Model, DiffusionModule, AtomTransformer
|
||||
from rf2aa.tensor_util import assert_shape, assert_cmp
|
||||
|
||||
|
||||
def test_batch_leakage():
|
||||
|
||||
conf_overrides = []
|
||||
with initialize(version_base=None, config_path="../config/train"):
|
||||
conf = compose(config_name='af3_repro', overrides=conf_overrides)
|
||||
|
||||
c_s = conf.model.c_s
|
||||
c_z = conf.model.c_z
|
||||
|
||||
model = DiffusionModule(
|
||||
c_atom=128,
|
||||
c_atompair=16,
|
||||
c_s=c_s,
|
||||
c_z=c_z,
|
||||
**conf.model.diffusion_module)
|
||||
|
||||
verification = BatchGradientVerification(model)
|
||||
|
||||
D = 3
|
||||
I = 80
|
||||
L = 160
|
||||
C_s_inputs = conf.model.diffusion_module.diffusion_conditioning.c_s_inputs
|
||||
C_s_trunk = conf.model.c_s
|
||||
|
||||
inputs = dict(
|
||||
X_noisy_L = torch.rand((D, L, 3)),
|
||||
t = torch.rand((D,)),
|
||||
f = {
|
||||
'asym_id': torch.zeros(I),
|
||||
'residue_index': torch.arange(I),
|
||||
'entity_id': torch.zeros(I),
|
||||
'token_index': torch.arange(I),
|
||||
'sym_id': torch.zeros(I).long(),
|
||||
'tok_idx': torch.arange(L) // 2,
|
||||
'ref_pos': torch.rand((L, 3)),
|
||||
'ref_charge': torch.rand((L,)),
|
||||
'ref_mask': torch.arange(L) // 2,
|
||||
'ref_element': one_hot(torch.randint(127, (L,)), 128),
|
||||
'ref_atom_name_chars': torch.zeros((L, 4, 64)),
|
||||
'ref_space_uid': torch.zeros((L,))
|
||||
},
|
||||
S_inputs_I = torch.rand((I, C_s_inputs)),
|
||||
S_trunk_I = torch.rand((I, C_s_trunk)),
|
||||
Z_trunk_II = torch.rand((I, I, c_z))
|
||||
)
|
||||
|
||||
batched_inputs = default_input_mapping(inputs)
|
||||
assert len(batched_inputs) == 2 and batched_inputs[0].shape[0] == D, f'default input mapping (should contain X_noisy_L and t:\n' + pretty_describe_dict(batched_inputs)
|
||||
print(f'{verification.NORM_LAYER_CLASSES=}')
|
||||
|
||||
# Assert that there are no cross-batch gradients.
|
||||
# See: https://lightning-bolts.readthedocs.io/en/latest/callbacks/monitor.html
|
||||
# for details.
|
||||
valid = verification.check(
|
||||
input_array=inputs, sample_idx=0)
|
||||
|
||||
# Assert that the model produces the same output when run batched/unbatched.
|
||||
out_batched = model(**inputs)
|
||||
inputs_unbatched = inputs
|
||||
inputs_unbatched['X_noisy_L'] = inputs['X_noisy_L'][:1]
|
||||
inputs_unbatched['t'] = inputs['t'][:1]
|
||||
out_single = model(**inputs_unbatched)
|
||||
|
||||
assert_cmp(out_single, out_batched[0:1])
|
||||
ic(out_single.shape, out_batched.shape)
|
||||
|
||||
# Assert that the batch outputs are different.
|
||||
assert torch.norm(out_batched[0] - out_batched[1]) > 1
|
||||
assert valid
|
||||
|
||||
|
||||
def plot_attention_map(attn, diag=True):
|
||||
colors = ['indigo', 'yellow']
|
||||
if diag:
|
||||
attn[np.diag_indices_from(attn)] = 2
|
||||
colors = ['indigo', 'yellow', 'green']
|
||||
cmap = mcolors.ListedColormap(colors)
|
||||
plt.matshow(attn, cmap=cmap)
|
||||
plt.axis('off') # Turn off axis
|
||||
plt.show()
|
||||
|
||||
def test_sequence_local_atom_attention():
|
||||
|
||||
conf_overrides = []
|
||||
with initialize(version_base=None, config_path="../config/train"):
|
||||
conf = compose(config_name='af3_repro', overrides=conf_overrides)
|
||||
conf = conf.model.diffusion_module.atom_attention_encoder.atom_transformer
|
||||
|
||||
# Show the model's attenion map.
|
||||
show_full = False
|
||||
if show_full:
|
||||
# Get full size attention map.
|
||||
conf.l_max = 200
|
||||
atom_transformer = AtomTransformer(
|
||||
c_atom=10,
|
||||
c_atompair=11,
|
||||
**conf
|
||||
)
|
||||
Beta_lm = atom_transformer.Beta_lm
|
||||
attn = (Beta_lm == 0).long()
|
||||
plot_attention_map(attn)
|
||||
|
||||
|
||||
# Show af3 supplement-style attention map.
|
||||
show_supp = False
|
||||
if show_supp:
|
||||
atom_transformer = AtomTransformer(
|
||||
c_atom=10,
|
||||
c_atompair=11,
|
||||
l_max=200,
|
||||
n_queries=32,
|
||||
n_keys=64,
|
||||
diffusion_transformer=conf.diffusion_transformer,
|
||||
)
|
||||
Beta_lm = atom_transformer.Beta_lm
|
||||
attn = (Beta_lm == 0).long()
|
||||
plot_attention_map(attn)
|
||||
|
||||
atom_transformer = AtomTransformer(
|
||||
c_atom=10,
|
||||
c_atompair=11,
|
||||
l_max=10,
|
||||
n_queries=2,
|
||||
n_keys=4,
|
||||
diffusion_transformer=conf.diffusion_transformer,
|
||||
)
|
||||
L = 6
|
||||
Beta_lm = atom_transformer.Beta_lm
|
||||
Beta_lm = Beta_lm[:L, :L]
|
||||
|
||||
# Show small test-case attention map.
|
||||
show_test_case = False
|
||||
if show_test_case:
|
||||
plot_attention_map((Beta_lm==0).long(), diag=False)
|
||||
|
||||
o = 0
|
||||
x = -1e10
|
||||
want_Beta_lm = torch.tensor([
|
||||
[o,o,o,x,x,x],
|
||||
[o,o,o,x,x,x],
|
||||
[x,o,o,o,o,x],
|
||||
[x,o,o,o,o,x],
|
||||
[x,x,x,o,o,o],
|
||||
[x,x,x,o,o,o],
|
||||
])
|
||||
|
||||
assert_cmp(Beta_lm, want_Beta_lm)
|
||||
29
rf2aa/tests/test_dropout.py
Normal file
29
rf2aa/tests/test_dropout.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import torch
|
||||
import pytest
|
||||
from rf2aa.util_module import Dropout
|
||||
|
||||
|
||||
def test_dropout():
|
||||
torch.manual_seed(0)
|
||||
drop_row = Dropout(broadcast_dim=0, p_drop=0.5)
|
||||
d = 8
|
||||
x = torch.rand((d, d))
|
||||
x = drop_row(x)
|
||||
print('x:')
|
||||
print(x)
|
||||
|
||||
assert not torch.all(x==0)
|
||||
for i in range(d):
|
||||
row = x[i]
|
||||
if torch.any(row==0):
|
||||
assert torch.all(row==0)
|
||||
|
||||
has_all_zero_row = False
|
||||
for i in range(d):
|
||||
row = x[i]
|
||||
if torch.all(row==0):
|
||||
has_all_zero_row = True
|
||||
|
||||
assert has_all_zero_row
|
||||
|
||||
222
rf2aa/trainer_lightning.py
Normal file
222
rf2aa/trainer_lightning.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import re
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing as mp
|
||||
from icecream import ic
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import hydra
|
||||
import os
|
||||
import time
|
||||
import omegaconf
|
||||
from contextlib import nullcontext
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
import certifi
|
||||
import warnings
|
||||
import wandb
|
||||
import logging
|
||||
import tree
|
||||
import lightning as L
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
|
||||
from rf2aa.data.dataloader_adaptor_af3 import prepare_input_af3
|
||||
from rf2aa.flow_matching.interpolant import Interpolant
|
||||
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
|
||||
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads, pretty_describe_dict
|
||||
from rf2aa.training.EMA import EMA, count_parameters
|
||||
from rf2aa.loss.loss import translation_vector_field
|
||||
from rf2aa.loss.loss_factory import get_loss_and_misc
|
||||
from rf2aa.training.optimizer import add_weight_decay
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward, recycle_step_generic
|
||||
from rf2aa.model.network import RosettaFold
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
|
||||
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
|
||||
import rf2aa.util as util
|
||||
from rf2aa.util_module import XYZConverter
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.chemical import initialize_chemdata
|
||||
from rf2aa.set_seed import seed_all
|
||||
from rf2aa.model import AF3_structure
|
||||
from rf2aa.callbacks import LogMetrics, FindUnusedParameters, NetworkOutputGradSanityCheck, MonitorActivations
|
||||
from rf2aa.loggers import LitLogger
|
||||
|
||||
ic.configureOutput(includeContext=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#TODO: control environment variables from config
|
||||
# limit thread counts
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||
# Update environment variable with correct path (needed for W&B upload)
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
|
||||
## To reproduce errors
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
def get_n_params(model):
|
||||
pp=0
|
||||
for p in list(model.parameters()):
|
||||
nn=1
|
||||
for s in list(p.size()):
|
||||
nn = nn*s
|
||||
pp += nn
|
||||
return pp
|
||||
|
||||
def get_param_sizes(model):
|
||||
o = {}
|
||||
for k, p in model.named_parameters():
|
||||
o[k] = (np.array(p.size()).prod(), p.size())
|
||||
return o
|
||||
|
||||
# define the LightningModule
|
||||
class LitAF3Repro(L.LightningModule):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# self.model = torch.nn.Linear(2, 3).to(device)
|
||||
self.model = AF3_structure.Model(**self.config.model)
|
||||
print_n_params = False
|
||||
if print_n_params:
|
||||
logger.info(f'{get_n_params(self.model)=}')
|
||||
for k, v in sorted(get_param_sizes(self.model).items(), key=lambda item: item[1]):
|
||||
n_param, size = v
|
||||
# n_param = np.array(p.size()).prod()
|
||||
logger.info(f'{n_param=} {k=} {size=}')
|
||||
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
|
||||
def should_ignore(param_name):
|
||||
ignore_regexes = [
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_s_trunk\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_z\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_r\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias.ln_1\..*'),
|
||||
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.linear_output_project\..*'),
|
||||
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.ada_ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.atom_attention_decoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
]
|
||||
return any(regex.match(param_name) for regex in ignore_regexes)
|
||||
params_to_ignore = []
|
||||
for param_name, param in self.model.named_parameters():
|
||||
if should_ignore(param_name):
|
||||
params_to_ignore.append(param_name)
|
||||
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
|
||||
self.model,
|
||||
params_to_ignore
|
||||
)
|
||||
assert len(params_to_ignore)
|
||||
|
||||
self.loss = AF3_structure.Loss(**self.config.loss)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
||||
logger.debug('batch:\n' + pretty_describe_dict(batch))
|
||||
|
||||
# TODO: move data processing to dataset
|
||||
batch = tree.map_structure(lambda x: x.detach().cpu() if hasattr(x, 'cpu') else x, batch)
|
||||
network_input, loss_input = prepare_input_af3(
|
||||
batch,
|
||||
**self.config.af3_data_prep,
|
||||
)
|
||||
# TODO: move data processing to dataset
|
||||
network_input = tree.map_structure(lambda x: x.to(self.device), network_input)
|
||||
loss_input = tree.map_structure(lambda x: x.to(self.device), loss_input)
|
||||
|
||||
logger.debug('network_input:\n' + pretty_describe_dict(network_input))
|
||||
logger.debug('loss_input:\n' + pretty_describe_dict(loss_input))
|
||||
|
||||
n_cycle = random.randint(1, self.config.recycling.max_cycle)
|
||||
|
||||
X_L = self.model(
|
||||
network_input,
|
||||
n_cycle,
|
||||
no_sync=self.model.no_sync,
|
||||
)
|
||||
|
||||
loss, loss_dict, loss_dict_batched = self.loss(
|
||||
f=network_input['f'],
|
||||
t=network_input['t'],
|
||||
X_L=X_L,
|
||||
X_gt_L=loss_input['X_gt_L'],
|
||||
)
|
||||
self.log('loss', loss, prog_bar=True)
|
||||
return dict(
|
||||
loss=loss,
|
||||
X_L=X_L,
|
||||
) | loss_dict_batched | network_input | loss_input
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = getattr(torch.optim, self.config.optimizer.type)(
|
||||
self.model.parameters(),
|
||||
**self.config.optimizer.params,
|
||||
)
|
||||
scheduler = get_stepwise_decay_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=1000,
|
||||
num_steps_decay=5e4,
|
||||
decay_rate=0.95,
|
||||
)
|
||||
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
||||
|
||||
def configure_callbacks(self):
|
||||
return [
|
||||
LogMetrics(self.config, **self.config.callbacks.log_metrics),
|
||||
NetworkOutputGradSanityCheck(),
|
||||
MonitorActivations(),
|
||||
LearningRateMonitor(logging_interval='step'),
|
||||
]
|
||||
|
||||
class LitDataModule(L.LightningDataModule):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.init = partial(initialize_chemdata, config.chem_params)
|
||||
self.init()
|
||||
|
||||
def train_dataloader(self, rank=None, num_replicas=None):
|
||||
train_loader, train_sampler, valid_loaders, valid_samplers = compose_dataset(
|
||||
self.init, self.config.dataset_params, self.config.loader_params,
|
||||
rank or 0,
|
||||
num_replicas or 1,
|
||||
)
|
||||
return train_loader
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/train')
|
||||
def main(config):
|
||||
if config.autograd_detect_anomaly:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
model = LitAF3Repro(config)
|
||||
datamodule = LitDataModule(config)
|
||||
trainer_logger = LitLogger(**config.logger)
|
||||
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
every_n_train_steps=1000,
|
||||
dirpath='checkpoints',
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
logger=trainer_logger,
|
||||
log_every_n_steps=1,
|
||||
gradient_clip_val=10,
|
||||
callbacks=[model_checkpoint],
|
||||
**config.lightning.trainer
|
||||
)
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,9 @@
|
||||
import re
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing as mp
|
||||
from icecream import ic
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import hydra
|
||||
@@ -13,17 +15,21 @@ import datetime
|
||||
from datetime import timedelta
|
||||
import certifi
|
||||
import warnings
|
||||
import wandb
|
||||
import logging
|
||||
import tree
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
|
||||
from rf2aa.data.dataloader_adaptor_af3 import prepare_input_af3
|
||||
from rf2aa.flow_matching.interpolant import Interpolant
|
||||
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
|
||||
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads
|
||||
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads, pretty_describe_dict
|
||||
from rf2aa.training.EMA import EMA, count_parameters
|
||||
from rf2aa.loss.loss import translation_vector_field
|
||||
from rf2aa.loss.loss_factory import get_loss_and_misc
|
||||
from rf2aa.training.optimizer import add_weight_decay
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward
|
||||
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward, recycle_step_generic
|
||||
from rf2aa.model.network import RosettaFold
|
||||
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
|
||||
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
|
||||
@@ -32,6 +38,9 @@ from rf2aa.util_module import XYZConverter
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.chemical import initialize_chemdata
|
||||
from rf2aa.set_seed import seed_all
|
||||
from rf2aa.model import AF3_structure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#TODO: control environment variables from config
|
||||
# limit thread counts
|
||||
@@ -75,6 +84,8 @@ class Trainer:
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.config.training_params.use_amp)
|
||||
|
||||
def load_checkpoint(self, rank):
|
||||
if self.config.training_params.from_scratch:
|
||||
return False
|
||||
checkpoint_path = f"{self.output_dir}/{self.config.experiment.name}_last.pt"
|
||||
# 'checkpoint_path' takes priority ...
|
||||
if self.config.eval_params.checkpoint_path:
|
||||
@@ -178,13 +189,18 @@ class Trainer:
|
||||
|
||||
if ("SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
|
||||
world_size = int(os.environ["SLURM_NTASKS"])
|
||||
self.world_size = world_size
|
||||
rank = int (os.environ["SLURM_PROCID"])
|
||||
print ("Launched from slurm", rank, world_size)
|
||||
self.train_model(rank, world_size)
|
||||
|
||||
else:
|
||||
print ("Launched from interactive")
|
||||
world_size = torch.cuda.device_count()
|
||||
world_size = max(torch.cuda.device_count(), 1)
|
||||
if self.config.cpu_training:
|
||||
world_size = 1
|
||||
|
||||
self.world_size = world_size
|
||||
|
||||
if world_size == 0:
|
||||
print ("Error! No GPUs found!")
|
||||
@@ -195,10 +211,13 @@ class Trainer:
|
||||
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
def init_process_group(self, rank, world_size):
|
||||
gpu = rank % torch.cuda.device_count()
|
||||
gpu = rank % self.world_size
|
||||
dist.init_process_group(backend=self.config.training_params.ddp_backend, timeout=timedelta(seconds=1800), world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device("cuda:%d"%gpu)
|
||||
return gpu
|
||||
device = 'cpu'
|
||||
if torch.cuda.device_count():
|
||||
device = "cuda:%d"%gpu
|
||||
torch.cuda.set_device(device)
|
||||
return device
|
||||
|
||||
def cleanup(self):
|
||||
if dist.is_initialized():
|
||||
@@ -240,6 +259,7 @@ class Trainer:
|
||||
self.valid_loaders = valid_loaders
|
||||
|
||||
# move global information to device
|
||||
# if torch.cuda.device_count():
|
||||
self.move_constants_to_device(gpu)
|
||||
|
||||
self.construct_model(device=gpu)
|
||||
@@ -251,6 +271,7 @@ class Trainer:
|
||||
self.construct_scaler()
|
||||
start_epoch = 0
|
||||
loaded_checkpoint = self.load_checkpoint(gpu)
|
||||
logger.info(f'Loaded checkpoint: {loaded_checkpoint}')
|
||||
if loaded_checkpoint:
|
||||
start_epoch = self.checkpoint["epoch"]
|
||||
self.load_model()
|
||||
@@ -279,6 +300,7 @@ class Trainer:
|
||||
if (
|
||||
self.config.dataset_params.validate_every_n_epochs > 0
|
||||
and epoch % self.config.dataset_params.validate_every_n_epochs==0
|
||||
and (epoch!=start_epoch or self.config.dataset_params.validate_after_first_epoch)
|
||||
):
|
||||
self.valid_epoch(epoch, rank, world_size)
|
||||
|
||||
@@ -318,6 +340,20 @@ class Trainer:
|
||||
if (p.grad is not None):
|
||||
print (n, torch.max( torch.abs(p.flatten()) ), torch.max( torch.abs(p.grad.flatten()) ))
|
||||
exit(1)
|
||||
|
||||
find_no_grad_parameters = False
|
||||
if find_no_grad_parameters:
|
||||
no_grad_parameters = []
|
||||
for n,p in self.model.module.model.named_parameters():
|
||||
if p.grad is None:
|
||||
no_grad_parameters.append(n)
|
||||
|
||||
if no_grad_parameters:
|
||||
print('Parameters with grad == None:')
|
||||
for n in no_grad_parameters:
|
||||
print(n)
|
||||
print(f'Fraction with grad == None: {len(no_grad_parameters)}/{len(list(self.model.module.model.named_parameters()))}')
|
||||
|
||||
|
||||
train_time = time.time() - start_time
|
||||
|
||||
@@ -400,11 +436,12 @@ class Trainer:
|
||||
def log_intermediate_losses(self, inputs, loss_dict, n_cycle, Nex, Nepoch, runtime):
|
||||
item = inputs[-1]
|
||||
max_mem = torch.cuda.max_memory_allocated()/1e9
|
||||
print(f"Models: {Nex} of: {Nepoch} Max_Memory: {max_mem:.4f} Runtime: {runtime:.4f}")
|
||||
print(f"Models: {Nex} of: {Nepoch} Max_Memory: {max_mem:.4f}Gb Runtime: {runtime:.4f}")
|
||||
print(f"Example: {item} Recycle:{n_cycle}\n"+
|
||||
"\t".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
|
||||
#print(f"Models: {Nex} Example: {item['CHAINID']} "+" ".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
#print(f"Models: {Nex} Example: {item['CHAINID']}"+" ".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def log_validation_losses(self, dataset_name, loss_dict):
|
||||
print(f"Dataset: {dataset_name} "+
|
||||
@@ -429,9 +466,12 @@ class LegacyTrainer(Trainer):
|
||||
cb_tor = ChemData().cb_torsion_t.to(device),
|
||||
|
||||
).to(device)
|
||||
device_ids = [device]
|
||||
if device == 'cpu':
|
||||
device_ids = None
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
|
||||
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=False, broadcast_buffers=False)
|
||||
|
||||
|
||||
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
|
||||
@@ -627,11 +667,115 @@ class FlowMatchingTrainer(Trainer):
|
||||
# If using W&B, log the validation losses (note: this is only done for rank = 0)
|
||||
if self.config.log_params.use_wandb:
|
||||
wandb.log(valid_loss_dict)
|
||||
def get_n_params(model):
|
||||
pp=0
|
||||
for p in list(model.parameters()):
|
||||
nn=1
|
||||
for s in list(p.size()):
|
||||
nn = nn*s
|
||||
pp += nn
|
||||
return pp
|
||||
|
||||
def get_param_sizes(model):
|
||||
o = {}
|
||||
for k, p in model.named_parameters():
|
||||
o[k] = (np.array(p.size()).prod(), p.size())
|
||||
return o
|
||||
|
||||
class AF3Trainer(FlowMatchingTrainer):
|
||||
|
||||
def construct_model(self, device="cpu"):
|
||||
# self.model = torch.nn.Linear(2, 3).to(device)
|
||||
self.model = AF3_structure.Model(**self.config.model).to(device)
|
||||
print_n_params = False
|
||||
if print_n_params:
|
||||
logger.info(f'{get_n_params(self.model)=}')
|
||||
for k, v in sorted(get_param_sizes(self.model).items(), key=lambda item: item[1]):
|
||||
n_param, size = v
|
||||
# n_param = np.array(p.size()).prod()
|
||||
logger.info(f'{n_param=} {k=} {size=}')
|
||||
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
|
||||
def should_ignore(param_name):
|
||||
ignore_regexes = [
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_s_trunk\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_z\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_r\..*'),
|
||||
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias.ln_1\..*'),
|
||||
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.linear_output_project\..*'),
|
||||
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.ada_ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
re.compile(r'model\.diffusion_module\.atom_attention_decoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
|
||||
]
|
||||
return any(regex.match(param_name) for regex in ignore_regexes)
|
||||
params_to_ignore = []
|
||||
for param_name, param in self.model.named_parameters():
|
||||
if should_ignore(param_name):
|
||||
params_to_ignore.append(param_name)
|
||||
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
|
||||
self.model,
|
||||
params_to_ignore
|
||||
)
|
||||
assert len(params_to_ignore)
|
||||
|
||||
device_ids = [device]
|
||||
if device == 'cpu':
|
||||
device_ids = None
|
||||
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=False, broadcast_buffers=False)
|
||||
self.sampler = AllAtomSampler(self.model,
|
||||
self.config.interpolant.sampling.num_timesteps,
|
||||
self.config.interpolant.min_t,
|
||||
self.interpolant,
|
||||
self.xyz_converter,
|
||||
is_training=True)
|
||||
self.loss = AF3_structure.Loss(**self.config.loss)
|
||||
|
||||
def move_constants_to_device(self, gpu):
|
||||
self.interpolant = Interpolant(self.config.interpolant)
|
||||
self.interpolant.set_device(gpu)
|
||||
super().move_constants_to_device(gpu)
|
||||
|
||||
def train_step(self, inputs, n_cycle, no_grads=False, return_outputs=False):
|
||||
gpu = self.model.device
|
||||
|
||||
D = 12
|
||||
network_input, loss_input = prepare_input_af3(
|
||||
inputs,
|
||||
self.config.interpolant,
|
||||
D,
|
||||
)
|
||||
network_input=tree.map_structure(lambda x: x.to(gpu) if hasattr(x, 'cpu') else x, network_input)
|
||||
logger.debug('network_input:\n' + pretty_describe_dict(network_input))
|
||||
logger.debug('loss_input:\n' + pretty_describe_dict(loss_input))
|
||||
|
||||
output_i = self.model(
|
||||
network_input,
|
||||
n_cycle,
|
||||
no_sync=self.model.no_sync,
|
||||
)
|
||||
|
||||
loss, loss_dict, loss_dict_batched = self.loss(
|
||||
f=network_input['f'],
|
||||
t=network_input['t'],
|
||||
X_L=output_i,
|
||||
X_gt_L=loss_input['X_gt_L'].tile((D,1,1)).to(gpu)
|
||||
)
|
||||
return loss, loss_dict
|
||||
|
||||
def construct_optimizer(self):
|
||||
self.optimizer = getattr(torch.optim, self.config.optimizer.type)(
|
||||
self.model.parameters(),
|
||||
**self.config.optimizer.params,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/train')
|
||||
def main(config):
|
||||
if config.autograd_detect_anomaly:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
seed_all()
|
||||
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||||
|
||||
@@ -650,7 +794,8 @@ def main(config):
|
||||
trainer_factory = {
|
||||
"legacy": LegacyTrainer,
|
||||
"composed": ComposedTrainer,
|
||||
"flow_matching": FlowMatchingTrainer
|
||||
"flow_matching": FlowMatchingTrainer,
|
||||
"af3_repro": AF3Trainer
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -75,6 +75,26 @@ def recycle_step_gen(ddp_model, input, n_cycle, use_amp, nograds=False, force_de
|
||||
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
|
||||
return output_i
|
||||
|
||||
def recycle_step_generic(model, input, n_cycle, use_amp, nograds=False, force_device=None):
|
||||
'''
|
||||
Runs recycling with gradients only on final recycle.
|
||||
|
||||
Assums model has methods:
|
||||
pre_recycle: input --> recycling_input
|
||||
recycle: recycling_input --> recycling_input
|
||||
post_recycle: recycling_input --> output
|
||||
'''
|
||||
assert not nograds, 'not implemented'
|
||||
assert not use_amp, 'not implemented'
|
||||
recycling_input = model.pre_recycle(**input)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
if i_cycle < n_cycle -1:
|
||||
stack.enter_context(torch.no_grad())
|
||||
stack.enter_context(model.no_sync())
|
||||
recycling_input = model.recycle(**recycling_input)
|
||||
return model.post_recycle(**recycling_input)
|
||||
|
||||
|
||||
def run_model_forward(model, network_input, use_checkpoint=False, device="cpu"):
|
||||
""" run model forward pass, no recycling, no ddp (for tests) """
|
||||
|
||||
Reference in New Issue
Block a user