Remove some old debugging code

This commit is contained in:
Kevin Wu
2022-10-03 13:52:13 -07:00
parent 7991fc4a42
commit 9be2cb6b87

View File

@@ -443,14 +443,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
# pl.utilities.rank_zero_debug("Train status", self.training)
# pl.utilities.rank_zero_debug(
# "Inputs", inputs.device, timestep.device, attention_mask.device
# )
# pl.utilities.rank_zero_debug(
# "Inputs", inputs.dtype, timestep.dtype, attention_mask.dtype
# )
output_attentions = (
output_attentions
if output_attentions is not None
@@ -482,10 +474,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
.type_as(timestep)
)
# pl.utilities.rank_zero_debug(
# "Position IDs", position_ids.device, position_ids.dtype
# )
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. This code is taken
# from hugggingface modeling_utils
@@ -598,7 +586,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
Training step, runs once per batch
"""
loss_terms = self._get_loss_terms(batch)
# pl.utilities.rank_zero_debug("Training stacked and gathered loss", loss_terms)
avg_loss = torch.mean(loss_terms)
# L1 loss implementation