mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Add link to issue for deepspeed_evo_attention test.
This commit is contained in:
@@ -315,6 +315,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
||||
# Move the recycling dimension to the end
|
||||
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
|
||||
batch = tensor_tree_map(move_dim, batch)
|
||||
# Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
|
||||
# https://github.com/aqlaboratory/openfold/issues/532
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_deepspeed_evo_attention = False
|
||||
|
||||
Reference in New Issue
Block a user