mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
switch to templateV2 module instead of template module. During refactoring i previously accidentally took template module instead of templatemodulev2. the difference is that teamplatemodule masks out inter-chain templates (this only affects the refolding step)
This commit is contained in:
@@ -286,6 +286,7 @@ class TemplateModule(nn.Module):
|
||||
)
|
||||
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
||||
|
||||
|
||||
if miniformer_blocks:
|
||||
self.pairformer = MiniformerNoSeqModule(
|
||||
template_dim,
|
||||
@@ -330,7 +331,6 @@ class TemplateModule(nn.Module):
|
||||
|
||||
"""
|
||||
# Load relevant features
|
||||
asym_id = feats["asym_id"]
|
||||
res_type = feats["template_restype"]
|
||||
frame_rot = feats["template_frame_rot"]
|
||||
frame_t = feats["template_frame_t"]
|
||||
@@ -338,6 +338,7 @@ class TemplateModule(nn.Module):
|
||||
cb_coords = feats["template_cb"]
|
||||
ca_coords = feats["template_ca"]
|
||||
cb_mask = feats["template_mask_cb"]
|
||||
visibility_ids = feats["visibility_ids"]
|
||||
template_mask = feats["template_mask"].any(dim=2).float()
|
||||
num_templates = template_mask.sum(dim=1)
|
||||
num_templates = num_templates.clamp(min=1)
|
||||
@@ -351,8 +352,9 @@ class TemplateModule(nn.Module):
|
||||
|
||||
# Compute asym mask, template features only attend within the same chain
|
||||
B, T = res_type.shape[:2] # noqa: N806
|
||||
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
|
||||
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
|
||||
tmlp_pair_mask = (
|
||||
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
|
||||
).float()
|
||||
|
||||
# Compute template features
|
||||
with torch.autocast(device_type="cuda", enabled=False):
|
||||
@@ -375,7 +377,8 @@ class TemplateModule(nn.Module):
|
||||
# Concatenate input features
|
||||
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
||||
a_tij = torch.cat(a_tij, dim=-1)
|
||||
a_tij = a_tij * asym_mask.unsqueeze(-1)
|
||||
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
|
||||
|
||||
res_type_i = res_type[:, :, :, None]
|
||||
res_type_j = res_type[:, :, None, :]
|
||||
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
||||
|
||||
Reference in New Issue
Block a user