mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Inline softmax implementation so that we don't break with old JAX
PiperOrigin-RevId: 877965714 Change-Id: I949f0472c20c62dd9d681157c8bacd6304e8cd85
This commit is contained in:
committed by
Copybara-Service
parent
dc61fb4a31
commit
636a18bae7
@@ -17,10 +17,16 @@
|
||||
import json
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from jax.scipy import special
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _softmax(x: np.ndarray, axis: Optional[int] = None):
|
||||
x = np.asarray(x)
|
||||
x_max = np.max(x, axis=axis, keepdims=True)
|
||||
exp_x_shifted = np.exp(x - x_max)
|
||||
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
||||
|
||||
|
||||
def compute_plddt(logits: np.ndarray) -> np.ndarray:
|
||||
"""Computes per-residue pLDDT from logits.
|
||||
|
||||
@@ -33,7 +39,7 @@ def compute_plddt(logits: np.ndarray) -> np.ndarray:
|
||||
num_bins = logits.shape[-1]
|
||||
bin_width = 1.0 / num_bins
|
||||
bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
|
||||
probs = np.array(special.softmax(logits, axis=-1))
|
||||
probs = np.array(_softmax(logits, axis=-1))
|
||||
predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
|
||||
return predicted_lddt_ca * 100
|
||||
|
||||
@@ -135,7 +141,7 @@ def compute_predicted_aligned_error(
|
||||
error for each pair of residues.
|
||||
max_predicted_aligned_error: The maximum predicted error possible.
|
||||
"""
|
||||
aligned_confidence_probs = np.array(special.softmax(logits, axis=-1))
|
||||
aligned_confidence_probs = np.array(_softmax(logits, axis=-1))
|
||||
predicted_aligned_error, max_predicted_aligned_error = (
|
||||
_calculate_expected_aligned_error(
|
||||
alignment_confidence_breaks=breaks,
|
||||
@@ -215,7 +221,7 @@ def predicted_tm_score(
|
||||
d0 = 1.24 * (clipped_num_res - 15) ** (1.0 / 3) - 1.8
|
||||
|
||||
# Convert logits to probs.
|
||||
probs = np.array(special.softmax(logits, axis=-1))
|
||||
probs = np.array(_softmax(logits, axis=-1))
|
||||
|
||||
# TM-Score term for every bin.
|
||||
tm_per_bin = 1.0 / (1 + np.square(bin_centers) / np.square(d0))
|
||||
|
||||
Reference in New Issue
Block a user