Add an option to make entity IDs unique in structure.concat

PiperOrigin-RevId: 826028193
Change-Id: I8af72087770fafa68135093d0c0fa55d6cbd03f2
This commit is contained in:
Augustin Zidek
2025-10-30 07:46:15 -07:00
committed by Copybara-Service
parent 805adc3863
commit aed6e82cb1

View File

@@ -2889,7 +2889,11 @@ class Structure(table.Database):
# We don't need to assign unique chain IDs because the bioassembly
# transform takes care of remapping chain IDs to be unique.
concatenated = concat(transformed_strucs, assign_unique_chain_ids=False)
concatenated = concat(
transformed_strucs,
assign_unique_chain_ids=False,
assign_unique_entity_ids=False,
)
# Copy over all scalar fields (e.g. name, release date, etc.) other than
# bioassembly_data because it relates only to the pre-transformed structure.
@@ -3094,6 +3098,7 @@ def concat(
*,
name: str | None = None,
assign_unique_chain_ids: bool = True,
assign_unique_entity_ids: bool = True,
) -> Structure:
"""Concatenates structures along the atom dimension.
@@ -3128,11 +3133,16 @@ def concat(
structures then they should all have the same number of models).
name: Optional name to give to the concatenated structure. If None, the name
will be concatenation of names of all concatenated structures.
assign_unique_chain_ids: Whether this function will first assign new unique
assign_unique_chain_ids: If True, this function first assigns new unique
chain IDs, entity IDs and author chain IDs to every chain in `strucs`. If
`False` then users must ensure chain IDs are already unique, otherwise an
False, you must ensure chain IDs are already unique, otherwise an
exception is raised. See `_assign_unique_chain_ids` for more information
on how this is performed.
assign_unique_entity_ids: If True, this function first assigns new unique
entity IDs to every chain in `strucs`. If False, you must ensure entity
IDs are already set in a way so that same entity ID implies for two chains
in `strucs` that they have the same residues. This option applies only if
`assign_unique_chain_ids == False`, otherwise it must be set to True.
Returns:
A new concatenated `Structure` with all of the chains in `strucs` combined
@@ -3146,6 +3156,13 @@ def concat(
"""
if not strucs:
raise ValueError('Need at least one Structure to concatenate.')
if assign_unique_chain_ids and not assign_unique_entity_ids:
raise ValueError(
'If assign_unique_chain_ids is True, assign_unique_entity_ids must be '
'True as well.'
)
if assign_unique_chain_ids:
strucs = _assign_unique_chain_ids(strucs)
@@ -3167,11 +3184,13 @@ def concat(
concatted_struc = table.concat_databases(strucs)
name = name if name is not None else '_'.join(s.name for s in strucs)
# Chain IDs (label and author) are fixed at this point, fix also entity IDs.
if assign_unique_chain_ids:
entity_id = np.char.mod('%d', np.arange(1, concatted_struc.num_chains + 1))
if assign_unique_chain_ids or assign_unique_entity_ids:
numeric_ids = np.arange(1, concatted_struc.num_chains + 1)
entity_id = np.char.mod('%d', numeric_ids).astype(object)
chains = concatted_struc.chains_table.copy_and_update(entity_id=entity_id)
else:
chains = concatted_struc.chains_table
return concatted_struc.copy_and_update(
name=name,
release_date=None,