From 2893a73ebb7ddf01b6be5cf76b2ee9e5b3d226a4 Mon Sep 17 00:00:00 2001 From: Simon Mathis Date: Wed, 13 Aug 2025 00:25:29 +0100 Subject: [PATCH] clean: remove unused embedding blocks --- src/modelhub/model/embedding_blocks.py | 228 ------------------------- 1 file changed, 228 deletions(-) delete mode 100644 src/modelhub/model/embedding_blocks.py diff --git a/src/modelhub/model/embedding_blocks.py b/src/modelhub/model/embedding_blocks.py deleted file mode 100644 index 0866355..0000000 --- a/src/modelhub/model/embedding_blocks.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.nn as nn - -from modelhub.chemical import ChemicalData as ChemData -from modelhub.model.layers.Embeddings import ( - Bond_emb, - Extra_emb, - MSA_emb, - MSA_emb_nostate, - Templ_emb, - Templ_emb_NoPtwise, - recycling_factory, -) - - -class RF2_embedding(nn.Module): - def __init__(self, global_params, block_params): - super(RF2_embedding, self).__init__() - d_msa, d_msa_full, d_pair, d_state = ( - global_params["d_msa"], - global_params["d_msa_full"], - global_params["d_pair"], - global_params["d_state"], - ) - self.latent_emb = MSA_emb( - d_msa=d_msa, - d_pair=d_pair, - d_state=d_state, - p_drop=block_params.p_drop, - use_same_chain=block_params.use_same_chain, - ) - self.full_emb = Extra_emb( - d_msa=d_msa_full, - d_init=ChemData().NAATOKENS - - 1 - + 4, # HACK: should define this freom the config (4: ins/del,nterm/cterm feats) - p_drop=block_params.p_drop, - ) - self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES) - - self.templ_emb = Templ_emb( - d_pair=d_pair, - d_templ=block_params.d_templ, - d_state=d_state, - n_head=block_params.n_head_templ, - d_hidden=block_params.d_hidden_templ, - p_drop=block_params.templ_p_drop, - symmetrize_repeats=block_params.symmetrize_repeats, # repeat protein stuff - repeat_length=block_params.repeat_length, - symmsub_k=block_params.symmsub_k, - sym_method=block_params.sym_method, - main_block=block_params.main_block, - copy_main_block=block_params.copy_main_block_template, - additional_dt1d=block_params.additional_dt1d, - ) - - ## Update inputs with outputs from previous forward pass - self.recycle = recycling_factory[block_params.recycling_type]( - d_msa=d_msa, d_pair=d_pair, d_state=d_state - ) - self.recycling_type = block_params.recycling_type - assert ( - self.recycling_type != "all" - ), "no backward compatibility to recycling state" - - def _unpack_inputs(self, rf_inputs): - msa_latent, msa_full, seq, idx, bond_feats, dist_matrix = ( - rf_inputs["msa_latent"], - rf_inputs["msa_full"], - rf_inputs["seq"], - rf_inputs["idx"], - rf_inputs["bond_feats"], - rf_inputs["dist_matrix"], - ) - ## recycling inputs - msa_prev, pair_prev, state_prev, xyz, sctors, mask_recycle = ( - rf_inputs["msa_prev"], - rf_inputs["pair_prev"], - None, - rf_inputs["xyz"], - rf_inputs["sctors"], - rf_inputs["mask_recycle"], - ) - return ( - msa_latent, - msa_full, - seq, - idx, - bond_feats, - dist_matrix, - msa_prev, - pair_prev, - state_prev, - xyz, - sctors, - mask_recycle, - ) - - def _add_templ_features(self, rf_inputs, pair, state): - t1d, t2d, alpha_t, xyz_t, mask_t = ( - rf_inputs["t1d"], - rf_inputs["t2d"], - rf_inputs["alpha_t"], - rf_inputs["xyz_t"], - rf_inputs["mask_t"], - ) - pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state) - return pair, state - - def forward(self, rf_inputs): - ( - msa_latent, - msa_full, - seq, - idx, - bond_feats, - dist_matrix, - msa_prev, - pair_prev, - state_prev, - xyz, - sctors, - mask_recycle, - ) = self._unpack_inputs(rf_inputs) - B, N, L = msa_latent.shape[:3] - - dtype = msa_latent.dtype - - msa_latent, pair, state = self.latent_emb( - msa_latent, seq, idx, bond_feats, dist_matrix - ) - msa_full = self.full_emb(msa_full, seq, idx) - pair = pair + self.bond_emb(bond_feats) - - msa_latent, pair = msa_latent.to(dtype), pair.to(dtype) - msa_full = msa_full.to(dtype) - if state is not None: - state = state.to(dtype) - - if msa_prev is None: - msa_prev = torch.zeros_like(msa_latent[:, 0]) - if pair_prev is None: - pair_prev = torch.zeros_like(pair) - if ( - state_prev is None or self.recycling_type == "msa_pair" - ): # explicitly remove state features if only recycling msa and pair - state_prev = torch.zeros_like(msa_latent[:, 0]) - - msa_recycle, pair_recycle, state_recycle = self.recycle( - msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle - ) - - msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype) - - msa_latent[:, 0] = msa_latent[:, 0] + msa_recycle.reshape(B, L, -1) - pair = pair + pair_recycle - - pair, state = self._add_templ_features(rf_inputs, pair, state) - return {"msa": msa_latent, "msa_full": msa_full, "pair": pair, "state": state} - - -class RF2_embedding_no_ptwise(RF2_embedding): - def __init__(self, global_params, block_params): - super(RF2_embedding_no_ptwise, self).__init__(global_params, block_params) - _d_msa, _d_msa_full, d_pair, d_state = ( - global_params["d_msa"], - global_params["d_msa_full"], - global_params["d_pair"], - global_params["d_state"], - ) - self.templ_emb = Templ_emb_NoPtwise( - d_pair=d_pair, - d_templ=block_params.d_templ, - d_state=d_state, - n_head=block_params.n_head_templ, - d_hidden=block_params.d_hidden_templ, - p_drop=block_params.templ_p_drop, - symmetrize_repeats=block_params.symmetrize_repeats, # repeat protein stuff - repeat_length=block_params.repeat_length, - symmsub_k=block_params.symmsub_k, - sym_method=block_params.sym_method, - main_block=block_params.main_block, - copy_main_block=block_params.copy_main_block_template, - additional_dt1d=block_params.additional_dt1d, - ) - - -class RF2_embedding_nostate(RF2_embedding): - def __init__(self, global_params, block_params): - super(RF2_embedding_nostate, self).__init__(global_params, block_params) - d_msa, _d_msa_full, d_pair, d_state = ( - global_params["d_msa"], - global_params["d_msa_full"], - global_params["d_pair"], - global_params["d_state"], - ) - self.latent_emb = MSA_emb_nostate( - d_msa=d_msa, - d_pair=d_pair, - d_state=d_state, - p_drop=block_params.p_drop, - use_same_chain=block_params.use_same_chain, - ) - self.templ_emb = None - - def _add_templ_features(self, rf_inputs, pair, state): - # identity - return pair, state - - -# Null module for overloading existing modules with a no-op -class Noop(nn.Module): - def forward(*args, **kwargs): - return torch.tensor([0.0]) - - -class RF2_embedding_no_ptwise_no_full(RF2_embedding_no_ptwise): - def __init__(self, global_params, block_params): - super(RF2_embedding_no_ptwise, self).__init__(global_params, block_params) - self.full_emb = Noop() - - -embedding_factory = { - "rf2aa": RF2_embedding, - "rf2aa_noptwise": RF2_embedding_no_ptwise, - "rf2aa_noptwise_no_full": RF2_embedding_no_ptwise_no_full, - "rf2aa_nostate": RF2_embedding_nostate, -}