mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Bugfix for model not learning
This commit is contained in:
@@ -98,7 +98,7 @@ def train(
|
||||
gradient_clip: float = 1.0,
|
||||
batch_size: int = 256,
|
||||
lr: float = 5e-5,
|
||||
l2_norm: float = 0.0,
|
||||
l2_norm: float = 0.01,
|
||||
loss: modelling.LOSS_KEYS = "smooth_l1",
|
||||
min_epochs: Optional[int] = None,
|
||||
max_epochs: int = 10000, # 10000, set to 100 for debug
|
||||
|
||||
@@ -924,7 +924,7 @@ class BertForAutoregressive(BertForAutoregressiveBase, pl.LightningModule):
|
||||
assert preds.ndim == 3 # batch_size, seq_length, features
|
||||
# Get the loss terms
|
||||
l = self.loss(
|
||||
preds[:, batch["causal_idx"]],
|
||||
preds[torch.arange(batch["lengths"].shape[0]), batch["causal_idx"]],
|
||||
batch["causal_target"],
|
||||
)
|
||||
return l
|
||||
|
||||
Reference in New Issue
Block a user