Skip to content

Commit 22b937b

Browse files
committed
Fixed bug about metric accumulation
1 parent 1a080ae commit 22b937b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

trainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,9 @@ def loss_classifier(self, params, dropout_rng, batch: Batch, is_training: bool):
218218
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
219219
batch_size = batch.patches.shape[0]
220220
metrics = {
221-
"loss": (loss * batch_size, batch_size),
222-
"accuracy": (correct_pred.sum(), batch_size),
221+
"loss": ((loss * batch_size).astype(jnp.float32), batch_size),
222+
"accuracy": (correct_pred.sum().astype(jnp.float32), batch_size),
223223
}
224-
metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
225224

226225
return loss, metrics
227226

@@ -246,8 +245,7 @@ def loss_autoregressor(self, params, dropout_rng, batch: Batch, is_training: boo
246245
loss = jnp.mean(loss)
247246

248247
batch_size = batch.patches.shape[0]
249-
metrics = {"loss": (loss * batch_size, batch_size)}
250-
metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
248+
metrics = {"loss": ((loss * batch_size).astype(jnp.float32), batch_size)}
251249

252250
return loss, metrics
253251

0 commit comments

Comments
 (0)