rename config flag and reset default

This commit is contained in:
Rohith Krishna
2025-02-10 18:20:58 -08:00
parent 7d33e24dd5
commit 78604610a2
2 changed files with 5 additions and 5 deletions

View File

@@ -200,7 +200,7 @@ model:
c_weighted_average: 32
c_msa_embed: ${model.recycler.msa_module.c_m}
c_z: ${model.c_z}
bugfix: True
separate_gate_for_every_channel: False
msa_transition:
n: 4
c: ${model.recycler.msa_module.c_m}

View File

@@ -247,7 +247,7 @@ class MsaSubsampleEmbedder(nn.Module):
class MsaPairWeightedAverage(nn.Module):
"""implements Algorithm 10 from AF3 paper"""
def __init__(self, c_weighted_average, n_heads, c_msa_embed, c_z, bugfix):
def __init__(self, c_weighted_average, n_heads, c_msa_embed, c_z, separate_gate_for_every_channel):
super(MsaPairWeightedAverage, self).__init__()
self.weighted_average_channels = c_weighted_average
self.n_heads = n_heads
@@ -258,8 +258,8 @@ class MsaPairWeightedAverage(nn.Module):
self.norm_pair = nn.LayerNorm(self.pair_channels)
self.to_bias = nn.Linear(self.pair_channels, self.n_heads, bias=False)
self.bugfix = bugfix
if bugfix:
self.separate_gate_for_every_channel = separate_gate_for_every_channel
if self.separate_gate_for_every_channel:
self.to_gate = nn.Linear(self.msa_channels, self.weighted_average_channels*self.n_heads, bias=False)
else:
self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
@@ -285,7 +285,7 @@ class MsaPairWeightedAverage(nn.Module):
gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
# compute weighted average & apply gate
if self.bugfix:
if self.separate_gate_for_every_channel:
weights = torch.einsum( "ijh,sjhc->sihc", w_IIH, v_SIH).reshape(S, I, -1)
o_SIH = gate_SIH * weights
else: