make 2oken_2d_features/embedder optional

This commit is contained in:
Raktim Mitra
2026-02-16 12:13:11 -08:00
parent d7fe89b8f5
commit 22d24edc5b
3 changed files with 13 additions and 7 deletions

View File

@@ -50,9 +50,9 @@ class TokenInitializer(nn.Module):
pairformer_block,
downcast,
token_1d_features,
token_2d_features,
atom_1d_features,
atom_transformer,
token_2d_features=None,
use_chunked_pll=False, # New parameter for memory optimization
):
super().__init__()
@@ -64,7 +64,10 @@ class TokenInitializer(nn.Module):
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
if token_2d_features != None:
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
else:
self.token_2d_embedder = None
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
self.transition_post_token = Transition(c=c_s, n=2)
@@ -206,7 +209,8 @@ class TokenInitializer(nn.Module):
f["ref_pos"][f["is_ca"]], valid_mask
)
# Add extra token pair features
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
if self.token_2d_embedder != None:
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
# Run a small transformer to provide position encodings to single.
for block in self.transformer_stack:

View File

@@ -803,7 +803,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
token_2d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme="atom14",
):
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_2d_features = token_2d_features
self.association_scheme = association_scheme
@@ -866,6 +866,8 @@ class AddAdditional2dFeaturesToFeats(Transform):
if "feats" not in data.keys():
data["feats"] = {}
# Only apply for features that the model is expecting:
if self.token_2d_features == None:
return data
for feature_name, n_dims in self.token_2d_features.items():
data = self.generate_token_feature(feature_name, n_dims, data)

View File

@@ -353,7 +353,7 @@ def build_atom14_base_pipeline_(
center_option: str,
atom_1d_features: dict | None,
token_1d_features: dict | None,
token_2d_features: dict | None,
token_2d_features: dict | None = None,
# PPI features
max_ppi_hotspots_frac_to_provide: float,
ppi_hotspot_max_distance: float,
@@ -373,7 +373,7 @@ def build_atom14_base_pipeline_(
"""
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Add any data necessary for downstream transforms
transforms = [
AddData(
@@ -653,7 +653,7 @@ def build_atom14_base_pipeline(
kwargs.setdefault("residue_cache_dir", None)
# TODO: Delete these once all checkpoints are updated with the latest defaults
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
kwargs.setdefault("generate_conformers_for_non_protein_only", False)
kwargs.setdefault("return_atom_array", True)
kwargs.setdefault("provide_elements_for_unindexed_components", False)
kwargs.setdefault("center_option", "all")