Skip to content

Commit 9f1084e

Browse files
authored
Merge pull request #909 from utsavrai/main
Bug Fix: Convert tensors to scalars for plotting compatibility
2 parents 65ed32c + e642eb2 commit 9f1084e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

04_pytorch_custom_datasets.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,10 +2498,11 @@
24982498
" )\n",
24992499
"\n",
25002500
" # 5. Update results dictionary\n",
2501-
" results[\"train_loss\"].append(train_loss)\n",
2502-
" results[\"train_acc\"].append(train_acc)\n",
2503-
" results[\"test_loss\"].append(test_loss)\n",
2504-
" results[\"test_acc\"].append(test_acc)\n",
2501+
" # Ensure all data is moved to CPU and converted to float for storage\n",
2502+
" results[\"train_loss\"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)\n",
2503+
" results[\"train_acc\"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)\n",
2504+
" results[\"test_loss\"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)\n",
2505+
" results[\"test_acc\"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)\n",
25052506
"\n",
25062507
" # 6. Return the filled results at the end of the epochs\n",
25072508
" return results"

0 commit comments

Comments
 (0)