mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
make 2oken_2d_features/embedder optional
This commit is contained in:
@@ -50,9 +50,9 @@ class TokenInitializer(nn.Module):
|
||||
pairformer_block,
|
||||
downcast,
|
||||
token_1d_features,
|
||||
token_2d_features,
|
||||
atom_1d_features,
|
||||
atom_transformer,
|
||||
token_2d_features=None,
|
||||
use_chunked_pll=False, # New parameter for memory optimization
|
||||
):
|
||||
super().__init__()
|
||||
@@ -64,7 +64,10 @@ class TokenInitializer(nn.Module):
|
||||
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
|
||||
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
|
||||
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
|
||||
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
|
||||
if token_2d_features != None:
|
||||
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
|
||||
else:
|
||||
self.token_2d_embedder = None
|
||||
|
||||
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
|
||||
self.transition_post_token = Transition(c=c_s, n=2)
|
||||
@@ -206,7 +209,8 @@ class TokenInitializer(nn.Module):
|
||||
f["ref_pos"][f["is_ca"]], valid_mask
|
||||
)
|
||||
# Add extra token pair features
|
||||
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
|
||||
if self.token_2d_embedder != None:
|
||||
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
|
||||
|
||||
# Run a small transformer to provide position encodings to single.
|
||||
for block in self.transformer_stack:
|
||||
|
||||
@@ -803,7 +803,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
token_2d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme="atom14",
|
||||
):
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_2d_features = token_2d_features
|
||||
self.association_scheme = association_scheme
|
||||
@@ -866,6 +866,8 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
# Only apply for features that the model is expecting:
|
||||
if self.token_2d_features == None:
|
||||
return data
|
||||
for feature_name, n_dims in self.token_2d_features.items():
|
||||
data = self.generate_token_feature(feature_name, n_dims, data)
|
||||
|
||||
|
||||
@@ -353,7 +353,7 @@ def build_atom14_base_pipeline_(
|
||||
center_option: str,
|
||||
atom_1d_features: dict | None,
|
||||
token_1d_features: dict | None,
|
||||
token_2d_features: dict | None,
|
||||
token_2d_features: dict | None = None,
|
||||
# PPI features
|
||||
max_ppi_hotspots_frac_to_provide: float,
|
||||
ppi_hotspot_max_distance: float,
|
||||
@@ -373,7 +373,7 @@ def build_atom14_base_pipeline_(
|
||||
"""
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
|
||||
# Add any data necessary for downstream transforms
|
||||
transforms = [
|
||||
AddData(
|
||||
@@ -653,7 +653,7 @@ def build_atom14_base_pipeline(
|
||||
kwargs.setdefault("residue_cache_dir", None)
|
||||
|
||||
# TODO: Delete these once all checkpoints are updated with the latest defaults
|
||||
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
|
||||
kwargs.setdefault("generate_conformers_for_non_protein_only", False)
|
||||
kwargs.setdefault("return_atom_array", True)
|
||||
kwargs.setdefault("provide_elements_for_unindexed_components", False)
|
||||
kwargs.setdefault("center_option", "all")
|
||||
|
||||
Reference in New Issue
Block a user