switch to templateV2 module instead of template module. During refactoring i previously accidentally took template module instead of templatemodulev2. the difference is that teamplatemodule masks out inter-chain templates (this only affects the refolding step)

This commit is contained in:
Hannes Stärk
2026-02-09 14:11:26 -05:00
committed by GitHub
parent 135a4a5a97
commit 50be67f942

View File

@@ -286,6 +286,7 @@ class TemplateModule(nn.Module):
)
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
if miniformer_blocks:
self.pairformer = MiniformerNoSeqModule(
template_dim,
@@ -330,7 +331,6 @@ class TemplateModule(nn.Module):
"""
# Load relevant features
asym_id = feats["asym_id"]
res_type = feats["template_restype"]
frame_rot = feats["template_frame_rot"]
frame_t = feats["template_frame_t"]
@@ -338,6 +338,7 @@ class TemplateModule(nn.Module):
cb_coords = feats["template_cb"]
ca_coords = feats["template_ca"]
cb_mask = feats["template_mask_cb"]
visibility_ids = feats["visibility_ids"]
template_mask = feats["template_mask"].any(dim=2).float()
num_templates = template_mask.sum(dim=1)
num_templates = num_templates.clamp(min=1)
@@ -351,8 +352,9 @@ class TemplateModule(nn.Module):
# Compute asym mask, template features only attend within the same chain
B, T = res_type.shape[:2] # noqa: N806
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
tmlp_pair_mask = (
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
).float()
# Compute template features
with torch.autocast(device_type="cuda", enabled=False):
@@ -375,7 +377,8 @@ class TemplateModule(nn.Module):
# Concatenate input features
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
a_tij = torch.cat(a_tij, dim=-1)
a_tij = a_tij * asym_mask.unsqueeze(-1)
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
res_type_i = res_type[:, :, :, None]
res_type_j = res_type[:, :, None, :]
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)