Files
foundry/rf2aa/tests/test_align.py
Woody Ahern 11101963df Use loggers
Add utilities for training on a single-entry dataset.

Allow validation skipping.

WIP AF3 Non-equivariant structure encoder/decoder

Add flag to force training from scratch

Force training from scratch in debug config

All modules in diffusion module implemented

Document behavior of dropout with test

Finish majority of model trunk

Convert some ModuleLists to nn.Sequential

Add RelativePositionEncoding and WIP af3_repro config

Fix ref_space_uid embedding in AtomEncoder

Put Model together with fake MSAModule and TemplateEmbedder

AF3 repro loads model.

WIP af3 data-adaptor, AF3_structure fixes

Feature initializer working

Standardize S_inputs_I

Fix pairformer stack

Forward pass working, WIP: backward pass stale reference fixing

Add dataloader_adaptor_af3.py

Backward pass working, WIP: still some unused params

Backprop working

Training runs

Add pytorch lightning training and some wandb logging

Training converging for single example.

Run:
/home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name af3_repro_single_example_small
logger.use_wandb=True af3_data_prep.D=6

Log loss

Training working for single example.

Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name
af3_repro_single_example_small_working_4 logger.use_wandb=True

on an a4000

Add test_diffusion_module.py
2024-06-20 17:25:32 -07:00

31 lines
852 B
Python

import os
import torch
import pytest
from icecream import ic
from rf2aa.alignment import weighted_rigid_align, get_rmsd
from rf2aa.util import kabsch
def pseudobatched_kabsch(xyz1, xyz2):
B = xyz1.shape[0]
out = []
for i in range(B):
out.append(kabsch(xyz1[i], xyz2[i])[0])
return torch.stack(out)
def test_align():
torch.manual_seed(0)
B = 9
L = 5
x_from = torch.rand((B, L, 3))
x_to = torch.rand((B, L, 3))
w = torch.ones((B, L))
rmsd_kabsch = pseudobatched_kabsch(x_from, x_to)
x_from_align = weighted_rigid_align(x_from, x_to, w)
rmsd_weighted_rigid = get_rmsd(x_to, x_from_align)
ic(rmsd_weighted_rigid, rmsd_kabsch)
assert (torch.abs(rmsd_weighted_rigid - rmsd_kabsch) < 1e-5).all(), f'{rmsd_weighted_rigid} != {rmsd_kabsch}'
if __name__ == '__main__':
test_align()