mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Add utilities for training on a single-entry dataset. Allow validation skipping. WIP AF3 Non-equivariant structure encoder/decoder Add flag to force training from scratch Force training from scratch in debug config All modules in diffusion module implemented Document behavior of dropout with test Finish majority of model trunk Convert some ModuleLists to nn.Sequential Add RelativePositionEncoding and WIP af3_repro config Fix ref_space_uid embedding in AtomEncoder Put Model together with fake MSAModule and TemplateEmbedder AF3 repro loads model. WIP af3 data-adaptor, AF3_structure fixes Feature initializer working Standardize S_inputs_I Fix pairformer stack Forward pass working, WIP: backward pass stale reference fixing Add dataloader_adaptor_af3.py Backward pass working, WIP: still some unused params Backprop working Training runs Add pytorch lightning training and some wandb logging Training converging for single example. Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif trainer_lightning.py --config-name af3_repro_single_example_small logger.use_wandb=True af3_data_prep.D=6 Log loss Training working for single example. Run: /home/ahern/reclone/rf_diffusion_staging/rf_diffusion/exec/rf_diffusion_aa_2.sif trainer_lightning.py --config-name af3_repro_single_example_small_working_4 logger.use_wandb=True on an a4000 Add test_diffusion_module.py
35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint as checkpoint
|
|
from functools import partial
|
|
import numpy as np
|
|
from torch import relu
|
|
|
|
from rf2aa.debug import debug_nans
|
|
from rf2aa.model.layers.SE3_network import FullyConnectedSE3, FullyConnectedSE3_noR
|
|
from rf2aa.model.layers.structure_bias import structure_bias_factory
|
|
from rf2aa.model.layers.Attention_module import BiasedAxialAttention, FeedForwardLayer, MSAColAttention, \
|
|
MSARowAttentionWithBias, TriangleMultiplication, MSAColGlobalAttention, \
|
|
OldMSAColAttention, OldMSAColGlobalAttention, BiasedUntiedAxialAttention, TriangleAttention
|
|
from rf2aa.model.layers.outer_product import OuterProductMean # need to code this correctly
|
|
from rf2aa.training.checkpoint import create_custom_forward
|
|
from rf2aa.util_module import Dropout
|
|
from rf2aa.model.AF3_structure import AtomAttentionEncoder, AtomAttentionDecoder
|
|
|
|
|
|
class NonEquivariantAtomEncoder(nn.Module):
|
|
|
|
def __init__(self, block_params):
|
|
super().__init__()
|
|
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
|
self.model = AtomAttentionEncoder(**block_params)
|
|
|
|
|
|
class NonEquivariantAtomDecoder(nn.Module):
|
|
|
|
def __init__(self, block_params):
|
|
super().__init__()
|
|
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
|
self.model = AtomAttentionDecoder(**block_params)
|
|
|