Use loggers

Add utilities for training on a single-entry dataset.

Allow validation skipping.

WIP AF3 Non-equivariant structure encoder/decoder

Add flag to force training from scratch

Force training from scratch in debug config

All modules in diffusion module implemented

Document behavior of dropout with test

Finish majority of model trunk

Convert some ModuleLists to nn.Sequential

Add RelativePositionEncoding and WIP af3_repro config

Fix ref_space_uid embedding in AtomEncoder

Put Model together with fake MSAModule and TemplateEmbedder

AF3 repro loads model.

WIP af3 data-adaptor, AF3_structure fixes

Feature initializer working

Standardize S_inputs_I

Fix pairformer stack

Forward pass working, WIP: backward pass stale reference fixing

Add dataloader_adaptor_af3.py

Backward pass working, WIP: still some unused params

Backprop working

Training runs

Add pytorch lightning training and some wandb logging

Training converging for single example.

Run:
/home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name af3_repro_single_example_small
logger.use_wandb=True af3_data_prep.D=6

Log loss

Training working for single example.

Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name
af3_repro_single_example_small_working_4 logger.use_wandb=True

on an a4000

Add test_diffusion_module.py
This commit is contained in:
Woody Ahern
2024-05-14 14:22:43 -07:00
committed by Rohith Krishna
parent d33c097ff5
commit 11101963df
27 changed files with 2734 additions and 249 deletions

View File

