mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge remote-tracking branch 'refs/remotes/jnwei/pl_upgrades' into pl_upgrades
This commit is contained in:
@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):
|
||||
|
||||
|
||||
def _superimpose_single(reference, coords):
|
||||
reference_np = reference.detach().to(torch.float).cpu().numpy()
|
||||
coords_np = coords.detach().to(torch.float).cpu().numpy()
|
||||
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
|
||||
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
|
||||
reference_np = reference.detach().to(torch.float).cpu().numpy()
|
||||
coords_np = coords.detach().to(torch.float).cpu().numpy()
|
||||
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
|
||||
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
|
||||
|
||||
|
||||
def superimpose(reference, coords, mask):
|
||||
|
||||
@@ -682,9 +682,9 @@ if __name__ == "__main__":
|
||||
trainer_group.add_argument(
|
||||
"--reload_dataloaders_every_n_epochs", type=int, default=1,
|
||||
)
|
||||
|
||||
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
|
||||
help="Accumulate gradients over k batches before next optimizer step.")
|
||||
trainer_group.add_argument(
|
||||
"--accumulate_grad_batches", type=int, default=1,
|
||||
help="Accumulate gradients over k batches before next optimizer step.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -700,5 +700,4 @@ if __name__ == "__main__":
|
||||
raise ValueError(
|
||||
"Choose between loading pretrained Jax-weights and a checkpoint-path")
|
||||
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user