Merge branch 'main' of https://github.com/minkbaek/BFF into main

This commit is contained in:
Rohith
2022-07-14 23:27:59 -07:00

View File

@@ -846,7 +846,7 @@ def get_atom_frames(msa, mol, G):
frames_with_n = [frame for frame in frames if n in frame]
# if the atom isn't in a 3 atom frame, it should be ignored in loss calc, set all the atoms to n
if not frames_with_n:
selected_frames.append([n,n,n])
selected_frames.append([(n,0),(n,0),(n, 0)])
continue
frame_priorities = []
for frame in frames_with_n:
@@ -861,7 +861,7 @@ def get_atom_frames(msa, mol, G):
frame = [(frame-n, 1) for frame in frames_with_n[sorted_indices[0]]]
selected_frames.append(frame)
assert msa.shape[0] == len(selected_frames)
return torch.Tensor(selected_frames).long()
return torch.tensor(selected_frames).long()
### Generate bond features for small molecules ###