Skip to content

Commit 1a080ae

Browse files
committed
Fix metric accumulation (cast to float32)
1 parent 10e2ecc commit 1a080ae

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def loss_classifier(self, params, dropout_rng, batch: Batch, is_training: bool):
221221
"loss": (loss * batch_size, batch_size),
222222
"accuracy": (correct_pred.sum(), batch_size),
223223
}
224+
metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
224225

225226
return loss, metrics
226227

@@ -246,6 +247,7 @@ def loss_autoregressor(self, params, dropout_rng, batch: Batch, is_training: boo
246247

247248
batch_size = batch.patches.shape[0]
248249
metrics = {"loss": (loss * batch_size, batch_size)}
250+
metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
249251

250252
return loss, metrics
251253

0 commit comments

Comments
 (0)