Files
foundry/rf2aa/model/AF3_structure_wrapper.py
2025-02-04 21:44:04 -08:00

19 lines
659 B
Python

import torch.nn as nn
from rf2aa.model.AF3_structure import AtomAttentionDecoder, AtomAttentionEncoder
class NonEquivariantAtomEncoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionEncoder(**block_params)
class NonEquivariantAtomDecoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionDecoder(**block_params)