mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Replace custom block_diagonal with jax.scipy.linalg.block_diag.
PiperOrigin-RevId: 824421836 Change-Id: I036ee6f72d0364c56fe62a24296d7fba4ae521fd
This commit is contained in:
committed by
Copybara-Service
parent
cd0357af72
commit
a138beeaf1
@@ -20,6 +20,7 @@ from typing import Iterable, List, Mapping, Sequence
|
||||
|
||||
from alphafold.common import residue_constants
|
||||
from alphafold.data import pipeline
|
||||
from jax.scipy import linalg
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -348,54 +349,12 @@ def reorder_paired_rows(
|
||||
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
|
||||
"""Like scipy.linalg.block_diag but with an optional padding value."""
|
||||
ones_arrs = [np.ones_like(x) for x in arrs]
|
||||
off_diag_mask = 1.0 - block_diagonal(*ones_arrs)
|
||||
diag = block_diagonal(*arrs)
|
||||
off_diag_mask = 1.0 - np.array(linalg.block_diag(*ones_arrs))
|
||||
diag = np.array(linalg.block_diag(*arrs))
|
||||
diag += (off_diag_mask * pad_value).astype(diag.dtype)
|
||||
return diag
|
||||
|
||||
|
||||
def block_diagonal(*arrays: np.ndarray) -> np.ndarray:
|
||||
"""Creates a block diagonal matrix from a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
*arrays: A sequence of NumPy arrays. They can be 1D or 2D, or higher
|
||||
dimensional, in which case dimensions before the last two are considered
|
||||
batch dimensions.
|
||||
|
||||
Returns:
|
||||
A NumPy array with blocks formed by input arrays on its diagonal.
|
||||
"""
|
||||
if not arrays:
|
||||
arrays = ([],)
|
||||
arrays = [np.atleast_2d(a) for a in arrays]
|
||||
|
||||
# Broadcast leading dimensions if they exist.
|
||||
batch_shapes = [array.shape[:-2] for array in arrays]
|
||||
batch_shape = np.broadcast_shapes(*batch_shapes)
|
||||
arrays = [
|
||||
np.broadcast_to(array, batch_shape + array.shape[-2:]) for array in arrays
|
||||
]
|
||||
|
||||
output_dtype = np.result_type(*(array.dtype for array in arrays))
|
||||
block_shapes = np.array([array.shape[-2:] for array in arrays])
|
||||
# The output shape is batch_shape + (sum of rows, sum of cols).
|
||||
output_array = np.zeros(
|
||||
batch_shape + tuple(np.sum(block_shapes, axis=0)), dtype=output_dtype
|
||||
)
|
||||
current_row = 0
|
||||
current_col = 0
|
||||
for array, (num_rows, num_cols) in zip(arrays, block_shapes, strict=True):
|
||||
# Place array block i in the output array.
|
||||
output_array[
|
||||
...,
|
||||
current_row : current_row + num_rows,
|
||||
current_col : current_col + num_cols,
|
||||
] = array
|
||||
current_row += num_rows
|
||||
current_col += num_cols
|
||||
return output_array
|
||||
|
||||
|
||||
def _correct_post_merged_feats(
|
||||
np_example: pipeline.FeatureDict,
|
||||
np_chains_list: Sequence[pipeline.FeatureDict],
|
||||
|
||||
Reference in New Issue
Block a user