fix: update plddt reshaping and permute

This commit is contained in:
Tuscan Thompson
2025-02-18 13:05:54 -08:00
committed by Rohith Krishna
parent 022df83dad
commit 44324493fd
3 changed files with 13 additions and 7 deletions

View File

@@ -153,6 +153,7 @@ model:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: True
relative_position_encoding:
r_max: 32
s_max: 2
@@ -171,6 +172,7 @@ model:
n_block: 2
raw_template_dim: 108
c: 64
p_drop: 0.25
msa_module:
n_block: 4
c_m: 64
@@ -242,10 +244,13 @@ model:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: True
broadcast_trunk_feats_on_1dim_old: True
diffusion_transformer:
n_block: 24
diffusion_transformer_block:
n_head: 16
no_residual_connection_between_attention_and_transition: True
atom_attention_decoder:
atom_transformer:
n_queries: 32
@@ -255,6 +260,7 @@ model:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: True
distogram_head:
bins: 65
confidence_head:

View File

@@ -121,11 +121,11 @@ class ConfidenceLoss(nn.Module):
plddt_logit_stack = network_output["plddt"]
plddt_per_structure = unbin_logits(
plddt_logit_stack.reshape(
plddt_logit_stack.shape[0],
-1,
plddt_logit_stack.shape[1],
I,
ChemData().NHEAVY,
).float(),
self.plddt.n_bins,
).permute(0, 3, 1, 2).float(),
self.plddt.max_value,
self.plddt.n_bins,
)

View File

@@ -40,11 +40,11 @@ class WriteAF3Confidence(Metric):
# reorder the input tensors to be in (B, n_bins, ...) format for unbinning
plddt = unbin_logits(
plddt_logit_stack.reshape(
plddt_logit_stack.shape[0],
-1,
plddt_logit_stack.shape[1],
ChemData().NHEAVY,
).float(),
self.plddt.n_bins,
).permute(0, 3, 1, 2).float(),
self.plddt.max_value,
self.plddt.n_bins,
)
@@ -187,8 +187,8 @@ class GetConfidenceIndices(Metric):
# Reshape logits to B, K, L, NHEAVY
is_real_atom = network_output["confidence"]["is_real_atom"]
plddt_logits = plddt_logits.reshape(
plddt_logits.shape[0], -1, plddt_logits.shape[1], ChemData().NHEAVY
).float()
-1, plddt_logits.shape[1], ChemData().NHEAVY, confidence_loss.plddt.n_bins
).permute(0, 3, 1, 2).float()
# Reshape the pae and pde logits to B, K, L, L
pae_logits = confidence["pae_logits"].permute(0, 3, 1, 2).float()
pde_logits = confidence["pde_logits"].permute(0, 3, 1, 2).float()