mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
fix training script
This commit is contained in:
@@ -98,7 +98,7 @@ Training
|
||||
[--batch-size BATCH_SIZE] [--weight-decay WEIGHT_DECAY]
|
||||
[--lr LR] [--lambda INTERACTION_WEIGHT] [--topsy-turvy]
|
||||
[--glider-weight GLIDER_WEIGHT]
|
||||
[--glider-thresh GLIDER_THRESH] [-o OUTPUT]
|
||||
[--glider-thresh GLIDER_THRESH] [-o OUTFILE]
|
||||
[--save-prefix SAVE_PREFIX] [-d DEVICE]
|
||||
[--checkpoint CHECKPOINT]
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ def add_args(parser):
|
||||
|
||||
# Output
|
||||
misc_grp.add_argument(
|
||||
"-o", "--output", help="output file path (default: stdout)"
|
||||
"-o", "--outfile", help="output file path (default: stdout)"
|
||||
)
|
||||
misc_grp.add_argument(
|
||||
"--save-prefix", help="path prefix for saving models"
|
||||
@@ -328,7 +328,8 @@ def interaction_grad(
|
||||
if use_cuda:
|
||||
y = y.cpu()
|
||||
p_hat = p_hat.cpu()
|
||||
g_score = g_score.cpu()
|
||||
if run_tt:
|
||||
g_score = g_score.cpu()
|
||||
|
||||
with torch.no_grad():
|
||||
guess_cutoff = 0.5
|
||||
@@ -673,7 +674,7 @@ def main(args):
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
output = args.output
|
||||
output = args.outfile
|
||||
if output is None:
|
||||
output = sys.stdout
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user