diff --git a/RF2_allatom/util.py b/RF2_allatom/util.py index 4b1781d..419d59b 100644 --- a/RF2_allatom/util.py +++ b/RF2_allatom/util.py @@ -846,7 +846,7 @@ def get_atom_frames(msa, mol, G): frames_with_n = [frame for frame in frames if n in frame] # if the atom isn't in a 3 atom frame, it should be ignored in loss calc, set all the atoms to n if not frames_with_n: - selected_frames.append([n,n,n]) + selected_frames.append([(n,0),(n,0),(n, 0)]) continue frame_priorities = [] for frame in frames_with_n: @@ -861,7 +861,7 @@ def get_atom_frames(msa, mol, G): frame = [(frame-n, 1) for frame in frames_with_n[sorted_indices[0]]] selected_frames.append(frame) assert msa.shape[0] == len(selected_frames) - return torch.Tensor(selected_frames).long() + return torch.tensor(selected_frames).long() ### Generate bond features for small molecules ###