Bugfix for model not learning

This commit is contained in:
Kevin Wu
2022-11-08 21:56:18 -08:00
parent 966a3be053
commit f044114dae
2 changed files with 2 additions and 2 deletions

View File

@@ -98,7 +98,7 @@ def train(
gradient_clip: float = 1.0,
batch_size: int = 256,
lr: float = 5e-5,
l2_norm: float = 0.0,
l2_norm: float = 0.01,
loss: modelling.LOSS_KEYS = "smooth_l1",
min_epochs: Optional[int] = None,
max_epochs: int = 10000, # 10000, set to 100 for debug

View File

@@ -924,7 +924,7 @@ class BertForAutoregressive(BertForAutoregressiveBase, pl.LightningModule):
assert preds.ndim == 3 # batch_size, seq_length, features
# Get the loss terms
l = self.loss(
preds[:, batch["causal_idx"]],
preds[torch.arange(batch["lengths"].shape[0]), batch["causal_idx"]],
batch["causal_target"],
)
return l