@@ -0,0 +1,228 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9986c90c",
"metadata": {},
"source": [
"# Dataset subsampling\n",
"\n",
"This notebook can be used to subsample a dataset pickle to debug training runs, especially for new archictures by attempting to train the architecture on a single example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "32cf6d8c",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import pickle\n",
"import gc\n",
"import copy\n",
"import numpy as np\n",
"\n",
"root_dir = os.path.join(\n",
" os.getcwd(),\n",
" '..',\n",
")\n",
"\n",
"sys.path.append(root_dir)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "960c344d",
"metadata": {},
"outputs": [],
"source": [
"# Load the existing large dataset\n",
"\n",
"with open(os.path.join(root_dir, 'rf2aa/dataset_20240318.pkl'), 'rb') as fh:\n",
" dataset_clean = pickle.load(fh)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "31f55497",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ii=1154 len(dici[\"PARTNERS\"])=1 dici[\"LEN_EXIST\"]=65\n"
]
}
],
"source": [
"# Define filters\n",
"\n",
"def get_index_passing_filter(train_dict, f):\n",
" for i, row in train_dict.iterrows():\n",
" if f(row):\n",
" return i\n",
" raise Exception('No index passing filter')\n",
"\n",
"def get_short_single_ligand(row):\n",
" if row['LEN_EXIST'] > 100:\n",
" return False\n",
" if len(row['PARTNERS']) != 1:\n",
" return False\n",
" return True\n",
"\n",
"ii = get_index_passing_filter(dataset_clean['train_dict']['sm_compl'], get_short_single_ligand)\n",
"dici = dataset_clean['train_dict']['sm_compl'].loc[ii]\n",
"print(f'{ii=} {len(dici[\"PARTNERS\"])=} {dici[\"LEN_EXIST\"]=}')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "04c40e33",
"metadata": {},
"outputs": [],
"source": [
"# Subsample the dataset\n",
"\n",
"def subsample_v2(p, n_per_dataset=1, filter_by_dataset=dict()):\n",
" assert n_per_dataset==1\n",
" train_ID_dict = p['train_ID_dict']\n",
" valid_ID_dict = p['valid_ID_dict']\n",
" weights_dict = p['weights_dict']\n",
" train_dict = p['train_dict']\n",
" valid_dict = p['valid_dict']\n",
" homo = p['homo']\n",
" chid2hash = p['chid2hash']\n",
" chid2taxid = p['chid2taxid']\n",
" chid2smpartners = p['chid2smpartners']\n",
"\n",
" all_chain_ids = set()\n",
" datasets = set(train_ID_dict.keys())\n",
" \n",
" for k, v in train_ID_dict.items():\n",
" if k in datasets:\n",
" v = v[:n_per_dataset]\n",
" dic = train_dict[k]\n",
" row_filter = filter_by_dataset.get(k, lambda x: True)\n",
" i = [get_index_passing_filter(train_dict[k], row_filter)]\n",
" dic = dic.loc[i]\n",
" v = np.array(dic.loc[i]['CLUSTER'])\n",
" \n",
" else:\n",
" v = []\n",
" dic = train_dict[k][0:0]\n",
" weights_dict[k] = []\n",
" if 'CHAINID' in dic:\n",
" chain_ids = dic['CHAINID']\n",
" for ch_id in chain_ids:\n",
" all_chain_ids.add(ch_id)\n",
" for ch_id in ch_id.split(':'):\n",
" all_chain_ids.add(ch_id)\n",
"\n",
" train_ID_dict[k] = v\n",
" train_dict[k] = dic\n",
" weights_dict[k] = weights_dict[k][:n_per_dataset]\n",
" \n",
" for k, v in valid_ID_dict.items():\n",
" v = v[:n_per_dataset]\n",
" dic = valid_dict[k]\n",
" dic = dic[dic.CLUSTER.isin(v)].reset_index(drop=True)\n",
" if 'CHAINID' in dic:\n",
" chain_ids = dic['CHAINID']\n",
" for ch_id in chain_ids:\n",
" all_chain_ids.add(ch_id)\n",
" for ch_id in ch_id.split(':'):\n",
" all_chain_ids.add(ch_id)\n",
"\n",
" valid_ID_dict[k] = v\n",
" valid_dict[k] = dic\n",
" \n",
" homo = homo.sample(n_per_dataset)\n",
" chid2hash = {k:v for k,v in chid2hash.items() if k in all_chain_ids}\n",
" chid2taxid = {k:v for k,v in chid2taxid.items() if k in all_chain_ids}\n",
" chid2smpartners = {k:v for k,v in chid2smpartners.items() if k in all_chain_ids}\n",
" return dict(\n",
" train_ID_dict=train_ID_dict,\n",
" valid_ID_dict=valid_ID_dict,\n",
" weights_dict=weights_dict,\n",
" train_dict=train_dict,\n",
" valid_dict=valid_dict,\n",
" homo=homo,\n",
" chid2hash=chid2hash,\n",
" chid2taxid=chid2taxid,\n",
" chid2smpartners=chid2smpartners,\n",
" )\n",
"\n",
"dataset = copy.deepcopy(dataset_clean)\n",
"dataset_subsampled = subsample_v2(dataset, filter_by_dataset={'sm_compl': get_short_single_ligand})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3a4c76c4",
"metadata": {},
"outputs": [],
"source": [
"# Override the train dict with the valid dict\n",
"def use_train_as_valid(dataset):\n",
" dataset['valid_ID_dict'] = {k: [] for k,v in dataset['valid_ID_dict'].items()}\n",
" dataset['valid_ID_dict'].update(dataset['train_ID_dict'])\n",
" dataset['valid_dict'] = {k: v.iloc[0:0] for k,v in dataset['valid_dict'].items()}\n",
" dataset['valid_dict'].update(dataset['train_dict'])\n",
"\n",
"use_train_as_valid(dataset_subsampled)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c7c8cd07",
"metadata": {},
"outputs": [],
"source": [
"# Write the subsampled dataset pickle, which can now be used to debug architectures\n",
"# by specifying loader_param.datapkl=YOUR_PKL_PATH\n",
"\n",
"dataset_subsampled_path = os.path.join(root_dir, 'rf2aa/subsampled/dataset_20240318_n-2.pkl')\n",
"assert not os.path.exists(dataset_subsampled_path)\n",
"with open(dataset_subsampled_path, 'wb') as fh:\n",
" pickle.dump(dataset_subsampled, fh)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "343b8cca",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

47
rf2aa/alignment.py Normal file
View File

@@ -0,0 +1,47 @@
import logging
import numpy as np
import torch
from rf2aa.debug import pretty_describe_dict
logger = logging.getLogger(__name__)
def weighted_rigid_align(
X_L, # [B, L, 3]
X_gt_L, # [B, L, 3]
w_L, # [B, L]
):
'''
Weighted rigid body alignment of X_L onto X_gt_L
Returns:
X_align_L: [B, L, 3]
'''
assert X_L.shape == X_gt_L.shape
assert X_L.shape[:-1] == w_L.shape
u_X = torch.mean(X_L * w_L.unsqueeze(-1), dim=-2) / torch.mean(w_L, dim=-1, keepdim=True)
u_X_gt = torch.mean(X_gt_L * w_L.unsqueeze(-1), dim=-2) / torch.mean(w_L, dim=-1, keepdim=True)
X_L = X_L - u_X.unsqueeze(-2)
X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
# Computation of the covariance matrix
C = torch.transpose(X_gt_L, -1, -2) @ X_L
U, S, V = torch.linalg.svd(C)
R = U @ V
B, _, _ = X_L.shape
F = torch.eye(3,3, device=X_L.device)[None].tile((B,1,1,))
F[...,-1, -1] = torch.sign(torch.linalg.det(R))
R = U @ F @ V
X_align_L = X_L @ R.transpose(-1, -2) + u_X_gt.unsqueeze(-2)
return X_align_L.detach()
def get_rmsd(xyz1, xyz2, eps=1e-4):
L = xyz1.shape[-2]
rmsd = torch.sqrt(torch.sum((xyz2-xyz1)*(xyz2-xyz1), axis=(-1, -2)) / L + eps)
return rmsd

200
rf2aa/callbacks.py Normal file
View File

@@ -0,0 +1,200 @@
import os
import csv
import shutil
import logging
from collections import defaultdict
import json
import numpy as np
from scipy.stats import norm
from icecream import ic
import torch
import torch.nn.functional as F
import tree
import pandas as pd
import numpy as np
from lightning.pytorch.callbacks import Callback
from lightning import Trainer, LightningModule
from lightning_fabric.loggers.csv_logs import _ExperimentWriter
from rf2aa.tensor_util import apply_to_tensors
from rf2aa.debug import pretty_describe_dict
from rf2aa.model.AF3_structure import Loss
from rf2aa import pymol_tools
logger = logging.getLogger(__name__)
def flatten_dictionary(dictionary, parent_key='', separator='.'):
flattened_dict = {}
for key, value in dictionary.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
flattened_dict.update(flatten_dictionary(value, new_key, separator))
else:
flattened_dict[new_key] = value
return flattened_dict
class LogMetrics(Callback):
def __init__(self, config):
super().__init__()
self.config = config
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int) -> None:
logger.debug('on_train_batch_end outputs:\n' + pretty_describe_dict(outputs))
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
o = {}
stratifications = defaultdict(list)
for metric in [diffusion_losses]:
metric_d, stratification_keys = metric(self.config, outputs)
stratifications[stratification_keys].extend(metric_d.keys())
o.update(metric_d)
o['t'] = outputs['t']
o['t_quantile_4'] = get_t_quantiles(outputs['t'], self.config.loss.sigma_data, 4)
df = pd.DataFrame.from_dict(o)
df = df.reindex(sorted(df.columns), axis=1)
D, = outputs['t'].shape
df['batch_idx'] = batch_idx
df['data_idx'] = np.arange(D)
df['global_step'] = trainer.global_step
trainer.logger.log_df(df, stratifications=stratifications)
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
def diffusion_losses(config, outputs):
loss = Loss(**config.loss)
loss_dict_by_type = {}
t = outputs['t']
X_noisy_L = outputs['X_noisy_L']
sigma_data = 16
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[...,None,None] * X_noisy_L
sigma_gt = torch.var(outputs['X_gt_L'], dim=(1,2))**0.5
for input_type, X_L in (
('pred', outputs['X_L']),
# ('input', outputs['X_noisy_L']),
('true', outputs['X_gt_L']),
('null_pred', null_pred),
):
l_total, _, loss_dict_batched = loss(
outputs['f'],
X_L,
outputs['X_gt_L'],
outputs['t'],
)
# loss_dict_by_type[input_type] = loss_dict_batched
loss_dict_batched_prefixed = {f'{k}.{input_type}':v for k,v in loss_dict_batched.items()}
loss_dict_by_type.update(loss_dict_batched_prefixed)
# Correcting for EDM : AF3 lambda conversion
edm_corr = (t+loss.sigma_data)**2 / (t*loss.sigma_data)**2
loss_dict_batched_edm = {k:v * edm_corr for k,v in loss_dict_batched.items()}
loss_dict_batched_prefixed_edm = {f'{k}_edm.{input_type}':v for k,v in loss_dict_batched_edm.items()}
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
# Correcting for Var(gt) != sigma_data
expected_loss_gt = 1 / (loss.sigma_data**2 + t**2) * (loss.sigma_data**2 + t**2 * sigma_gt**2 / loss.sigma_data**2)
loss_dict_batched_edm_gt_corr = {k: edm_corr * v / expected_loss_gt for k,v in loss_dict_batched.items()}
loss_dict_batched_prefixed_edm = {f'{k}_edm_gt_corr.{input_type}':v for k,v in loss_dict_batched_edm_gt_corr.items()}
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
o = flatten_dictionary(loss_dict_by_type)
o['pred_over_null_pred'] = o['diffusion_loss.pred'] / o['diffusion_loss.null_pred']
o['pred_over_null_pred_norm'] = o['diffusion_loss_edm_gt_corr.pred'] / o['diffusion_loss_edm_gt_corr.null_pred']
return o, ('t_quantile_4',)
def get_normal_quantiles(n):
# Generate n evenly spaced probabilities between 0 and 1
probabilities = np.linspace(0, 1, n)
# Use the percent point function (inverse CDF) of the standard normal distribution
return norm.ppf(probabilities)
def get_t_quantiles(t, sigma_data, n):
bins = sigma_data * np.exp(-1.2 + 1.5 * get_normal_quantiles(n+1))
t_binned_list = []
for t in t:
t_bin = np.digitize(t, bins) - 1
bin_start = bins[t_bin]
bin_end = bins[t_bin+1]
t_binned = f't=[{bin_start:.2f},{bin_end:.2f})'
t_binned_list.append(t_binned)
return t_binned_list
class NetworkOutputGradSanityCheck(Callback):
def __init__(self, call_n_times=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.call_n_times = call_n_times
self.call_count = 0
def on_after_backward(self, trainer, pl_module):
if self.call_count < self.call_n_times:
self.call_count += 1
r_projection_weight = pl_module.model.model.diffusion_module.atom_attention_decoder.to_r_update[1].weight
ic(
torch.linalg.norm(r_projection_weight) if r_projection_weight is not None else None,
torch.linalg.norm(r_projection_weight.grad) if r_projection_weight.grad is not None else None,
)
class MonitorActivations(Callback):
def make_hook(self, label):
def hook(module, args, kwargs, output):
activation_metrics = {
f'{label}:inter_batch_cosine_similarity': F.cosine_similarity(
torch.flatten(output[0]),
torch.flatten(output[1]),
dim=0,
),
f'{label}:inter_batch_cosine_similarity': F.cosine_similarity(
torch.flatten(output[0]),
torch.flatten(output[1]),
dim=0,
),
f'{label}:intra_batch_cosine_similarity_to_elem_0': F.cosine_similarity(
output[0][0:1],
output[0],
).mean(),
}
self.log_dict(activation_metrics)
return hook
def setup(self, trainer, pl_module, stage):
self.pl_module = pl_module
self.trainer = trainer
pl_module.model.model.diffusion_module.atom_attention_decoder.register_forward_hook(
self.make_hook(
'diffusion_module.atom_attention_decoder',
),
with_kwargs=True
)
class FindUnusedParameters(Callback):
def __init__(self, only_once=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.only_once = only_once
self.called = False
def on_after_backward(self, trainer, pl_module):
if self.called and self.only_once:
return
self.called=True
# Calculate unused parameters after each batch
unused_params = [name for name, param in pl_module.named_parameters() if param.grad is None]
# Log unused parameters
logging.info(f'global_step={pl_module.global_step}: parameters with no gradient: {json.dumps(unused_params, indent=4)}')
if unused_params:
raise Exception('storp')

View File

@@ -0,0 +1,193 @@
defaults:
# - af3
- rf2aa
- _self_
experiment:
name: rf2aa-af3-repro
trainer: af3_repro
output_dir: null
loss:
sigma_data: ${model.diffusion_module.sigma_data}
alpha_dna: 5
alpha_rna: 5
alpha_ligand: 10
edm_lambda: False
se3_invariant_loss: True
ddp_params:
batch_size: 1
loader_params:
p_msa_mask: 0.0
interpolant:
sigma_data: 16
min_t: 1e-2
separate_t: False
provide_kappa: False
hierarchical_t: False
codesign_separate_t: False
codesign_forward_fold_prop: 0.0
codesign_inverse_fold_prop: 0.0
twisting:
use: False
rots:
corrupt: True
train_schedule: linear
sample_schedule: linear
exp_rate: 10
trans:
corrupt: True
batch_ot: True
train_schedule: linear
sample_schedule: linear
sample_temp: 1.0
vpsde_bmin: 0.1
vpsde_bmax: 20.0
potential: null
potential_t_scaling: False
rog:
weight: 10.0
cutoff: 5.0
aatypes:
corrupt: False
schedule: linear
schedule_exp_rate: 10
temp: 1.0
noise: 0.0
do_purity: False
train_extra_mask: 0.0
interpolant_type: masking
num_tokens: 80
sampling:
num_timesteps: 20
do_sde: False
self_condition: False
model_globals:
l_max: 1000
model:
c_s: 384
c_z: 128
c_atom: 128
c_atompair: 16
c_s_inputs: 449
feature_initializer:
c_s_inputs: ${model.c_s_inputs}
input_feature_embedder:
features:
- restype
- profile
- deletion_mean
atom_attention_encoder:
c_token: 384
c_atom_1d_features: 389
c_tokenpair: ${model.c_z}
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model_globals.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
relative_position_encoding:
r_max: 32
s_max: 2
recycler:
n_pairformer_blocks: 48
pairformer_block:
p_drop: 0.25
c: 128
attention_pair_bias:
n_head: 16
template_embedder:
n_block: 2
c: 64
msa_module:
n_block: 4
c_m: 64
diffusion_module:
sigma_data: ${interpolant.sigma_data}
c_token: 768
f_pred: edm
diffusion_conditioning:
c_s_inputs: ${model.c_s_inputs}
c_t_embed: 256
relative_position_encoding:
r_max: 32
s_max: 2
atom_attention_encoder:
c_tokenpair: ${model.c_z}
c_atom_1d_features: 389
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model_globals.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
diffusion_transformer:
n_block: 24
diffusion_transformer_block:
n_head: 16
atom_attention_decoder:
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model_globals.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
optimizer:
type: Adam
params:
lr: 1.8e-3
betas: [0.9, 0.95]
eps: 1.0e-8
logger:
save_dir: csv_logs
use_wandb: False
sublogger:
project: af3-debug
callbacks:
log_metrics: {}
af3_data_prep:
D: 12
sigma_data: ${model.diffusion_module.sigma_data}
s_trans: 1
random_augmentation: True
only_ca: False
recycling:
max_cycle: 4
lightning:
trainer:
accumulate_grad_batches: 25

View File

@@ -0,0 +1,63 @@
defaults:
- af3_repro
- _self_
loader_params:
datapkl: /home/ahern/projects/RF2-allatom/rf2aa/subsampled/dataset_20240318_n-1.pkl
no_match_okay: True
training_params:
from_scratch: True
learning_rate_schedule:
decay_rate: 1.
dataset_params:
n_train: 500
validate_every_n_epochs: 100
validate_after_first_epoch: False
fraction_pdb: 0.
fraction_fb: 0.
fraction_compl: 0.
fraction_neg_compl: 0.
fraction_na_compl: 0.
fraction_neg_na_compl: 0.
fraction_distil_tf: 0.
fraction_tf: 0.
fraction_neg_tf: 0.
fraction_rna: 0.
fraction_dna: 0.
fraction_sm_compl: 1.
fraction_metal_compl: 0.
fraction_sm_compl_multi: 0.
fraction_sm_compl_covale: 0.
fraction_sm: 0.
fraction_atomize_pdb: 0.
fraction_atomize_complex: 0.
fraction_sm_compl_asmb: 0.
n_valid_pdb: 0
n_valid_homo: 0
n_valid_dslf: 0
n_valid_compl: 0
n_valid_neg_compl: 0
n_valid_na_compl: 0
n_valid_neg_na_compl: 0
n_valid_distil_tf: 0
n_valid_tf: 0
n_valid_neg_tf: 0
n_valid_rna: 0
n_valid_dna: 0
n_valid_sm_compl: 2
n_valid_metal_compl: 0
n_valid_sm_compl_multi: 0
n_valid_sm_compl_covale: 0
n_valid_sm_compl_strict: 0
n_valid_sm: 0
n_valid_atomize_pdb: 0
n_valid_atomize_complex: 0
n_valid_sm_compl_asmb: 0
n_valid_fb: 0
log_params:
use_wandb: true
wandb_project: rf2aa-debug

View File

@@ -0,0 +1,32 @@
defaults:
- af3_repro_single_example_small
- _self_
cpu_training: true
training_params:
ddp_backend: gloo
model:
recycler:
n_pairformer_blocks: 2
template_embedder:
n_block: 2
msa_module:
n_block: 2
diffusion_module:
atom_attention_encoder:
atom_transformer:
diffusion_transformer:
n_block: 2
diffusion_transformer:
n_block: 2
atom_attention_decoder:
atom_transformer:
diffusion_transformer:
n_block: 2
logger:
sublogger:
project: af3-debug-cpu
autograd_detect_anomaly: true

View File

@@ -0,0 +1,28 @@
defaults:
- af3_repro_single_example
- _self_
model:
recycler:
n_pairformer_blocks: 10
recycling:
max_cycle: 1
loader_params:
maxcycle: 1
# template_embedder:
# n_block: 2
# msa_module:
# n_block: 2
# diffusion_module:
# atom_attention_encoder:
# atom_transformer:
# diffusion_transformer:
# n_block: 2
# diffusion_transformer:
# n_block: 2
# atom_attention_decoder:
# atom_transformer:
# diffusion_transformer:
# n_block: 2

View File

@@ -0,0 +1,12 @@
# Status: unvetted
defaults:
- af3_repro_single_example_small
- _self_
af3_data_prep:
D: 12
lightning:
trainer:
accumulate_grad_batches: 1

View File

@@ -12,6 +12,7 @@ model:
legacy_model: null
dataset_params:
validate_every_n_epochs: 0
validate_after_first_epoch: False
fraction_pdb: 0
fraction_fb: 0
fraction_compl: 0
@@ -89,6 +90,7 @@ loader_params:
min_metal_contacts: 0
min_metal_contact_dist: 2.6
sampler_class: DistributedWeightedSampler
no_match_okay: False # If False, asserts that the dataset params in the cached dataset pickle match the config dataset params
dataloader_kwargs:
shuffle: False
num_workers: 0
@@ -146,3 +148,12 @@ chem_params:
use_lj_params_for_atoms: False
metrics: ["mean_pae", "mean_plddt"]
hydra:
job_logging:
formatters:
simple:
format: '[%(asctime)s][%(filename)25s:%(lineno)4s][%(name)30s:%(funcName)30s()][%(levelname)s] %(message)s'
cpu_training: false
autograd_detect_anomaly: false

View File

@@ -0,0 +1,61 @@
defaults:
- generative_refinement
- _self_
loader_params:
datapkl: /home/ahern/projects/RF2-allatom/rf2aa/subsampled/dataset_20240318_n-1.pkl
no_match_okay: True
dataset_params:
n_train: 500
validate_every_n_epochs: 100
validate_after_first_epoch: False
fraction_pdb: 0.
fraction_fb: 0.
fraction_compl: 0.
fraction_neg_compl: 0.
fraction_na_compl: 0.
fraction_neg_na_compl: 0.
fraction_distil_tf: 0.
fraction_tf: 0.
fraction_neg_tf: 0.
fraction_rna: 0.
fraction_dna: 0.
fraction_sm_compl: 1.
fraction_metal_compl: 0.
fraction_sm_compl_multi: 0.
fraction_sm_compl_covale: 0.
fraction_sm: 0.
fraction_atomize_pdb: 0.
fraction_atomize_complex: 0.
fraction_sm_compl_asmb: 0.
n_valid_pdb: 0
n_valid_homo: 0
n_valid_dslf: 0
n_valid_compl: 0
n_valid_neg_compl: 0
n_valid_na_compl: 0
n_valid_neg_na_compl: 0
n_valid_distil_tf: 0
n_valid_tf: 0
n_valid_neg_tf: 0
n_valid_rna: 0
n_valid_dna: 0
n_valid_sm_compl: 2
n_valid_metal_compl: 0
n_valid_sm_compl_multi: 0
n_valid_sm_compl_covale: 0
n_valid_sm_compl_strict: 0
n_valid_sm: 0
n_valid_atomize_pdb: 0
n_valid_atomize_complex: 0
n_valid_sm_compl_asmb: 0
n_valid_fb: 0
log_params:
use_wandb: True
wandb_project: rf2aa-debug
training_params:
from_scratch: True

View File

@@ -1,8 +1,12 @@
import numpy as np
from icecream import ic
import pickle
import torch.utils.data as data
from collections import OrderedDict
import os
import logging
logger = logging.getLogger(__name__)
from rf2aa.data.data_loader import get_train_valid_set, loader_pdb, loader_complex, loader_na_complex, \
loader_distil_tf, loader_tf_complex, loader_fb, loader_dna_rna, \
@@ -138,7 +142,7 @@ def get_distilled_dataset(dataset_params, loader_params):
chid2hash,
chid2taxid,
chid2smpartners,
) = get_train_valid_set(loader_params)
) = get_train_valid_set(loader_params, no_match_okay=loader_params['no_match_okay'])
# define atomize_pdb train/valid sets, which use the same examples as pdb set
train_ID_dict["atomize_pdb"] = train_ID_dict["pdb"]

View File

@@ -10,6 +10,8 @@ from itertools import permutations
from typing import Dict, Optional, Tuple, List, Set, Any
from pathlib import Path
from os.path import exists
import logging
import traceback
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_dir)
@@ -39,7 +41,14 @@ from rf2aa.util import get_nxgraph, get_atom_frames, get_bond_feats, get_protein
reindex_protein_feats_after_atomize, get_residue_contacts, atomize_discontiguous_residues, pop_protein_feats, \
is_atom, get_atom_template_indices, reassign_symmetry_after_cropping, expand_xyz_sm_to_ntotal, Ls_from_same_chain_2d, \
is_protein, is_nucleic, is_RNA, is_DNA, is_atom
from rf2aa.data.cluster_dataset import cluster_factory
logger = logging.getLogger(__name__)
try:
from rf2aa.data.cluster_dataset import cluster_factory
except Exception as e:
logger.warning(f'Failed to import cluster_factory from rf2aa.data.cluster_dataset: if you are rebuilding the dataset .pkl expect failure: ' + repr(e))
assert "rf2aa" in os.path.abspath(cifutils.__file__)

View File

@@ -0,0 +1,565 @@
import os
import logging
import math
import numpy as np
import torch
import torch.nn.functional as F
from icecream import ic
from rf2aa.kinematics import xyz_to_t2d
from rf2aa.symmetry import symm_subunit_matrix, find_symm_subs
from rf2aa.util import is_atom, \
Ls_from_same_chain_2d, xyz_t_to_frame_xyz, get_prot_sm_mask
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.flow_matching import data_transforms
from rf2aa.debug import pretty_describe_dict
from rf2aa.util import rigid_from_3_points
from rf2aa.flow_matching.rigid_utils import rot_vec_mul
from rf2aa.set_seed import seed_all
from rf2aa.util import writepdb
from rf2aa import pymol_tools
from rf2aa.pymol import cmd
logger = logging.getLogger(__name__)
def within_group_unique_ids(group_ids, element_ids):
# Initialize a dictionary to store unique element mappings for each group
unique_mappings = {}
unique_id_counter = 0
# Initialize a list to store the resulting within-group unique ids
within_group_unique = []
# Iterate over the group ids and element ids simultaneously
for group_id, element_id in zip(group_ids, element_ids):
# Check if the current group_id is already in the unique_mappings dictionary
if group_id in unique_mappings:
# If the element_id is already mapped to a unique id within this group, use it
if element_id in unique_mappings[group_id]:
within_group_unique.append(unique_mappings[group_id][element_id])
# If the element_id is not yet mapped to a unique id within this group, assign a new unique id
else:
unique_mappings[group_id][element_id] = unique_id_counter
within_group_unique.append(unique_id_counter)
unique_id_counter += 1
# If the current group_id is not yet in the unique_mappings dictionary, add it
else:
unique_mappings[group_id] = {element_id: unique_id_counter}
within_group_unique.append(unique_id_counter)
unique_id_counter += 1
# Convert the resulting list to a PyTorch tensor
within_group_unique_tensor = torch.tensor(within_group_unique)
return within_group_unique_tensor
def integer_tokenize(iterable):
# Create a dictionary mapping unique elements to integers
unique_elements = list(set(iterable))
mapping = {element: i for i, element in enumerate(unique_elements)}
# Convert iterable to integer tensor
int_tensor = torch.tensor([mapping[element] for element in iterable])
return int_tensor
af3_num2aa = [
# Amino acids:
'ALA','ARG','ASN','ASP','CYS',
'GLN','GLU','GLY','HIS','ILE',
'LEU','LYS','MET','PHE','PRO',
'SER','THR','TRP','TYR','VAL',
'UNK', # 20 + 1
# DNA
'DA','DC','DG','DT', 'DUNK',
# RNA
'RA','RC','RG',' RU', 'RUNK',
# GAP
'GAP',
]
af3_aa2num = {x:i for i,x in enumerate(af3_num2aa)}
aa_coarse_from_fine = {
'ALA':'ALA',
'ARG':'ARG',
'ASN':'ASN',
'ASP':'ASP',
'CYS':'CYS',
'GLN':'GLN',
'GLU':'GLU',
'GLY':'GLY',
'HIS':'HIS',
'ILE':'ILE',
'LEU':'LEU',
'LYS':'LYS',
'MET':'MET',
'PHE':'PHE',
'PRO':'PRO',
'SER':'SER',
'THR':'THR',
'TRP':'TRP',
'TYR':'TYR',
'VAL':'VAL',
'UNK':'UNK',
# 'MAS':'MAS',
' DA':'DA',
' DC':'DC',
' DG':'DG',
' DT':'DT',
' DX':'DUNK',
' RA':'RA',
' RC':'RC',
' RG':'RG',
' RU':'RU',
' RX':'RUNK',
# 'HIS_D':'UNK',
'Al':'UNK',
'As':'UNK',
'Au':'UNK',
'B':'UNK',
'Be':'UNK',
'Br':'UNK',
'C':'C',
'Ca':'UNK',
'Cl':'UNK',
'Co':'UNK',
'Cr':'UNK',
'Cu':'UNK',
'F':'UNK',
'Fe':'UNK',
'Hg':'UNK',
'I':'UNK',
'Ir':'UNK',
'K':'UNK',
'Li':'UNK',
'Mg':'UNK',
'Mn':'UNK',
'Mo':'UNK',
'N':'UNK',
'Ni':'UNK',
'O':'UNK',
'Os':'UNK',
'P':'UNK',
'Pb':'UNK',
'Pd':'UNK',
'Pr':'UNK',
'Pt':'UNK',
'Re':'UNK',
'Rh':'UNK',
'Ru':'UNK',
'S':'UNK',
'Sb':'UNK',
'Se':'UNK',
'Si':'UNK',
'Sn':'UNK',
'Tb':'UNK',
'Te':'UNK',
'U':'UNK',
'W':'UNK',
'V':'UNK',
'Y':'UNK',
'Zn':'UNK',
'ATM':'ATM'
}
from enum import Enum
class TokenType(Enum):
PROTEIN = 1
DNA = 2
RNA = 3
LIGAND = 4
aa_restype_from_fine = {
'ALA': TokenType.PROTEIN,
'ARG': TokenType.PROTEIN,
'ASN': TokenType.PROTEIN,
'ASP': TokenType.PROTEIN,
'CYS': TokenType.PROTEIN,
'GLN': TokenType.PROTEIN,
'GLU': TokenType.PROTEIN,
'GLY': TokenType.PROTEIN,
'HIS': TokenType.PROTEIN,
'ILE': TokenType.PROTEIN,
'LEU': TokenType.PROTEIN,
'LYS': TokenType.PROTEIN,
'MET': TokenType.PROTEIN,
'PHE': TokenType.PROTEIN,
'PRO': TokenType.PROTEIN,
'SER': TokenType.PROTEIN,
'THR': TokenType.PROTEIN,
'TRP': TokenType.PROTEIN,
'TYR': TokenType.PROTEIN,
'VAL': TokenType.PROTEIN,
'UNK': TokenType.PROTEIN,
# 'MAS':'MAS',
' DA': TokenType.DNA,
' DC': TokenType.DNA,
' DG': TokenType.DNA,
' DT': TokenType.DNA,
' DX': TokenType.DNA,
' RA': TokenType.RNA,
' RC': TokenType.RNA,
' RG': TokenType.RNA,
' RU': TokenType.RNA,
' RX': TokenType.RNA,
# 'HIS_D':'UNK',
'Al': TokenType.LIGAND,
'As': TokenType.LIGAND,
'Au': TokenType.LIGAND,
'B': TokenType.LIGAND,
'Be': TokenType.LIGAND,
'Br': TokenType.LIGAND,
'C': TokenType.LIGAND,
'Ca': TokenType.LIGAND,
'Cl': TokenType.LIGAND,
'Co': TokenType.LIGAND,
'Cr': TokenType.LIGAND,
'Cu': TokenType.LIGAND,
'F': TokenType.LIGAND,
'Fe': TokenType.LIGAND,
'Hg': TokenType.LIGAND,
'I': TokenType.LIGAND,
'Ir': TokenType.LIGAND,
'K': TokenType.LIGAND,
'Li': TokenType.LIGAND,
'Mg': TokenType.LIGAND,
'Mn': TokenType.LIGAND,
'Mo': TokenType.LIGAND,
'N': TokenType.LIGAND,
'Ni': TokenType.LIGAND,
'O': TokenType.LIGAND,
'Os': TokenType.LIGAND,
'P': TokenType.LIGAND,
'Pb': TokenType.LIGAND,
'Pd': TokenType.LIGAND,
'Pr': TokenType.LIGAND,
'Pt': TokenType.LIGAND,
'Re': TokenType.LIGAND,
'Rh': TokenType.LIGAND,
'Ru': TokenType.LIGAND,
'S': TokenType.LIGAND,
'Sb': TokenType.LIGAND,
'Se': TokenType.LIGAND,
'Si': TokenType.LIGAND,
'Sn': TokenType.LIGAND,
'Tb': TokenType.LIGAND,
'Te': TokenType.LIGAND,
'U': TokenType.LIGAND,
'W': TokenType.LIGAND,
'V': TokenType.LIGAND,
'Y': TokenType.LIGAND,
'Zn': TokenType.LIGAND,
'ATM': TokenType.LIGAND,
}
element_codes = [
'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', # Atomic numbers 1-10
'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', # Atomic numbers 11-20
'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', # Atomic numbers 21-30
'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', # Atomic numbers 31-40
'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', # Atomic numbers 41-50
'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', # Atomic numbers 51-60
'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', # Atomic numbers 61-70
'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', # Atomic numbers 71-80
'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', # Atomic numbers 81-90
'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', # Atomic numbers 91-100
'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', # Atomic numbers 101-110
'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', # Atomic numbers 111-118
'Uue', 'Ubn', 'Ubu', 'Ubb', 'Ubt', 'Ubq', 'Ubp', 'Ubh', # Atomic numbers 119-126
'the_element', 'of_surprise' # Non-existant elements 127-128
]
element_code_to_num = {x:i for i,x in enumerate(element_codes)}
def discretize_distance_matrix(distance_matrix, num_bins=38, min_distance=3.25, max_distance=50.75):
# Calculate the bin width
bin_width = (max_distance - min_distance) / num_bins
# Discretize distances into bins
bins = torch.floor((distance_matrix - min_distance) / bin_width).clamp_max(num_bins)
# Assign larger distances to the last bin
bins = torch.where(bins < num_bins, bins, torch.tensor(num_bins))
return bins
def torch_vectorize(pyfunc):
def f(*args, **kwargs):
out_np = np.vectorize(pyfunc)(*args, **kwargs)
return torch.tensor(out_np)
return f
def prepare_input_af3(inputs, D, s_trans, sigma_data, random_augmentation, only_ca, device="cpu",):
logger.debug('prepare_input_af3 input:\n' + pretty_describe_dict(inputs))
(
seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb,
xyz_t, t1d, mask_t, xyz_prev, mask_prev, same_chain, unclamp, negative,
atom_frames, bond_feats, dist_matrix, chirals, ch_label, symmgp, task, item
) = inputs
# transfer inputs to device
B, _, N, I = msa.shape
logger.debug('\n\n\n\n\n')
logger.debug('prepare_input_af3 input:\n' + pretty_describe_dict(dict(
seq=seq,
msa=msa,
msa_masked=msa_masked,
msa_full=msa_full,
mask_msa=mask_msa,
true_crds=true_crds,
mask_crds=mask_crds,
idx_pdb=idx_pdb,
xyz_t=xyz_t,
t1d=t1d,
mask_t=mask_t,
xyz_prev=xyz_prev,
mask_prev=mask_prev,
same_chain=same_chain,
unclamp=unclamp,
negative=negative,
atom_frames=atom_frames,
bond_feats=bond_feats,
dist_matrix=dist_matrix,
chirals=chirals,
ch_label=ch_label,
symmgp=symmgp,
task=task,
item=item,
)))
NUM_TEMPLATE_DISTOGRAM_BINS = 38
# Strip batch dimension
msa = msa[0,0]
idx_pdb = idx_pdb[0]
ch_label = ch_label[0]
true_crds = true_crds[0,0]
seq = seq[0,0]
xyz_t = xyz_t[0]
mask_t = mask_t[0]
mask_crds = mask_crds[0,0]
bond_feats = bond_feats[0]
N_token = seq.shape[0]
logger.debug(f'{ch_label[:5]}')
aa = [ChemData().num2aa[num] for num in seq]
aa1 = [ChemData().to1letter.get(k, k) for k in aa]
# aa_af3 = [aa_coarse_from_fine[s] for s in aa]
# Converts a ChemData() sequence token to an af3 token
@torch_vectorize
def af3num_from_num(num):
chemdata_code = ChemData().num2aa[num]
coarse = aa_coarse_from_fine[chemdata_code]
return af3_aa2num[coarse]
@np.vectorize
def get_token_type(num):
code3 = ChemData().num2aa[num]
return aa_restype_from_fine[code3]
f = {}
### Residue level ###
f['residue_index'] = idx_pdb
f['token_index'] = torch.arange(N_token)
f['asym_id'] = ch_label
# Hacked:
f['entity_id'] = torch.zeros(N_token)
f['sym_id'] = within_group_unique_ids(f['entity_id'], f['asym_id'])
# f['restype'] = F.one_hot(torch.tensor([af3_aa2num[aa_af3]]), len(af3_num2aa))
f['restype'] = F.one_hot(af3num_from_num(seq), len(af3_num2aa))
# token_type = torch.tensor([aa_restype_from_fine[aa]])
# token_type = np.vectorize(aa_restype_from_fine.__getitem__)(aa)
token_type = get_token_type(seq)
f['is_protein'] = torch.tensor(token_type == TokenType.PROTEIN)
f['is_rna'] = torch.tensor(token_type == TokenType.RNA)
f['is_dna'] = torch.tensor(token_type == TokenType.DNA)
f['is_ligand'] = torch.tensor(token_type == TokenType.LIGAND)
### Atom level ###
allatom_mask = ChemData().allatom_mask.to(device, non_blocking=True)
# remove symmetry dimension
# if len(true_crds.shape) == 4:
# true_crds = true_crds[0:1]
# mask_crds = mask_crds[0:1]
# true_crds = true_crds[0]
# want to unroll the coordinate tensors to get the full coordinates in (atoms, 3)
is_real_atom = allatom_mask[seq].bool()
if only_ca:
is_real_atom[:] = False
is_real_atom[:, 1] = True
tok_idx = is_real_atom.nonzero()[:,0]
within_tok_idx = is_real_atom.nonzero()[:,1]
N_atom = len(tok_idx)
f['tok_idx'] = tok_idx
# atom_mask = mask_crds[is_real_atom]
# t = interpolant.sample_t(D)
# Hacked:
f['ref_pos'] = torch.rand((N_atom, 3))
f['ref_mask'] = torch.arange(N_atom)
element = [ChemData().aa2elt[seq[tok]][within_tok] for tok, within_tok in zip(tok_idx, within_tok_idx)]
f['ref_element'] = F.one_hot(torch.tensor([element_code_to_num[e] for e in element]), len(element_codes))
# Hacked:
f['ref_charge'] = torch.zeros((N_atom))
f['ref_atom_name_chars'] = torch.zeros((N_atom, 4, 64))
f['ref_atom_name_chars'][:,0,0] = torch.arange(N_atom)
f['ref_space_uid'] = integer_tokenize(list(zip(f['asym_id'], f['residue_index'])))[tok_idx]
### MSA ###
f['msa'] = F.one_hot(af3num_from_num(msa), len(af3_num2aa))
# Hacked
N_msa = msa.shape[0]
f['has_deletion'] = torch.zeros((N_msa, N_token))
f['deletion_value'] = torch.zeros((N_msa, N_token))
f['profile'] = torch.zeros((N_token, 32))
f['deletion_mean'] = torch.zeros((N_token))
### Templates ###
N_templ = xyz_t.shape[0]
# Hacked:
template_seq = t1d[0].argmax(dim=-1) # [T, I]
assert (template_seq < ChemData().NPROTAAS - 1).all() # only 20 AA + 1 UNK (No mask)
f['template_restype'] = F.one_hot(af3num_from_num(template_seq), len(af3_num2aa))
template_is_protein = torch.tensor(get_token_type(template_seq) == TokenType.PROTEIN)
template_is_gly = template_seq == ChemData().aa2num['GLY']
template_atom_name = np.where(template_is_gly, ' CA ', ' CB ')
template_protein_beta_idx = torch_vectorize(lambda token, atom_name: ChemData().aa2long[token].index(atom_name))(
template_seq[template_is_protein],
template_atom_name[template_is_protein]
)
template_beta_idx = torch.full((N_templ, N_token), 0)
template_beta_idx[template_is_protein] = template_protein_beta_idx
template_beta_exists = torch.gather(mask_t, dim=2, index=template_beta_idx[..., None]).squeeze(-1)
f['template_pseudo_beta_mask'] = template_beta_exists * template_is_protein # .reshape?
f['template_backbone_frame_mask'] = mask_t[:, :, torch.tensor([0,1,2])].all(dim=-1)
# Reshape index_tensor to match the dimensions of xyz_t except for the last dimension
index_tensor_expanded = template_beta_idx[...,None,None].expand(-1, -1, -1, 3)
template_pseudo_beta = torch.gather(xyz_t, dim=2, index=index_tensor_expanded).squeeze(-2) #.squeeze(-1)
template_pseudo_beta_distogram = torch.cdist(template_pseudo_beta, template_pseudo_beta)
template_pseudo_beta_distogram *= f['template_pseudo_beta_mask'].unsqueeze(-1) * f['template_pseudo_beta_mask'].unsqueeze(-2)
f['template_distogram'] = discretize_distance_matrix(
template_pseudo_beta_distogram,
num_bins=NUM_TEMPLATE_DISTOGRAM_BINS,
)
CA_IDX = 1
template_ca = xyz_t[..., CA_IDX, :]
template_ca_disp = template_ca.unsqueeze(-2) - template_ca.unsqueeze(-3)
template_ca_disp_unit = template_ca_disp / torch.linalg.norm(template_ca_disp, dim=-1, keepdim=True)
template_R, _ = rigid_from_3_points(
xyz_t[:,:,0],
xyz_t[:,:,1],
xyz_t[:,:,2],
)
has_ca = mask_t[..., CA_IDX]
both_have_ca = has_ca[..., None, :] * has_ca[..., None]
template_unit_vector = rot_vec_mul(template_R[:,:, None], template_ca_disp_unit)
template_unit_vector[both_have_ca] = 0
f['template_unit_vector'] = template_unit_vector
has_ligand_2d = (f['is_ligand'].unsqueeze(-2) + f['is_ligand'].unsqueeze(-1)).bool()
# is_ligand_ligand = f['is_ligand'].unsqueeze(-2) * f['is_ligand'].unsqueeze(-1)
# Hacked (as covalent bonds are not represented in bond_feats and 2.4A filter not applied)
f['token_bonds'] = has_ligand_2d * (bond_feats > 0)
X_gt_L = true_crds[is_real_atom]
atom_mask = mask_crds[is_real_atom]
t = sigma_data * torch.exp(-1.2 + 1.5 * torch.normal(mean=0, std=1, size=(D,)))
X_gt_L = centre(X_gt_L, atom_mask)
X_gt_L = X_gt_L.tile(D,1,1)
if random_augmentation:
X_gt_L = get_random_augmentation(X_gt_L, s_trans=s_trans)
_, L, _ = X_gt_L.shape
t_tiled = t[:, None, None].tile(1, L, 3)
noise = torch.normal(mean=0, std=t_tiled)
X_noisy_L = X_gt_L + noise
return (
# network input
dict(
X_noisy_L=X_noisy_L,
t=t,
f=f,
),
# loss input (trues)
dict(
X_gt_L=X_gt_L,
crd_mask_I = is_real_atom,
seq=seq,
bond_feats=bond_feats,
)
)
def centre(X_L, X_exists_L):
X_L = X_L.clone()
X_L[X_exists_L] = X_L[X_exists_L] - torch.mean(X_L[X_exists_L], dim=-2, keepdim=True)
X_L[~X_exists_L] = 0.0
return X_L
def get_random_augmentation(X_L, s_trans):
'''
Inputs:
X_L [D, L, 3]: Batched atom coordinates
s_trans (float): standard deviation of a global translation to be applied for each
element in the batch
'''
D, L, _ = X_L.shape
R = uniform_random_rotation((D,))
noise = s_trans * torch.normal(mean=0, std=1, size=(D,1,3))
return rot_vec_mul(R[:,None], X_L) + noise
def uniform_random_rotation(size):
# Sample random angles for rotations around X, Y, and Z axes
theta_x = torch.rand(size) * 2 * math.pi
theta_y = torch.rand(size) * 2 * math.pi
theta_z = torch.rand(size) * 2 * math.pi
# Calculate the cosines and sines of the angles
cos_x = torch.cos(theta_x)
sin_x = torch.sin(theta_x)
cos_y = torch.cos(theta_y)
sin_y = torch.sin(theta_y)
cos_z = torch.cos(theta_z)
sin_z = torch.sin(theta_z)
# Create the rotation matrices around X, Y, and Z axes
rotation_x = torch.stack([torch.tensor([[1, 0, 0],
[0, c, -s],
[0, s, c]]) for c, s in zip(cos_x, sin_x)])
rotation_y = torch.stack([torch.tensor([[c, 0, s],
[0, 1, 0],
[-s, 0, c]]) for c, s in zip(cos_y, sin_y)])
rotation_z = torch.stack([torch.tensor([[c, -s, 0],
[s, c, 0],
[0, 0, 1]]) for c, s in zip(cos_z, sin_z)])
# Combine the rotation matrices
rotation_matrix = torch.matmul(rotation_z, torch.matmul(rotation_y, rotation_x))
return rotation_matrix

View File

@@ -6,6 +6,9 @@ import numpy as np
from torch.utils import data
from typing import Dict
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
def get_id_lengths(

View File

@@ -1,4 +1,7 @@
import torch
import tree
import json
from icecream import ic
def debug_nans(latent_feats):
@@ -32,3 +35,36 @@ def debug_grads(model):
def debug_nan_params(model):
for name, param in model.named_parameters():
print(f"{name}: {torch.sum(param.isnan())}")
def pretty_describe_dict(d):
mapped = describe_dict(d)
mapped = tree.map_structure(str, mapped)
return json.dumps(mapped, indent=4)
def describe_dict(d):
return tree.map_structure(describe, d)
def describe(t: torch.Tensor):
out = [f'type:{type(t)}']
if hasattr(t, 'shape'):
out.append(f'shape:{str(t.shape)}')
if hasattr(t, 'dtype'):
out.append(f'dtype:{str(t.dtype)}')
if hasattr(t, 'device'):
out.append(f'device:{str(t.device)}')
return ' '.join(out)
def safe_shape(t: torch.Tensor):
if hasattr(t, 'shape'):
return t.shape
return None
def log_in_out(f):
def wrapped(*args, **kwargs):
o = f(*args, **kwargs)
ic(
args, kwargs,
o
)
return o
return wrapped

75
rf2aa/loggers.py Normal file
View File

@@ -0,0 +1,75 @@
import logging
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
from icecream import ic
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers import WandbLogger
from rf2aa.debug import pretty_describe_dict
logger = logging.getLogger(__name__)
def mean_over(df, grouper, metrics):
out = df[[grouper] + metrics].groupby(grouper).mean(numeric_only=True).stack().to_dict()
out = {f'{grouper}={grouper_v}.{metric}': v for (grouper_v, metric), v in out.items()}
return out
class LitLogger(Logger):
def __init__(self, save_dir, use_wandb, sublogger):
self.use_wandb = use_wandb
if self.use_wandb:
self.sublogger = WandbLogger(**sublogger)
else:
self.sublogger = CSVLogger(**sublogger)
super().__init__()
def log_df(self, df, stratifications=None):
global_step = df['global_step'].iloc[0]
mean_over_step = df.groupby('global_step').mean(numeric_only=True)
assert len(mean_over_step) == 1
mean_over_step = mean_over_step.iloc[0]
mean_over_step = mean_over_step.to_dict()
stratified = {}
for groupers, values in stratifications.items():
# TODO: enable stratification over multiple keys
assert len(groupers) == 1
stratified.update(mean_over(df, groupers[0], values))
stratified = mean_over_step | stratified
self.sublogger.log_metrics(stratified, step=global_step)
@property
def name(self):
return "MyLogger"
@property
def version(self):
# Return the experiment version, int or str.
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
self.sublogger.log_hyperparams(params)
@rank_zero_only
def log_metrics(self, metrics, step):
ic(step)
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
self.sublogger.log_metrics(metrics, step)
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
self.sublogger.save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
self.sublogger.finalize(status)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from functools import partial
import numpy as np
from torch import relu
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util_module import Dropout
from rf2aa.model.AF3_structure import AtomAttentionEncoder, AtomAttentionDecoder
class NonEquivariantAtomEncoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionEncoder(**block_params)
class NonEquivariantAtomDecoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionDecoder(**block_params)

56
rf2aa/pymol.py Normal file
View File

@@ -0,0 +1,56 @@
import sys
import os
from icecream import ic
import xmlrpc.client as xmlrpclib
class XMLRPCWrapperProxy(object):
def __init__(self, wrapped=None):
self.name = 'cmd'
self.wrapped = wrapped
def __getattr__(self, name):
attr = getattr(self.wrapped, name)
# ic(type(self), attr)
wrapped = type(self)(attr)
wrapped.name = name
return wrapped
def __call__(self, *args, **kw):
try:
return self.wrapped(*args, **kw)
except Exception as e:
all_args = tuple(map(str, args))
all_args += tuple(f'{k}={v}' for k,v in kw.items())
raise Exception(f"cmd.{self.name}('{','.join(all_args)})'") from e
def get_cmd(pymol_url='http://localhost:9123'):
cmd = xmlrpclib.ServerProxy(pymol_url)
if not ('ipd' in pymol_url or 'localhost' in pymol_url):
make_network_cmd(cmd)
return cmd
cmd = None
def init(pymol_url='http://localhost:9123'):
global cmd
cmd_inner = get_cmd(pymol_url)
if cmd is None:
cmd = XMLRPCWrapperProxy(cmd_inner)
else:
cmd.wrapped = cmd_inner
def make_network_cmd(cmd):
# old_load = cmd.load
def new_load(*args, **kwargs):
path = args[0]
with open(path) as f:
contents = f.read()
# args[0] = contents
args = (contents,) + args[1:]
#print('writing contents')
cmd.read_pdbstr(*args, **kwargs)
cmd.is_network = True
cmd.load = new_load
init()

65
rf2aa/pymol_tools.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import torch
from rf2aa.pymol import cmd
from rf2aa.util import writepdb
def clear():
cmd.reinitialize('everything')
cmd.delete('all')
# cmd.do(f'cd {REPO_DIR}/pymol_config')
cmd.do('@./pymolrc')
# def pseudoatom(
# pos: list = [0,0,0],
# label='origin',
# ):
# cmd.pseudoatom(label,'', 'PS1','PSD', '1', 'P',
# 'PSDO', 'PS', -1.0, 1, 0.0, 0.0, '',
# '', pos)
# # cmd.do(f'label {label}, "{label}"')
# return label
def pseudoatom(
cmd,
pos: list = [0,0,0],
label='origin',
):
cmd.pseudoatom(label,'', 'PS1','PSD', '1', 'P',
'PSDO', 'PS', -1.0, 1, 0.0, 0.0, '',
'', pos)
# cmd.do(f'label {label}, "{label}"')
return label
def show_origin():
pa = pseudoatom(cmd, label='the_origin')
cmd.center(pa)
cmd.color('red', pa)
cmd.set('grid_slot', -2, pa)
def show_pymol(
true_crds,
seq,
bond_feats,
label='unlabeled'
):
pdb_path = f"tmp/true_0.pdb"
writepdb(
pdb_path,
true_crds,
seq.long(),
bond_feats=bond_feats[None],
)
cmd.load(os.path.abspath(pdb_path), label)
show_origin()
def to_atom37(X_L, atom_mask):
assert X_L.shape[-1] == 3
assert X_L.numel() / 3 == atom_mask.sum(), f'{X_L.numel()/3=} != {atom_mask.sum()=}. {X_L.shape=} {atom_mask.shape=}'
L, _ = atom_mask.shape[-2:]
X_I = torch.zeros(atom_mask.shape + (3,), dtype=torch.float) - 10
X_I[atom_mask] = X_L
return X_I

View File

@@ -167,3 +167,8 @@ def cmp(got, want, **kwargs):
if dd:
return dd
return ''
def assert_cmp(got, want, **kwargs):
diff = cmp(got, want, **kwargs)
if diff:
raise AssertionError(diff)

31
rf2aa/tests/test_align.py Normal file
View File

@@ -0,0 +1,31 @@
import os
import torch
import pytest
from icecream import ic
from rf2aa.alignment import weighted_rigid_align, get_rmsd
from rf2aa.util import kabsch
def pseudobatched_kabsch(xyz1, xyz2):
B = xyz1.shape[0]
out = []
for i in range(B):
out.append(kabsch(xyz1[i], xyz2[i])[0])
return torch.stack(out)
def test_align():
torch.manual_seed(0)
B = 9
L = 5
x_from = torch.rand((B, L, 3))
x_to = torch.rand((B, L, 3))
w = torch.ones((B, L))
rmsd_kabsch = pseudobatched_kabsch(x_from, x_to)
x_from_align = weighted_rigid_align(x_from, x_to, w)
rmsd_weighted_rigid = get_rmsd(x_to, x_from_align)
ic(rmsd_weighted_rigid, rmsd_kabsch)
assert (torch.abs(rmsd_weighted_rigid - rmsd_kabsch) < 1e-5).all(), f'{rmsd_weighted_rigid} != {rmsd_kabsch}'
if __name__ == '__main__':
test_align()

View File

@@ -0,0 +1,165 @@
import os
import torch
import pytest
from icecream import ic
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from hydra import initialize, compose
from pl_bolts.utils import BatchGradientVerification
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping
from torch.nn.functional import one_hot
from rf2aa.debug import pretty_describe_dict
from rf2aa.model.AF3_structure import Model, DiffusionModule, AtomTransformer
from rf2aa.tensor_util import assert_shape, assert_cmp
def test_batch_leakage():
conf_overrides = []
with initialize(version_base=None, config_path="../config/train"):
conf = compose(config_name='af3_repro', overrides=conf_overrides)
c_s = conf.model.c_s
c_z = conf.model.c_z
model = DiffusionModule(
c_atom=128,
c_atompair=16,
c_s=c_s,
c_z=c_z,
**conf.model.diffusion_module)
verification = BatchGradientVerification(model)
D = 3
I = 80
L = 160
C_s_inputs = conf.model.diffusion_module.diffusion_conditioning.c_s_inputs
C_s_trunk = conf.model.c_s
inputs = dict(
X_noisy_L = torch.rand((D, L, 3)),
t = torch.rand((D,)),
f = {
'asym_id': torch.zeros(I),
'residue_index': torch.arange(I),
'entity_id': torch.zeros(I),
'token_index': torch.arange(I),
'sym_id': torch.zeros(I).long(),
'tok_idx': torch.arange(L) // 2,
'ref_pos': torch.rand((L, 3)),
'ref_charge': torch.rand((L,)),
'ref_mask': torch.arange(L) // 2,
'ref_element': one_hot(torch.randint(127, (L,)), 128),
'ref_atom_name_chars': torch.zeros((L, 4, 64)),
'ref_space_uid': torch.zeros((L,))
},
S_inputs_I = torch.rand((I, C_s_inputs)),
S_trunk_I = torch.rand((I, C_s_trunk)),
Z_trunk_II = torch.rand((I, I, c_z))
)
batched_inputs = default_input_mapping(inputs)
assert len(batched_inputs) == 2 and batched_inputs[0].shape[0] == D, f'default input mapping (should contain X_noisy_L and t:\n' + pretty_describe_dict(batched_inputs)
print(f'{verification.NORM_LAYER_CLASSES=}')
# Assert that there are no cross-batch gradients.
# See: https://lightning-bolts.readthedocs.io/en/latest/callbacks/monitor.html
# for details.
valid = verification.check(
input_array=inputs, sample_idx=0)
# Assert that the model produces the same output when run batched/unbatched.
out_batched = model(**inputs)
inputs_unbatched = inputs
inputs_unbatched['X_noisy_L'] = inputs['X_noisy_L'][:1]
inputs_unbatched['t'] = inputs['t'][:1]
out_single = model(**inputs_unbatched)
assert_cmp(out_single, out_batched[0:1])
ic(out_single.shape, out_batched.shape)
# Assert that the batch outputs are different.
assert torch.norm(out_batched[0] - out_batched[1]) > 1
assert valid
def plot_attention_map(attn, diag=True):
colors = ['indigo', 'yellow']
if diag:
attn[np.diag_indices_from(attn)] = 2
colors = ['indigo', 'yellow', 'green']
cmap = mcolors.ListedColormap(colors)
plt.matshow(attn, cmap=cmap)
plt.axis('off') # Turn off axis
plt.show()
def test_sequence_local_atom_attention():
conf_overrides = []
with initialize(version_base=None, config_path="../config/train"):
conf = compose(config_name='af3_repro', overrides=conf_overrides)
conf = conf.model.diffusion_module.atom_attention_encoder.atom_transformer
# Show the model's attenion map.
show_full = False
if show_full:
# Get full size attention map.
conf.l_max = 200
atom_transformer = AtomTransformer(
c_atom=10,
c_atompair=11,
**conf
)
Beta_lm = atom_transformer.Beta_lm
attn = (Beta_lm == 0).long()
plot_attention_map(attn)
# Show af3 supplement-style attention map.
show_supp = False
if show_supp:
atom_transformer = AtomTransformer(
c_atom=10,
c_atompair=11,
l_max=200,
n_queries=32,
n_keys=64,
diffusion_transformer=conf.diffusion_transformer,
)
Beta_lm = atom_transformer.Beta_lm
attn = (Beta_lm == 0).long()
plot_attention_map(attn)
atom_transformer = AtomTransformer(
c_atom=10,
c_atompair=11,
l_max=10,
n_queries=2,
n_keys=4,
diffusion_transformer=conf.diffusion_transformer,
)
L = 6
Beta_lm = atom_transformer.Beta_lm
Beta_lm = Beta_lm[:L, :L]
# Show small test-case attention map.
show_test_case = False
if show_test_case:
plot_attention_map((Beta_lm==0).long(), diag=False)
o = 0
x = -1e10
want_Beta_lm = torch.tensor([
[o,o,o,x,x,x],
[o,o,o,x,x,x],
[x,o,o,o,o,x],
[x,o,o,o,o,x],
[x,x,x,o,o,o],
[x,x,x,o,o,o],
])
assert_cmp(Beta_lm, want_Beta_lm)

View File

@@ -0,0 +1,29 @@
import os
import torch
import pytest
from rf2aa.util_module import Dropout
def test_dropout():
torch.manual_seed(0)
drop_row = Dropout(broadcast_dim=0, p_drop=0.5)
d = 8
x = torch.rand((d, d))
x = drop_row(x)
print('x:')
print(x)
assert not torch.all(x==0)
for i in range(d):
row = x[i]
if torch.any(row==0):
assert torch.all(row==0)
has_all_zero_row = False
for i in range(d):
row = x[i]
if torch.all(row==0):
has_all_zero_row = True
assert has_all_zero_row

222
rf2aa/trainer_lightning.py Normal file
View File

@@ -0,0 +1,222 @@
import re
import random
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from icecream import ic
import numpy as np
from functools import partial
import hydra
import os
import time
import omegaconf
from contextlib import nullcontext
import datetime
from datetime import timedelta
import certifi
import warnings
import wandb
import logging
import tree
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
from rf2aa.data.dataloader_adaptor_af3 import prepare_input_af3
from rf2aa.flow_matching.interpolant import Interpolant
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads, pretty_describe_dict
from rf2aa.training.EMA import EMA, count_parameters
from rf2aa.loss.loss import translation_vector_field
from rf2aa.loss.loss_factory import get_loss_and_misc
from rf2aa.training.optimizer import add_weight_decay
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward, recycle_step_generic
from rf2aa.model.network import RosettaFold
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
import rf2aa.util as util
from rf2aa.util_module import XYZConverter
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.chemical import initialize_chemdata
from rf2aa.set_seed import seed_all
from rf2aa.model import AF3_structure
from rf2aa.callbacks import LogMetrics, FindUnusedParameters, NetworkOutputGradSanityCheck, MonitorActivations
from rf2aa.loggers import LitLogger
ic.configureOutput(includeContext=True)
logger = logging.getLogger(__name__)
#TODO: control environment variables from config
# limit thread counts
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
# Update environment variable with correct path (needed for W&B upload)
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
## To reproduce errors
torch.set_num_threads(4)
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def get_param_sizes(model):
o = {}
for k, p in model.named_parameters():
o[k] = (np.array(p.size()).prod(), p.size())
return o
# define the LightningModule
class LitAF3Repro(L.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
# self.model = torch.nn.Linear(2, 3).to(device)
self.model = AF3_structure.Model(**self.config.model)
print_n_params = False
if print_n_params:
logger.info(f'{get_n_params(self.model)=}')
for k, v in sorted(get_param_sizes(self.model).items(), key=lambda item: item[1]):
n_param, size = v
# n_param = np.array(p.size()).prod()
logger.info(f'{n_param=} {k=} {size=}')
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
def should_ignore(param_name):
ignore_regexes = [
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_s_trunk\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_z\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_r\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias.ln_1\..*'),
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.linear_output_project\..*'),
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.ada_ln_1\..*'),
re.compile(r'model\.diffusion_module\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
re.compile(r'model\.diffusion_module\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
re.compile(r'model\.diffusion_module\.atom_attention_decoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
]
return any(regex.match(param_name) for regex in ignore_regexes)
params_to_ignore = []
for param_name, param in self.model.named_parameters():
if should_ignore(param_name):
params_to_ignore.append(param_name)
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
self.model,
params_to_ignore
)
assert len(params_to_ignore)
self.loss = AF3_structure.Loss(**self.config.loss)
def training_step(self, batch, batch_idx):
logger.debug('batch:\n' + pretty_describe_dict(batch))
# TODO: move data processing to dataset
batch = tree.map_structure(lambda x: x.detach().cpu() if hasattr(x, 'cpu') else x, batch)
network_input, loss_input = prepare_input_af3(
batch,
**self.config.af3_data_prep,
)
# TODO: move data processing to dataset
network_input = tree.map_structure(lambda x: x.to(self.device), network_input)
loss_input = tree.map_structure(lambda x: x.to(self.device), loss_input)
logger.debug('network_input:\n' + pretty_describe_dict(network_input))
logger.debug('loss_input:\n' + pretty_describe_dict(loss_input))
n_cycle = random.randint(1, self.config.recycling.max_cycle)
X_L = self.model(
network_input,
n_cycle,
no_sync=self.model.no_sync,
)
loss, loss_dict, loss_dict_batched = self.loss(
f=network_input['f'],
t=network_input['t'],
X_L=X_L,
X_gt_L=loss_input['X_gt_L'],
)
self.log('loss', loss, prog_bar=True)
return dict(
loss=loss,
X_L=X_L,
) | loss_dict_batched | network_input | loss_input
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.config.optimizer.type)(
self.model.parameters(),
**self.config.optimizer.params,
)
scheduler = get_stepwise_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=1000,
num_steps_decay=5e4,
decay_rate=0.95,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def configure_callbacks(self):
return [
LogMetrics(self.config, **self.config.callbacks.log_metrics),
NetworkOutputGradSanityCheck(),
MonitorActivations(),
LearningRateMonitor(logging_interval='step'),
]
class LitDataModule(L.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
self.init = partial(initialize_chemdata, config.chem_params)
self.init()
def train_dataloader(self, rank=None, num_replicas=None):
train_loader, train_sampler, valid_loaders, valid_samplers = compose_dataset(
self.init, self.config.dataset_params, self.config.loader_params,
rank or 0,
num_replicas or 1,
)
return train_loader
@hydra.main(version_base=None, config_path='config/train')
def main(config):
if config.autograd_detect_anomaly:
torch.autograd.set_detect_anomaly(True)
model = LitAF3Repro(config)
datamodule = LitDataModule(config)
trainer_logger = LitLogger(**config.logger)
model_checkpoint = ModelCheckpoint(
every_n_train_steps=1000,
dirpath='checkpoints',
)
trainer = L.Trainer(
logger=trainer_logger,
log_every_n_steps=1,
gradient_clip_val=10,
callbacks=[model_checkpoint],
**config.lightning.trainer
)
trainer.fit(model=model, datamodule=datamodule)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,9 @@
import re
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from icecream import ic
import numpy as np
from functools import partial
import hydra
@@ -13,17 +15,21 @@ import datetime
from datetime import timedelta
import certifi
import warnings
import wandb
import logging
import tree
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items, prepare_input_fm_allatom
from rf2aa.data.dataloader_adaptor_af3 import prepare_input_af3
from rf2aa.flow_matching.interpolant import Interpolant
from rf2aa.flow_matching.sampler import Sampler, AllAtomSampler
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads
from rf2aa.debug import debug_unused_params, debug_used_params, debug_grads, pretty_describe_dict
from rf2aa.training.EMA import EMA, count_parameters
from rf2aa.loss.loss import translation_vector_field
from rf2aa.loss.loss_factory import get_loss_and_misc
from rf2aa.training.optimizer import add_weight_decay
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward
from rf2aa.training.recycling import recycle_step_legacy, recycle_step_packed, recycle_step_gen, recycle_sampling, run_model_forward, recycle_step_generic
from rf2aa.model.network import RosettaFold
from rf2aa.model.RoseTTAFoldModel import LegacyRoseTTAFoldModule
from rf2aa.training.scheduler import get_stepwise_decay_schedule_with_warmup
@@ -32,6 +38,9 @@ from rf2aa.util_module import XYZConverter
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.chemical import initialize_chemdata
from rf2aa.set_seed import seed_all
from rf2aa.model import AF3_structure
logger = logging.getLogger(__name__)
#TODO: control environment variables from config
# limit thread counts
@@ -75,6 +84,8 @@ class Trainer:
self.scaler = torch.cuda.amp.GradScaler(enabled=self.config.training_params.use_amp)
def load_checkpoint(self, rank):
if self.config.training_params.from_scratch:
return False
checkpoint_path = f"{self.output_dir}/{self.config.experiment.name}_last.pt"
# 'checkpoint_path' takes priority ...
if self.config.eval_params.checkpoint_path:
@@ -178,13 +189,18 @@ class Trainer:
if ("SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
world_size = int(os.environ["SLURM_NTASKS"])
self.world_size = world_size
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
else:
print ("Launched from interactive")
world_size = torch.cuda.device_count()
world_size = max(torch.cuda.device_count(), 1)
if self.config.cpu_training:
world_size = 1
self.world_size = world_size
if world_size == 0:
print ("Error! No GPUs found!")
@@ -195,10 +211,13 @@ class Trainer:
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
def init_process_group(self, rank, world_size):
gpu = rank % torch.cuda.device_count()
gpu = rank % self.world_size
dist.init_process_group(backend=self.config.training_params.ddp_backend, timeout=timedelta(seconds=1800), world_size=world_size, rank=rank)
torch.cuda.set_device("cuda:%d"%gpu)
return gpu
device = 'cpu'
if torch.cuda.device_count():
device = "cuda:%d"%gpu
torch.cuda.set_device(device)
return device
def cleanup(self):
if dist.is_initialized():
@@ -240,6 +259,7 @@ class Trainer:
self.valid_loaders = valid_loaders
# move global information to device
# if torch.cuda.device_count():
self.move_constants_to_device(gpu)
self.construct_model(device=gpu)
@@ -251,6 +271,7 @@ class Trainer:
self.construct_scaler()
start_epoch = 0
loaded_checkpoint = self.load_checkpoint(gpu)
logger.info(f'Loaded checkpoint: {loaded_checkpoint}')
if loaded_checkpoint:
start_epoch = self.checkpoint["epoch"]
self.load_model()
@@ -279,6 +300,7 @@ class Trainer:
if (
self.config.dataset_params.validate_every_n_epochs > 0
and epoch % self.config.dataset_params.validate_every_n_epochs==0
and (epoch!=start_epoch or self.config.dataset_params.validate_after_first_epoch)
):
self.valid_epoch(epoch, rank, world_size)
@@ -318,6 +340,20 @@ class Trainer:
if (p.grad is not None):
print (n, torch.max( torch.abs(p.flatten()) ), torch.max( torch.abs(p.grad.flatten()) ))
exit(1)
find_no_grad_parameters = False
if find_no_grad_parameters:
no_grad_parameters = []
for n,p in self.model.module.model.named_parameters():
if p.grad is None:
no_grad_parameters.append(n)
if no_grad_parameters:
print('Parameters with grad == None:')
for n in no_grad_parameters:
print(n)
print(f'Fraction with grad == None: {len(no_grad_parameters)}/{len(list(self.model.module.model.named_parameters()))}')
train_time = time.time() - start_time
@@ -400,11 +436,12 @@ class Trainer:
def log_intermediate_losses(self, inputs, loss_dict, n_cycle, Nex, Nepoch, runtime):
item = inputs[-1]
max_mem = torch.cuda.max_memory_allocated()/1e9
print(f"Models: {Nex} of: {Nepoch} Max_Memory: {max_mem:.4f} Runtime: {runtime:.4f}")
print(f"Models: {Nex} of: {Nepoch} Max_Memory: {max_mem:.4f}Gb Runtime: {runtime:.4f}")
print(f"Example: {item} Recycle:{n_cycle}\n"+
"\t".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
#print(f"Models: {Nex} Example: {item['CHAINID']} "+" ".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
torch.cuda.reset_peak_memory_stats()
#print(f"Models: {Nex} Example: {item['CHAINID']}"+" ".join([f"{k}: {v:.4f}" for k,v in loss_dict.items()]))
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def log_validation_losses(self, dataset_name, loss_dict):
print(f"Dataset: {dataset_name} "+
@@ -429,9 +466,12 @@ class LegacyTrainer(Trainer):
cb_tor = ChemData().cb_torsion_t.to(device),
).to(device)
device_ids = [device]
if device == 'cpu':
device_ids = None
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
self.model = DDP(self.model, device_ids=[device], find_unused_parameters=False, broadcast_buffers=False)
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=False, broadcast_buffers=False)
def train_step(self, inputs, n_cycle, nograds=False, return_outputs=False):
@@ -627,11 +667,115 @@ class FlowMatchingTrainer(Trainer):
# If using W&B, log the validation losses (note: this is only done for rank = 0)
if self.config.log_params.use_wandb:
wandb.log(valid_loss_dict)
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def get_param_sizes(model):
o = {}
for k, p in model.named_parameters():
o[k] = (np.array(p.size()).prod(), p.size())
return o
class AF3Trainer(FlowMatchingTrainer):
def construct_model(self, device="cpu"):
# self.model = torch.nn.Linear(2, 3).to(device)
self.model = AF3_structure.Model(**self.config.model).to(device)
print_n_params = False
if print_n_params:
logger.info(f'{get_n_params(self.model)=}')
for k, v in sorted(get_param_sizes(self.model).items(), key=lambda item: item[1]):
n_param, size = v
# n_param = np.array(p.size()).prod()
logger.info(f'{n_param=} {k=} {size=}')
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
def should_ignore(param_name):
ignore_regexes = [
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_s_trunk\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_z\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.process_r\..*'),
re.compile(r'model\.feature_initializer\.input_feature_embedder\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias.ln_1\..*'),
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.linear_output_project\..*'),
re.compile(r'model\.recycler\.pairformer_stack\.\d+\.attention_pair_bias\.ada_ln_1\..*'),
re.compile(r'model\.diffusion_module\.atom_attention_encoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
re.compile(r'model\.diffusion_module\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
re.compile(r'model\.diffusion_module\.atom_attention_decoder\.atom_transformer\.diffusion_transformer\.blocks\.\d+\.attention_pair_bias\.ln_1\..*'),
]
return any(regex.match(param_name) for regex in ignore_regexes)
params_to_ignore = []
for param_name, param in self.model.named_parameters():
if should_ignore(param_name):
params_to_ignore.append(param_name)
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
self.model,
params_to_ignore
)
assert len(params_to_ignore)
device_ids = [device]
if device == 'cpu':
device_ids = None
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=False, broadcast_buffers=False)
self.sampler = AllAtomSampler(self.model,
self.config.interpolant.sampling.num_timesteps,
self.config.interpolant.min_t,
self.interpolant,
self.xyz_converter,
is_training=True)
self.loss = AF3_structure.Loss(**self.config.loss)
def move_constants_to_device(self, gpu):
self.interpolant = Interpolant(self.config.interpolant)
self.interpolant.set_device(gpu)
super().move_constants_to_device(gpu)
def train_step(self, inputs, n_cycle, no_grads=False, return_outputs=False):
gpu = self.model.device
D = 12
network_input, loss_input = prepare_input_af3(
inputs,
self.config.interpolant,
D,
)
network_input=tree.map_structure(lambda x: x.to(gpu) if hasattr(x, 'cpu') else x, network_input)
logger.debug('network_input:\n' + pretty_describe_dict(network_input))
logger.debug('loss_input:\n' + pretty_describe_dict(loss_input))
output_i = self.model(
network_input,
n_cycle,
no_sync=self.model.no_sync,
)
loss, loss_dict, loss_dict_batched = self.loss(
f=network_input['f'],
t=network_input['t'],
X_L=output_i,
X_gt_L=loss_input['X_gt_L'].tile((D,1,1)).to(gpu)
)
return loss, loss_dict
def construct_optimizer(self):
self.optimizer = getattr(torch.optim, self.config.optimizer.type)(
self.model.parameters(),
**self.config.optimizer.params,
)
@hydra.main(version_base=None, config_path='config/train')
def main(config):
if config.autograd_detect_anomaly:
torch.autograd.set_detect_anomaly(True)
seed_all()
trainer = trainer_factory[config.experiment.trainer](config=config)
@@ -650,7 +794,8 @@ def main(config):
trainer_factory = {
"legacy": LegacyTrainer,
"composed": ComposedTrainer,
"flow_matching": FlowMatchingTrainer
"flow_matching": FlowMatchingTrainer,
"af3_repro": AF3Trainer
}
if __name__ == "__main__":

View File

@@ -75,6 +75,26 @@ def recycle_step_gen(ddp_model, input, n_cycle, use_amp, nograds=False, force_de
output_i = unpack_outputs(rf_outputs, rf_latents, return_raw)
return output_i
def recycle_step_generic(model, input, n_cycle, use_amp, nograds=False, force_device=None):
'''
Runs recycling with gradients only on final recycle.
Assums model has methods:
pre_recycle: input --> recycling_input
recycle: recycling_input --> recycling_input
post_recycle: recycling_input --> output
'''
assert not nograds, 'not implemented'
assert not use_amp, 'not implemented'
recycling_input = model.pre_recycle(**input)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
if i_cycle < n_cycle -1:
stack.enter_context(torch.no_grad())
stack.enter_context(model.no_sync())
recycling_input = model.recycle(**recycling_input)
return model.post_recycle(**recycling_input)
def run_model_forward(model, network_input, use_checkpoint=False, device="cpu"):
""" run model forward pass, no recycling, no ddp (for tests) """