fix training script

This commit is contained in:
samsledje
2022-06-23 12:03:39 -04:00
parent 1b63b008fd
commit 6000faccfb
2 changed files with 5 additions and 4 deletions

View File

@@ -98,7 +98,7 @@ Training
[--batch-size BATCH_SIZE] [--weight-decay WEIGHT_DECAY]
[--lr LR] [--lambda INTERACTION_WEIGHT] [--topsy-turvy]
[--glider-weight GLIDER_WEIGHT]
[--glider-thresh GLIDER_THRESH] [-o OUTPUT]
[--glider-thresh GLIDER_THRESH] [-o OUTFILE]
[--save-prefix SAVE_PREFIX] [-d DEVICE]
[--checkpoint CHECKPOINT]

View File

@@ -174,7 +174,7 @@ def add_args(parser):
# Output
misc_grp.add_argument(
"-o", "--output", help="output file path (default: stdout)"
"-o", "--outfile", help="output file path (default: stdout)"
)
misc_grp.add_argument(
"--save-prefix", help="path prefix for saving models"
@@ -328,7 +328,8 @@ def interaction_grad(
if use_cuda:
y = y.cpu()
p_hat = p_hat.cpu()
g_score = g_score.cpu()
if run_tt:
g_score = g_score.cpu()
with torch.no_grad():
guess_cutoff = 0.5
@@ -673,7 +674,7 @@ def main(args):
:meta private:
"""
output = args.output
output = args.outfile
if output is None:
output = sys.stdout
else: