Files
foundry/rf2aa/model/AF3_structure_wrapper.py
Woody Ahern 11101963df Use loggers
Add utilities for training on a single-entry dataset.

Allow validation skipping.

WIP AF3 Non-equivariant structure encoder/decoder

Add flag to force training from scratch

Force training from scratch in debug config

All modules in diffusion module implemented

Document behavior of dropout with test

Finish majority of model trunk

Convert some ModuleLists to nn.Sequential

Add RelativePositionEncoding and WIP af3_repro config

Fix ref_space_uid embedding in AtomEncoder

Put Model together with fake MSAModule and TemplateEmbedder

AF3 repro loads model.

WIP af3 data-adaptor, AF3_structure fixes

Feature initializer working

Standardize S_inputs_I

Fix pairformer stack

Forward pass working, WIP: backward pass stale reference fixing

Add dataloader_adaptor_af3.py

Backward pass working, WIP: still some unused params

Backprop working

Training runs

Add pytorch lightning training and some wandb logging

Training converging for single example.

Run:
/home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name af3_repro_single_example_small
logger.use_wandb=True af3_data_prep.D=6

Log loss

Training working for single example.

Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif
trainer_lightning.py --config-name
af3_repro_single_example_small_working_4 logger.use_wandb=True

on an a4000

Add test_diffusion_module.py
2024-06-20 17:25:32 -07:00

35 lines
1.4 KiB
Python

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from functools import partial
import numpy as np
from torch import relu
from rf2aa.debug import debug_nans
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
from rf2aa.model.layers.structure_bias import structure_bias_factory
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
from rf2aa.training.checkpoint import create_custom_forward
from rf2aa.util_module import Dropout
from rf2aa.model.AF3_structure import AtomAttentionEncoder, AtomAttentionDecoder
class NonEquivariantAtomEncoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionEncoder(**block_params)
class NonEquivariantAtomDecoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionDecoder(**block_params)