mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: update plddt reshaping and permute
This commit is contained in:
committed by
Rohith Krishna
parent
022df83dad
commit
44324493fd
@@ -153,6 +153,7 @@ model:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
no_residual_connection_between_attention_and_transition: True
|
||||
relative_position_encoding:
|
||||
r_max: 32
|
||||
s_max: 2
|
||||
@@ -171,6 +172,7 @@ model:
|
||||
n_block: 2
|
||||
raw_template_dim: 108
|
||||
c: 64
|
||||
p_drop: 0.25
|
||||
msa_module:
|
||||
n_block: 4
|
||||
c_m: 64
|
||||
@@ -242,10 +244,13 @@ model:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
no_residual_connection_between_attention_and_transition: True
|
||||
broadcast_trunk_feats_on_1dim_old: True
|
||||
diffusion_transformer:
|
||||
n_block: 24
|
||||
diffusion_transformer_block:
|
||||
n_head: 16
|
||||
no_residual_connection_between_attention_and_transition: True
|
||||
atom_attention_decoder:
|
||||
atom_transformer:
|
||||
n_queries: 32
|
||||
@@ -255,6 +260,7 @@ model:
|
||||
n_block: 3
|
||||
diffusion_transformer_block:
|
||||
n_head: 4
|
||||
no_residual_connection_between_attention_and_transition: True
|
||||
distogram_head:
|
||||
bins: 65
|
||||
confidence_head:
|
||||
|
||||
@@ -121,11 +121,11 @@ class ConfidenceLoss(nn.Module):
|
||||
plddt_logit_stack = network_output["plddt"]
|
||||
plddt_per_structure = unbin_logits(
|
||||
plddt_logit_stack.reshape(
|
||||
plddt_logit_stack.shape[0],
|
||||
-1,
|
||||
plddt_logit_stack.shape[1],
|
||||
I,
|
||||
ChemData().NHEAVY,
|
||||
).float(),
|
||||
self.plddt.n_bins,
|
||||
).permute(0, 3, 1, 2).float(),
|
||||
self.plddt.max_value,
|
||||
self.plddt.n_bins,
|
||||
)
|
||||
|
||||
@@ -40,11 +40,11 @@ class WriteAF3Confidence(Metric):
|
||||
# reorder the input tensors to be in (B, n_bins, ...) format for unbinning
|
||||
plddt = unbin_logits(
|
||||
plddt_logit_stack.reshape(
|
||||
plddt_logit_stack.shape[0],
|
||||
-1,
|
||||
plddt_logit_stack.shape[1],
|
||||
ChemData().NHEAVY,
|
||||
).float(),
|
||||
self.plddt.n_bins,
|
||||
).permute(0, 3, 1, 2).float(),
|
||||
self.plddt.max_value,
|
||||
self.plddt.n_bins,
|
||||
)
|
||||
@@ -187,8 +187,8 @@ class GetConfidenceIndices(Metric):
|
||||
# Reshape logits to B, K, L, NHEAVY
|
||||
is_real_atom = network_output["confidence"]["is_real_atom"]
|
||||
plddt_logits = plddt_logits.reshape(
|
||||
plddt_logits.shape[0], -1, plddt_logits.shape[1], ChemData().NHEAVY
|
||||
).float()
|
||||
-1, plddt_logits.shape[1], ChemData().NHEAVY, confidence_loss.plddt.n_bins
|
||||
).permute(0, 3, 1, 2).float()
|
||||
# Reshape the pae and pde logits to B, K, L, L
|
||||
pae_logits = confidence["pae_logits"].permute(0, 3, 1, 2).float()
|
||||
pde_logits = confidence["pde_logits"].permute(0, 3, 1, 2).float()
|
||||
|
||||
Reference in New Issue
Block a user