Inline softmax implementation so that we don't break with old JAX

PiperOrigin-RevId: 877965714
Change-Id: I949f0472c20c62dd9d681157c8bacd6304e8cd85
This commit is contained in:
Augustin Zidek
2026-03-03 08:39:51 -08:00
committed by Copybara-Service
parent dc61fb4a31
commit 636a18bae7

View File

@@ -17,10 +17,16 @@
import json
from typing import Dict, Optional, Tuple
from jax.scipy import special
import numpy as np
def _softmax(x: np.ndarray, axis: Optional[int] = None):
x = np.asarray(x)
x_max = np.max(x, axis=axis, keepdims=True)
exp_x_shifted = np.exp(x - x_max)
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
def compute_plddt(logits: np.ndarray) -> np.ndarray:
"""Computes per-residue pLDDT from logits.
@@ -33,7 +39,7 @@ def compute_plddt(logits: np.ndarray) -> np.ndarray:
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
probs = np.array(special.softmax(logits, axis=-1))
probs = np.array(_softmax(logits, axis=-1))
predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
return predicted_lddt_ca * 100
@@ -135,7 +141,7 @@ def compute_predicted_aligned_error(
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
aligned_confidence_probs = np.array(special.softmax(logits, axis=-1))
aligned_confidence_probs = np.array(_softmax(logits, axis=-1))
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
alignment_confidence_breaks=breaks,
@@ -215,7 +221,7 @@ def predicted_tm_score(
d0 = 1.24 * (clipped_num_res - 15) ** (1.0 / 3) - 1.8
# Convert logits to probs.
probs = np.array(special.softmax(logits, axis=-1))
probs = np.array(_softmax(logits, axis=-1))
# TM-Score term for every bin.
tm_per_bin = 1.0 / (1 + np.square(bin_centers) / np.square(d0))