mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Format msa_pairing
PiperOrigin-RevId: 794545379 Change-Id: I40ca93f7e6864a5906a40168250261c9c56c59ff
This commit is contained in:
committed by
Copybara-Service
parent
6544d69fc3
commit
79a0210693
@@ -15,7 +15,7 @@
|
||||
"""Pairing logic for multimer data pipeline."""
|
||||
|
||||
import collections
|
||||
from typing import cast, Dict, Iterable, List, Sequence
|
||||
from typing import Dict, Iterable, List, Sequence, cast
|
||||
|
||||
from alphafold.common import residue_constants
|
||||
from alphafold.data import pipeline
|
||||
@@ -27,30 +27,48 @@ MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
|
||||
SEQUENCE_GAP_CUTOFF = 0.5
|
||||
SEQUENCE_SIMILARITY_CUTOFF = 0.9
|
||||
|
||||
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
|
||||
'msa_mask_all_seq': 1,
|
||||
'deletion_matrix_all_seq': 0,
|
||||
'deletion_matrix_int_all_seq': 0,
|
||||
'msa': MSA_GAP_IDX,
|
||||
'msa_mask': 1,
|
||||
'deletion_matrix': 0,
|
||||
'deletion_matrix_int': 0}
|
||||
MSA_PAD_VALUES = {
|
||||
'msa_all_seq': MSA_GAP_IDX,
|
||||
'msa_mask_all_seq': 1,
|
||||
'deletion_matrix_all_seq': 0,
|
||||
'deletion_matrix_int_all_seq': 0,
|
||||
'msa': MSA_GAP_IDX,
|
||||
'msa_mask': 1,
|
||||
'deletion_matrix': 0,
|
||||
'deletion_matrix_int': 0,
|
||||
}
|
||||
|
||||
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
|
||||
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
|
||||
'all_atom_mask', 'seq_mask', 'between_segment_residues',
|
||||
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
|
||||
'sym_id', 'entity_mask', 'deletion_mean',
|
||||
'prediction_atom_mask',
|
||||
'literature_positions', 'atom_indices_to_group_indices',
|
||||
'rigid_group_default_frame')
|
||||
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
|
||||
'template_all_atom_mask')
|
||||
SEQ_FEATURES = (
|
||||
'residue_index',
|
||||
'aatype',
|
||||
'all_atom_positions',
|
||||
'all_atom_mask',
|
||||
'seq_mask',
|
||||
'between_segment_residues',
|
||||
'has_alt_locations',
|
||||
'has_hetatoms',
|
||||
'asym_id',
|
||||
'entity_id',
|
||||
'sym_id',
|
||||
'entity_mask',
|
||||
'deletion_mean',
|
||||
'prediction_atom_mask',
|
||||
'literature_positions',
|
||||
'atom_indices_to_group_indices',
|
||||
'rigid_group_default_frame',
|
||||
)
|
||||
TEMPLATE_FEATURES = (
|
||||
'template_aatype',
|
||||
'template_all_atom_positions',
|
||||
'template_all_atom_mask',
|
||||
)
|
||||
CHAIN_FEATURES = ('num_alignments', 'seq_length')
|
||||
|
||||
|
||||
def create_paired_features(
|
||||
chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
|
||||
chains: Iterable[pipeline.FeatureDict],
|
||||
) -> List[pipeline.FeatureDict]:
|
||||
"""Returns the original chains with paired NUM_SEQ features.
|
||||
|
||||
Args:
|
||||
@@ -65,22 +83,22 @@ def create_paired_features(
|
||||
|
||||
if len(chains) < 2:
|
||||
return chains
|
||||
else:
|
||||
updated_chains = []
|
||||
paired_chains_to_paired_row_indices = pair_sequences(chains)
|
||||
paired_rows = reorder_paired_rows(
|
||||
paired_chains_to_paired_row_indices)
|
||||
|
||||
for chain_num, chain in enumerate(chains):
|
||||
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
|
||||
for feature_name in chain_keys:
|
||||
if feature_name.endswith('_all_seq'):
|
||||
feats_padded = pad_features(chain[feature_name], feature_name)
|
||||
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
|
||||
new_chain['num_alignments_all_seq'] = np.asarray(
|
||||
len(paired_rows[:, chain_num]))
|
||||
updated_chains.append(new_chain)
|
||||
return updated_chains
|
||||
updated_chains = []
|
||||
paired_chains_to_paired_row_indices = pair_sequences(chains)
|
||||
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices)
|
||||
|
||||
for chain_num, chain in enumerate(chains):
|
||||
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
|
||||
for feature_name in chain_keys:
|
||||
if feature_name.endswith('_all_seq'):
|
||||
feats_padded = pad_features(chain[feature_name], feature_name)
|
||||
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
|
||||
new_chain['num_alignments_all_seq'] = np.asarray(
|
||||
len(paired_rows[:, chain_num])
|
||||
)
|
||||
updated_chains.append(new_chain)
|
||||
return updated_chains
|
||||
|
||||
|
||||
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
|
||||
@@ -97,11 +115,16 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
|
||||
The feature with an additional padding row.
|
||||
"""
|
||||
assert feature.dtype != np.dtype(np.bytes_)
|
||||
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
|
||||
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
|
||||
if feature_name in (
|
||||
'msa_all_seq',
|
||||
'msa_mask_all_seq',
|
||||
'deletion_matrix_all_seq',
|
||||
'deletion_matrix_int_all_seq',
|
||||
):
|
||||
num_res = feature.shape[1]
|
||||
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
|
||||
feature.dtype)
|
||||
padding = MSA_PAD_VALUES[feature_name] * np.ones(
|
||||
[1, num_res], feature.dtype
|
||||
)
|
||||
elif feature_name == 'msa_species_identifiers_all_seq':
|
||||
padding = [b'']
|
||||
else:
|
||||
@@ -114,17 +137,19 @@ def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
|
||||
"""Makes dataframe with msa features needed for msa pairing."""
|
||||
chain_msa = chain_features['msa_all_seq']
|
||||
query_seq = chain_msa[0]
|
||||
per_seq_similarity = np.sum(
|
||||
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
|
||||
per_seq_similarity = np.sum(query_seq[None] == chain_msa, axis=-1) / float(
|
||||
len(query_seq)
|
||||
)
|
||||
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
|
||||
msa_df = pd.DataFrame({
|
||||
'msa_species_identifiers':
|
||||
chain_features['msa_species_identifiers_all_seq'],
|
||||
'msa_row':
|
||||
np.arange(len(
|
||||
chain_features['msa_species_identifiers_all_seq'])),
|
||||
'msa_species_identifiers': chain_features[
|
||||
'msa_species_identifiers_all_seq'
|
||||
],
|
||||
'msa_row': np.arange(
|
||||
len(chain_features['msa_species_identifiers_all_seq'])
|
||||
),
|
||||
'msa_similarity': per_seq_similarity,
|
||||
'gap': per_seq_gap
|
||||
'gap': per_seq_gap,
|
||||
})
|
||||
return msa_df
|
||||
|
||||
@@ -137,8 +162,9 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
|
||||
return species_lookup
|
||||
|
||||
|
||||
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
|
||||
) -> List[List[int]]:
|
||||
def _match_rows_by_sequence_similarity(
|
||||
this_species_msa_dfs: List[pd.DataFrame],
|
||||
) -> List[List[int]]:
|
||||
"""Finds MSA sequence pairings across chains based on sequence similarity.
|
||||
|
||||
Each chain's MSA sequences are first sorted by their sequence similarity to
|
||||
@@ -155,12 +181,16 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
|
||||
"""
|
||||
all_paired_msa_rows = []
|
||||
|
||||
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
|
||||
if species_df is not None]
|
||||
num_seqs = [
|
||||
len(species_df)
|
||||
for species_df in this_species_msa_dfs
|
||||
if species_df is not None
|
||||
]
|
||||
take_num_seqs = np.min(num_seqs)
|
||||
|
||||
sort_by_similarity = (
|
||||
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
|
||||
sort_by_similarity = lambda x: x.sort_values(
|
||||
'msa_similarity', axis=0, ascending=False
|
||||
)
|
||||
|
||||
for species_df in this_species_msa_dfs:
|
||||
if species_df is not None:
|
||||
@@ -173,8 +203,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
|
||||
return all_paired_msa_rows
|
||||
|
||||
|
||||
def pair_sequences(examples: List[pipeline.FeatureDict]
|
||||
) -> Dict[int, np.ndarray]:
|
||||
def pair_sequences(
|
||||
examples: List[pipeline.FeatureDict],
|
||||
) -> Dict[int, np.ndarray]:
|
||||
"""Returns indices for paired MSA sequences across chains."""
|
||||
|
||||
num_examples = len(examples)
|
||||
@@ -211,23 +242,28 @@ def pair_sequences(examples: List[pipeline.FeatureDict]
|
||||
continue
|
||||
|
||||
if np.any(
|
||||
np.array([len(species_df) for species_df in
|
||||
this_species_msa_dfs if
|
||||
isinstance(species_df, pd.DataFrame)]) > 600):
|
||||
np.array([
|
||||
len(species_df)
|
||||
for species_df in this_species_msa_dfs
|
||||
if isinstance(species_df, pd.DataFrame)
|
||||
])
|
||||
> 600
|
||||
):
|
||||
continue
|
||||
|
||||
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
|
||||
all_paired_msa_rows.extend(paired_msa_rows)
|
||||
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
|
||||
all_paired_msa_rows_dict = {
|
||||
num_examples: np.array(paired_msa_rows) for
|
||||
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
|
||||
num_examples: np.array(paired_msa_rows)
|
||||
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
|
||||
}
|
||||
return all_paired_msa_rows_dict
|
||||
|
||||
|
||||
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
|
||||
) -> np.ndarray:
|
||||
def reorder_paired_rows(
|
||||
all_paired_msa_rows_dict: Dict[int, np.ndarray],
|
||||
) -> np.ndarray:
|
||||
"""Creates a list of indices of paired MSA rows across chains.
|
||||
|
||||
Args:
|
||||
@@ -264,13 +300,16 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
|
||||
def _correct_post_merged_feats(
|
||||
np_example: pipeline.FeatureDict,
|
||||
np_chains_list: Sequence[pipeline.FeatureDict],
|
||||
pair_msa_sequences: bool) -> pipeline.FeatureDict:
|
||||
pair_msa_sequences: bool,
|
||||
) -> pipeline.FeatureDict:
|
||||
"""Adds features that need to be computed/recomputed post merging."""
|
||||
|
||||
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
|
||||
dtype=np.int32)
|
||||
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
|
||||
dtype=np.int32)
|
||||
np_example['seq_length'] = np.asarray(
|
||||
np_example['aatype'].shape[0], dtype=np.int32
|
||||
)
|
||||
np_example['num_alignments'] = np.asarray(
|
||||
np_example['msa'].shape[0], dtype=np.int32
|
||||
)
|
||||
|
||||
if not pair_msa_sequences:
|
||||
# Generate a bias that is 1 for the first row of every block in the
|
||||
@@ -285,31 +324,35 @@ def _correct_post_merged_feats(
|
||||
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
|
||||
|
||||
# Initialize Bert mask with masked out off diagonals.
|
||||
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
|
||||
for x in np_chains_list]
|
||||
msa_masks = [
|
||||
np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list
|
||||
]
|
||||
|
||||
np_example['bert_mask'] = block_diag(
|
||||
*msa_masks, pad_value=0)
|
||||
np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0)
|
||||
else:
|
||||
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
|
||||
np_example['cluster_bias_mask'][0] = 1
|
||||
|
||||
# Initialize Bert mask with masked out off diagonals.
|
||||
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
|
||||
x in np_chains_list]
|
||||
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
|
||||
x in np_chains_list]
|
||||
msa_masks = [
|
||||
np.ones(x['msa'].shape, dtype=np.float32) for x in np_chains_list
|
||||
]
|
||||
msa_masks_all_seq = [
|
||||
np.ones(x['msa_all_seq'].shape, dtype=np.float32)
|
||||
for x in np_chains_list
|
||||
]
|
||||
|
||||
msa_mask_block_diag = block_diag(
|
||||
*msa_masks, pad_value=0)
|
||||
msa_mask_block_diag = block_diag(*msa_masks, pad_value=0)
|
||||
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
|
||||
np_example['bert_mask'] = np.concatenate(
|
||||
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
|
||||
[msa_mask_all_seq, msa_mask_block_diag], axis=0
|
||||
)
|
||||
return np_example
|
||||
|
||||
|
||||
def _pad_templates(chains: Sequence[pipeline.FeatureDict],
|
||||
max_templates: int) -> Sequence[pipeline.FeatureDict]:
|
||||
def _pad_templates(
|
||||
chains: Sequence[pipeline.FeatureDict], max_templates: int
|
||||
) -> Sequence[pipeline.FeatureDict]:
|
||||
"""For each chain pad the number of templates to a fixed size.
|
||||
|
||||
Args:
|
||||
@@ -331,14 +374,14 @@ def _pad_templates(chains: Sequence[pipeline.FeatureDict],
|
||||
|
||||
|
||||
def _merge_features_from_multiple_chains(
|
||||
chains: Sequence[pipeline.FeatureDict],
|
||||
pair_msa_sequences: bool) -> pipeline.FeatureDict:
|
||||
chains: Sequence[pipeline.FeatureDict], pair_msa_sequences: bool
|
||||
) -> pipeline.FeatureDict:
|
||||
"""Merge features from multiple chains.
|
||||
|
||||
Args:
|
||||
chains: A list of feature dictionaries that we want to merge.
|
||||
pair_msa_sequences: Whether to concatenate MSA features along the
|
||||
num_res dimension (if True), or to block diagonalize them (if False).
|
||||
pair_msa_sequences: Whether to concatenate MSA features along the num_res
|
||||
dimension (if True), or to block diagonalize them (if False).
|
||||
|
||||
Returns:
|
||||
A feature dictionary for the merged example.
|
||||
@@ -352,7 +395,8 @@ def _merge_features_from_multiple_chains(
|
||||
merged_example[feature_name] = np.concatenate(feats, axis=1)
|
||||
else:
|
||||
merged_example[feature_name] = block_diag(
|
||||
*feats, pad_value=MSA_PAD_VALUES[feature_name])
|
||||
*feats, pad_value=MSA_PAD_VALUES[feature_name]
|
||||
)
|
||||
elif feature_name_split in SEQ_FEATURES:
|
||||
merged_example[feature_name] = np.concatenate(feats, axis=0)
|
||||
elif feature_name_split in TEMPLATE_FEATURES:
|
||||
@@ -365,7 +409,8 @@ def _merge_features_from_multiple_chains(
|
||||
|
||||
|
||||
def _merge_homomers_dense_msa(
|
||||
chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
|
||||
chains: Iterable[pipeline.FeatureDict],
|
||||
) -> Sequence[pipeline.FeatureDict]:
|
||||
"""Merge all identical chains, making the resulting MSA dense.
|
||||
|
||||
Args:
|
||||
@@ -387,12 +432,14 @@ def _merge_homomers_dense_msa(
|
||||
grouped_chains.append(chains)
|
||||
chains = [
|
||||
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
|
||||
for chains in grouped_chains]
|
||||
for chains in grouped_chains
|
||||
]
|
||||
return chains
|
||||
|
||||
|
||||
def _concatenate_paired_and_unpaired_features(
|
||||
example: pipeline.FeatureDict) -> pipeline.FeatureDict:
|
||||
example: pipeline.FeatureDict,
|
||||
) -> pipeline.FeatureDict:
|
||||
"""Merges paired and block-diagonalised features."""
|
||||
features = MSA_FEATURES
|
||||
for feature_name in features:
|
||||
@@ -401,14 +448,15 @@ def _concatenate_paired_and_unpaired_features(
|
||||
feat_all_seq = example[feature_name + '_all_seq']
|
||||
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
|
||||
example[feature_name] = merged_feat
|
||||
example['num_alignments'] = np.array(example['msa'].shape[0],
|
||||
dtype=np.int32)
|
||||
example['num_alignments'] = np.array(example['msa'].shape[0], dtype=np.int32)
|
||||
return example
|
||||
|
||||
|
||||
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
|
||||
pair_msa_sequences: bool,
|
||||
max_templates: int) -> pipeline.FeatureDict:
|
||||
def merge_chain_features(
|
||||
np_chains_list: List[pipeline.FeatureDict],
|
||||
pair_msa_sequences: bool,
|
||||
max_templates: int,
|
||||
) -> pipeline.FeatureDict:
|
||||
"""Merges features for multiple chains to single FeatureDict.
|
||||
|
||||
Args:
|
||||
@@ -419,25 +467,27 @@ def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
|
||||
Returns:
|
||||
Single FeatureDict for entire complex.
|
||||
"""
|
||||
np_chains_list = _pad_templates(
|
||||
np_chains_list, max_templates=max_templates)
|
||||
np_chains_list = _pad_templates(np_chains_list, max_templates=max_templates)
|
||||
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
|
||||
# Unpaired MSA features will be always block-diagonalised; paired MSA
|
||||
# features will be concatenated.
|
||||
np_example = _merge_features_from_multiple_chains(
|
||||
np_chains_list, pair_msa_sequences=False)
|
||||
np_chains_list, pair_msa_sequences=False
|
||||
)
|
||||
if pair_msa_sequences:
|
||||
np_example = _concatenate_paired_and_unpaired_features(np_example)
|
||||
np_example = _correct_post_merged_feats(
|
||||
np_example=np_example,
|
||||
np_chains_list=np_chains_list,
|
||||
pair_msa_sequences=pair_msa_sequences)
|
||||
pair_msa_sequences=pair_msa_sequences,
|
||||
)
|
||||
|
||||
return np_example
|
||||
|
||||
|
||||
def deduplicate_unpaired_sequences(
|
||||
np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
|
||||
np_chains: List[pipeline.FeatureDict],
|
||||
) -> List[pipeline.FeatureDict]:
|
||||
"""Removes unpaired sequences which duplicate a paired sequence."""
|
||||
|
||||
feature_names = np_chains[0].keys()
|
||||
|
||||
Reference in New Issue
Block a user