Replace custom block_diagonal with jax.scipy.linalg.block_diag.

PiperOrigin-RevId: 824421836
Change-Id: I036ee6f72d0364c56fe62a24296d7fba4ae521fd
This commit is contained in:
Harsh Tiku
2025-10-27 02:28:38 -07:00
committed by Copybara-Service
parent cd0357af72
commit a138beeaf1

View File

@@ -20,6 +20,7 @@ from typing import Iterable, List, Mapping, Sequence
from alphafold.common import residue_constants
from alphafold.data import pipeline
from jax.scipy import linalg
import numpy as np
@@ -348,54 +349,12 @@ def reorder_paired_rows(
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - block_diagonal(*ones_arrs)
diag = block_diagonal(*arrs)
off_diag_mask = 1.0 - np.array(linalg.block_diag(*ones_arrs))
diag = np.array(linalg.block_diag(*arrs))
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def block_diagonal(*arrays: np.ndarray) -> np.ndarray:
"""Creates a block diagonal matrix from a list of numpy arrays.
Args:
*arrays: A sequence of NumPy arrays. They can be 1D or 2D, or higher
dimensional, in which case dimensions before the last two are considered
batch dimensions.
Returns:
A NumPy array with blocks formed by input arrays on its diagonal.
"""
if not arrays:
arrays = ([],)
arrays = [np.atleast_2d(a) for a in arrays]
# Broadcast leading dimensions if they exist.
batch_shapes = [array.shape[:-2] for array in arrays]
batch_shape = np.broadcast_shapes(*batch_shapes)
arrays = [
np.broadcast_to(array, batch_shape + array.shape[-2:]) for array in arrays
]
output_dtype = np.result_type(*(array.dtype for array in arrays))
block_shapes = np.array([array.shape[-2:] for array in arrays])
# The output shape is batch_shape + (sum of rows, sum of cols).
output_array = np.zeros(
batch_shape + tuple(np.sum(block_shapes, axis=0)), dtype=output_dtype
)
current_row = 0
current_col = 0
for array, (num_rows, num_cols) in zip(arrays, block_shapes, strict=True):
# Place array block i in the output array.
output_array[
...,
current_row : current_row + num_rows,
current_col : current_col + num_cols,
] = array
current_row += num_rows
current_col += num_cols
return output_array
def _correct_post_merged_feats(
np_example: pipeline.FeatureDict,
np_chains_list: Sequence[pipeline.FeatureDict],