mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
rename config flag and reset default
This commit is contained in:
@@ -200,7 +200,7 @@ model:
|
||||
c_weighted_average: 32
|
||||
c_msa_embed: ${model.recycler.msa_module.c_m}
|
||||
c_z: ${model.c_z}
|
||||
bugfix: True
|
||||
separate_gate_for_every_channel: False
|
||||
msa_transition:
|
||||
n: 4
|
||||
c: ${model.recycler.msa_module.c_m}
|
||||
|
||||
@@ -247,7 +247,7 @@ class MsaSubsampleEmbedder(nn.Module):
|
||||
class MsaPairWeightedAverage(nn.Module):
|
||||
"""implements Algorithm 10 from AF3 paper"""
|
||||
|
||||
def __init__(self, c_weighted_average, n_heads, c_msa_embed, c_z, bugfix):
|
||||
def __init__(self, c_weighted_average, n_heads, c_msa_embed, c_z, separate_gate_for_every_channel):
|
||||
super(MsaPairWeightedAverage, self).__init__()
|
||||
self.weighted_average_channels = c_weighted_average
|
||||
self.n_heads = n_heads
|
||||
@@ -258,8 +258,8 @@ class MsaPairWeightedAverage(nn.Module):
|
||||
self.norm_pair = nn.LayerNorm(self.pair_channels)
|
||||
self.to_bias = nn.Linear(self.pair_channels, self.n_heads, bias=False)
|
||||
|
||||
self.bugfix = bugfix
|
||||
if bugfix:
|
||||
self.separate_gate_for_every_channel = separate_gate_for_every_channel
|
||||
if self.separate_gate_for_every_channel:
|
||||
self.to_gate = nn.Linear(self.msa_channels, self.weighted_average_channels*self.n_heads, bias=False)
|
||||
else:
|
||||
self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
|
||||
@@ -285,7 +285,7 @@ class MsaPairWeightedAverage(nn.Module):
|
||||
gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
|
||||
|
||||
# compute weighted average & apply gate
|
||||
if self.bugfix:
|
||||
if self.separate_gate_for_every_channel:
|
||||
weights = torch.einsum( "ijh,sjhc->sihc", w_IIH, v_SIH).reshape(S, I, -1)
|
||||
o_SIH = gate_SIH * weights
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user