Merge remote-tracking branch 'refs/remotes/jnwei/pl_upgrades' into pl_upgrades

This commit is contained in:
Jennifer Wei
2024-05-06 08:42:37 +00:00
2 changed files with 7 additions and 8 deletions

View File

@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords):
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords, mask):

View File

@@ -682,9 +682,9 @@ if __name__ == "__main__":
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
trainer_group.add_argument(
"--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args()
@@ -700,5 +700,4 @@ if __name__ == "__main__":
raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args)