Skip to content

Commit 42ac608

Browse files
committed
Avoid using dataset.save() and .restore()
1 parent 509798c commit 42ac608

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def parse_args():
282282
load_dataset(get_val_files(), args.patch_size, validation_transformations)
283283
.batch(args.batch_size, drop_remainder=True)
284284
.prefetch(tf.data.AUTOTUNE)
285-
.as_numpy_iterator()
286285
)
287286

288287
args.dummy_batch = next(iter(train_ds))

trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,9 @@ def train_model(self, train_loader, val_loader):
378378
if idx > first_step and (idx + 1) % self.eval_every_n_steps == 0:
379379
self.save_checkpoint(step=idx + 1)
380380
# Evaluate the model and return the eval metrics
381-
ckpt_val_loader = val_loader.save()
382381
eval_metrics = self.eval_model(
383-
eval_fn, self.state, prefetch(val_loader)
382+
eval_fn, self.state, prefetch(val_loader.as_numpy_iterator())
384383
)
385-
val_loader.restore(ckpt_val_loader)
386384

387385
for metric_name in eval_metrics.keys():
388386
metric_value = jax.device_get(eval_metrics)[metric_name][0]

0 commit comments

Comments
 (0)