Skip to content

Commit 1f43ea8

Browse files
authored
Merge pull request #756 from ZhouLong0/Modified
Added 'getting Pytorch to run on Apple Arm GPU' guide
2 parents c60e58f + 1c585d1 commit 1f43ea8

File tree

1 file changed

+83
-3
lines changed

1 file changed

+83
-3
lines changed

00_pytorch_fundamentals.ipynb

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3501,6 +3501,86 @@
35013501
"Knowing the number of GPUs PyTorch has access to is helpful incase you wanted to run a specific process on one GPU and another process on another (PyTorch also has features to let you run a process across *all* GPUs)."
35023502
]
35033503
},
3504+
{
3505+
"cell_type": "markdown",
3506+
"metadata": {},
3507+
"source": [
3508+
"\n",
3509+
"\n",
3510+
"### 2.1 Getting PyTorch to run on the ARM GPUs\n",
3511+
"\n",
3512+
"In order to run PyTorch on the Apple's M1/M2 GPUs you can use the [`torch.backends.mps`](https://pytorch.org/docs/stable/notes/mps.html) package.\n",
3513+
"\n",
3514+
"Be sure that the versions of the MacOS and Pytorch are updated\n",
3515+
"\n",
3516+
"You can test if PyTorch has access to a GPU using `torch.backends.mps.is_available()`\n"
3517+
]
3518+
},
3519+
{
3520+
"cell_type": "code",
3521+
"execution_count": 4,
3522+
"metadata": {},
3523+
"outputs": [
3524+
{
3525+
"data": {
3526+
"text/plain": [
3527+
"True"
3528+
]
3529+
},
3530+
"execution_count": 4,
3531+
"metadata": {},
3532+
"output_type": "execute_result"
3533+
}
3534+
],
3535+
"source": [
3536+
"# Check for ARM GPU\n",
3537+
"import torch\n",
3538+
"torch.backends.mps.is_available()"
3539+
]
3540+
},
3541+
{
3542+
"cell_type": "code",
3543+
"execution_count": 7,
3544+
"metadata": {},
3545+
"outputs": [
3546+
{
3547+
"data": {
3548+
"text/plain": [
3549+
"'mps'"
3550+
]
3551+
},
3552+
"execution_count": 7,
3553+
"metadata": {},
3554+
"output_type": "execute_result"
3555+
}
3556+
],
3557+
"source": [
3558+
"# Set device type\n",
3559+
"device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
3560+
"device"
3561+
]
3562+
},
3563+
{
3564+
"cell_type": "markdown",
3565+
"metadata": {},
3566+
"source": [
3567+
"As before, if the above output `\"mps\"` it means we can set all of our PyTorch code to use the available Apple Arm GPU"
3568+
]
3569+
},
3570+
{
3571+
"cell_type": "code",
3572+
"execution_count": 8,
3573+
"metadata": {},
3574+
"outputs": [],
3575+
"source": [
3576+
"if torch.cuda.is_available():\n",
3577+
" device = 'cuda'\n",
3578+
"elif torch.backends.mps.is_available():\n",
3579+
" device = 'mps'\n",
3580+
"else:\n",
3581+
" device = 'cpu'"
3582+
]
3583+
},
35043584
{
35053585
"cell_type": "markdown",
35063586
"metadata": {
@@ -3524,7 +3604,7 @@
35243604
},
35253605
{
35263606
"cell_type": "code",
3527-
"execution_count": 74,
3607+
"execution_count": 9,
35283608
"metadata": {
35293609
"colab": {
35303610
"base_uri": "https://localhost:8080/"
@@ -3543,10 +3623,10 @@
35433623
{
35443624
"data": {
35453625
"text/plain": [
3546-
"tensor([1, 2, 3], device='cuda:0')"
3626+
"tensor([1, 2, 3], device='mps:0')"
35473627
]
35483628
},
3549-
"execution_count": 74,
3629+
"execution_count": 9,
35503630
"metadata": {},
35513631
"output_type": "execute_result"
35523632
}

0 commit comments

Comments
 (0)