Format msa_pairing

PiperOrigin-RevId: 794545379
Change-Id: I40ca93f7e6864a5906a40168250261c9c56c59ff
This commit is contained in:
Augustin Zidek
2025-08-13 06:06:55 -07:00
committed by Copybara-Service
parent 6544d69fc3
commit 79a0210693

View File

@@ -15,7 +15,7 @@
"""Pairing logic for multimer data pipeline."""
import collections
from typing import cast, Dict, Iterable, List, Sequence
from typing import Dict, Iterable, List, Sequence, cast
from alphafold.common import residue_constants
from alphafold.data import pipeline
@@ -27,30 +27,48 @@ MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_PAD_VALUES = {
'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0,
}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
SEQ_FEATURES = (
'residue_index',
'aatype',
'all_atom_positions',
'all_atom_mask',
'seq_mask',
'between_segment_residues',
'has_alt_locations',
'has_hetatoms',
'asym_id',
'entity_id',
'sym_id',
'entity_mask',
'deletion_mean',
'prediction_atom_mask',
'literature_positions',
'atom_indices_to_group_indices',
'rigid_group_default_frame',
)
TEMPLATE_FEATURES = (
'template_aatype',
'template_all_atom_positions',
'template_all_atom_mask',
)
CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
chains: Iterable[pipeline.FeatureDict],
) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
@@ -65,22 +83,22 @@ def create_paired_features(
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num])
)
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
@@ -97,11 +115,16 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.bytes_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
if feature_name in (
'msa_all_seq',
'msa_mask_all_seq',
'deletion_matrix_all_seq',
'deletion_matrix_int_all_seq',
):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
padding = MSA_PAD_VALUES[feature_name] * np.ones(
[1, num_res], feature.dtype
)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
@@ -114,17 +137,19 @@ def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_similarity = np.sum(query_seq[None] == chain_msa, axis=-1) / float(
len(query_seq)
)
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_species_identifiers_all_seq'])),
'msa_species_identifiers': chain_features[
'msa_species_identifiers_all_seq'
],
'msa_row': np.arange(
len(chain_features['msa_species_identifiers_all_seq'])
),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
'gap': per_seq_gap,
})
return msa_df
@@ -137,8 +162,9 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return species_lookup
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
def _match_rows_by_sequence_similarity(
this_species_msa_dfs: List[pd.DataFrame],
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
@@ -155,12 +181,16 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
num_seqs = [
len(species_df)
for species_df in this_species_msa_dfs
if species_df is not None
]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
sort_by_similarity = lambda x: x.sort_values(
'msa_similarity', axis=0, ascending=False
)
for species_df in this_species_msa_dfs:
if species_df is not None:
@@ -173,8 +203,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict]
) -> Dict[int, np.ndarray]:
def pair_sequences(
examples: List[pipeline.FeatureDict],
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
@@ -211,23 +242,28 @@ def pair_sequences(examples: List[pipeline.FeatureDict]
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
np.array([
len(species_df)
for species_df in this_species_msa_dfs
if isinstance(species_df, pd.DataFrame)
])
> 600
):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
num_examples: np.array(paired_msa_rows)
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
def reorder_paired_rows(
all_paired_msa_rows_dict: Dict[int, np.ndarray],
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
@@ -264,13 +300,16 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
def _correct_post_merged_feats(
np_example: pipeline.FeatureDict,
np_chains_list: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
pair_msa_sequences: bool,
) -> pipeline.FeatureDict:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
dtype=np.int32)
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
dtype=np.int32)
np_example['seq_length'] = np.asarray(
np_example['aatype'].shape[0], dtype=np.int32
)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0], dtype=np.int32
)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
@@ -285,31 +324,35 @@ def _correct_post_merged_feats(
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list]
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list
]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0)
np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list]
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list]
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list
]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.float32)
for x in np_chains_list
]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0)
msa_mask_block_diag = block_diag(*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
[msa_mask_all_seq, msa_mask_block_diag], axis=0
)
return np_example
def _pad_templates(chains: Sequence[pipeline.FeatureDict],
max_templates: int) -> Sequence[pipeline.FeatureDict]:
def _pad_templates(
chains: Sequence[pipeline.FeatureDict], max_templates: int
) -> Sequence[pipeline.FeatureDict]:
"""For each chain pad the number of templates to a fixed size.
Args:
@@ -331,14 +374,14 @@ def _pad_templates(chains: Sequence[pipeline.FeatureDict],
def _merge_features_from_multiple_chains(
chains: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
chains: Sequence[pipeline.FeatureDict], pair_msa_sequences: bool
) -> pipeline.FeatureDict:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
pair_msa_sequences: Whether to concatenate MSA features along the num_res
dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
@@ -352,7 +395,8 @@ def _merge_features_from_multiple_chains(
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
*feats, pad_value=MSA_PAD_VALUES[feature_name]
)
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
@@ -365,7 +409,8 @@ def _merge_features_from_multiple_chains(
def _merge_homomers_dense_msa(
chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
chains: Iterable[pipeline.FeatureDict],
) -> Sequence[pipeline.FeatureDict]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
@@ -387,12 +432,14 @@ def _merge_homomers_dense_msa(
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
for chains in grouped_chains
]
return chains
def _concatenate_paired_and_unpaired_features(
example: pipeline.FeatureDict) -> pipeline.FeatureDict:
example: pipeline.FeatureDict,
) -> pipeline.FeatureDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
@@ -401,14 +448,15 @@ def _concatenate_paired_and_unpaired_features(
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
example['num_alignments'] = np.array(example['msa'].shape[0], dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
def merge_chain_features(
np_chains_list: List[pipeline.FeatureDict],
pair_msa_sequences: bool,
max_templates: int,
) -> pipeline.FeatureDict:
"""Merges features for multiple chains to single FeatureDict.
Args:
@@ -419,25 +467,27 @@ def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _pad_templates(np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
np_chains_list, pair_msa_sequences=False
)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
pair_msa_sequences=pair_msa_sequences,
)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
np_chains: List[pipeline.FeatureDict],
) -> List[pipeline.FeatureDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()