From 44324493fd426fa4e933cea22280e7e4ad57cf12 Mon Sep 17 00:00:00 2001 From: Tuscan Thompson Date: Tue, 18 Feb 2025 13:05:54 -0800 Subject: [PATCH] fix: update plddt reshaping and permute --- rf2aa/config/train/af3_repro_rollout.yaml | 6 ++++++ rf2aa/loss/af3_confidence_loss.py | 6 +++--- rf2aa/metrics/predicted_error.py | 8 ++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/rf2aa/config/train/af3_repro_rollout.yaml b/rf2aa/config/train/af3_repro_rollout.yaml index 992d344..7d11671 100644 --- a/rf2aa/config/train/af3_repro_rollout.yaml +++ b/rf2aa/config/train/af3_repro_rollout.yaml @@ -153,6 +153,7 @@ model: n_block: 3 diffusion_transformer_block: n_head: 4 + no_residual_connection_between_attention_and_transition: True relative_position_encoding: r_max: 32 s_max: 2 @@ -171,6 +172,7 @@ model: n_block: 2 raw_template_dim: 108 c: 64 + p_drop: 0.25 msa_module: n_block: 4 c_m: 64 @@ -242,10 +244,13 @@ model: n_block: 3 diffusion_transformer_block: n_head: 4 + no_residual_connection_between_attention_and_transition: True + broadcast_trunk_feats_on_1dim_old: True diffusion_transformer: n_block: 24 diffusion_transformer_block: n_head: 16 + no_residual_connection_between_attention_and_transition: True atom_attention_decoder: atom_transformer: n_queries: 32 @@ -255,6 +260,7 @@ model: n_block: 3 diffusion_transformer_block: n_head: 4 + no_residual_connection_between_attention_and_transition: True distogram_head: bins: 65 confidence_head: diff --git a/rf2aa/loss/af3_confidence_loss.py b/rf2aa/loss/af3_confidence_loss.py index f11440f..fb7871e 100644 --- a/rf2aa/loss/af3_confidence_loss.py +++ b/rf2aa/loss/af3_confidence_loss.py @@ -121,11 +121,11 @@ class ConfidenceLoss(nn.Module): plddt_logit_stack = network_output["plddt"] plddt_per_structure = unbin_logits( plddt_logit_stack.reshape( - plddt_logit_stack.shape[0], -1, - plddt_logit_stack.shape[1], + I, ChemData().NHEAVY, - ).float(), + self.plddt.n_bins, + ).permute(0, 3, 1, 2).float(), self.plddt.max_value, self.plddt.n_bins, ) diff --git a/rf2aa/metrics/predicted_error.py b/rf2aa/metrics/predicted_error.py index 6f188ee..db54951 100644 --- a/rf2aa/metrics/predicted_error.py +++ b/rf2aa/metrics/predicted_error.py @@ -40,11 +40,11 @@ class WriteAF3Confidence(Metric): # reorder the input tensors to be in (B, n_bins, ...) format for unbinning plddt = unbin_logits( plddt_logit_stack.reshape( - plddt_logit_stack.shape[0], -1, plddt_logit_stack.shape[1], ChemData().NHEAVY, - ).float(), + self.plddt.n_bins, + ).permute(0, 3, 1, 2).float(), self.plddt.max_value, self.plddt.n_bins, ) @@ -187,8 +187,8 @@ class GetConfidenceIndices(Metric): # Reshape logits to B, K, L, NHEAVY is_real_atom = network_output["confidence"]["is_real_atom"] plddt_logits = plddt_logits.reshape( - plddt_logits.shape[0], -1, plddt_logits.shape[1], ChemData().NHEAVY - ).float() + -1, plddt_logits.shape[1], ChemData().NHEAVY, confidence_loss.plddt.n_bins + ).permute(0, 3, 1, 2).float() # Reshape the pae and pde logits to B, K, L, L pae_logits = confidence["pae_logits"].permute(0, 3, 1, 2).float() pde_logits = confidence["pde_logits"].permute(0, 3, 1, 2).float()