diff --git a/models/rfd3/src/rfd3/model/inference_sampler.py b/models/rfd3/src/rfd3/model/inference_sampler.py index cbdeeae..6ed957e 100644 --- a/models/rfd3/src/rfd3/model/inference_sampler.py +++ b/models/rfd3/src/rfd3/model/inference_sampler.py @@ -9,12 +9,12 @@ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwis from rfd3.model.cfg_utils import strip_X from foundry.common import exists +from foundry.utils.alignment import weighted_rigid_align from foundry.utils.ddp import RankedLogger from foundry.utils.rotation_augmentation import ( rot_vec_mul, uniform_random_rotation, ) -from foundry.utils.alignment import weighted_rigid_align ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index 4338922..aeac08c 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -274,7 +274,7 @@ def get_sparse_attention_indices_with_inter_chain( other_chain_atoms = torch.where(other_chain_mask)[0] if len(other_chain_atoms) > 0: - # Get distances to other chains + # Get distances to other chains distances_to_other = D_LL[b, c, other_chain_atoms] # Select k_inter closest atoms from other chains