Skip to content

Commit dc08835

Browse files
committed
updated to TorchSharp 0.103.0
1 parent ed7bf5d commit dc08835

File tree

5 files changed

+15
-14
lines changed

5 files changed

+15
-14
lines changed

app/MinGPTProgram.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Diagnostics;
1+
using System.Diagnostics;
42
using System.IO;
5-
using System.Linq;
63
using System.Text;
74

85
using LostTech.Torch.NN;
96

107
using ShellProgressBar;
118

12-
using TorchSharp;
13-
149
using static TorchSharp.torch;
1510
using static TorchSharp.torch.nn;
1611

@@ -29,7 +24,7 @@
2924
}
3025
});
3126

32-
byte[] itob = vocab.ToArray();
27+
byte[] itob = [.. vocab];
3328
var btoi = Enumerable.Range(0, vocab.Count).ToDictionary(i => itob[i], i => i);
3429
var gpt = new GPT(vocabularySize: vocab.Count,
3530
embeddingSize: 128,
@@ -111,6 +106,7 @@ double TrainOnFile(string filePath, ProgressBar parentProgressBar, out int batch
111106
ShowEstimatedDuration = true,
112107
};
113108
using var progressBar = parentProgressBar.Spawn(batches, "", displayOptions);
109+
var stopwatch = Stopwatch.StartNew();
114110
for (int batchIndex = 0; batchIndex < batches; batchIndex++) {
115111
using var _ = torch.NewDisposeScope();
116112
var (@in, @out) = GetBatch(batchIndex);
@@ -131,7 +127,10 @@ double TrainOnFile(string filePath, ProgressBar parentProgressBar, out int batch
131127
using var noGrad = no_grad();
132128
totalLoss += loss.detach().cpu().mean().ToDouble();
133129

134-
progressBar.Tick($"loss: {totalLoss / (batchIndex + 1):0.00}");
130+
int tokensPerSecond = (int)(step * batchSize / stopwatch.Elapsed.TotalSeconds);
131+
132+
progressBar.Tick($"loss: {totalLoss / (batchIndex + 1):0.00}" +
133+
$" {tokensPerSecond} tokens/s");
135134
}
136135

137136
return totalLoss;

app/app.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
</PropertyGroup>
1414

1515
<ItemGroup>
16-
<PackageReference Include="libtorch-cuda-11.7-win-x64" Version="1.13.0.1" />
1716
<PackageReference Include="ShellProgressBar" Version="5.2.0" />
18-
<PackageReference Include="System.Drawing.Common" Version="6.0.0" />
17+
<PackageReference Include="System.Drawing.Common" Version="8.0.8" />
18+
<PackageReference Include="TorchSharp-cuda-windows" Version="0.103.0" />
1919
</ItemGroup>
2020

2121
<ItemGroup>

src/GPT.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ public GPT(int vocabularySize, int blockSize, int blockCount, int embeddingSize,
5656
this.Register(out this.finalNorm, LayerNorm(new long[] { embeddingSize }));
5757
this.Register(out this.decoderHead, Linear(embeddingSize, vocabularySize, hasBias: false));
5858

59+
this.RegisterComponents();
60+
5961
this.apply(InitWeights);
6062
}
6163

src/MinGPT.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939

4040
<ItemGroup>
4141
<!-- The following is recommended for public projects -->
42-
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
43-
<PackageReference Include="TorchSharp" Version="0.99.6" />
42+
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
43+
<PackageReference Include="TorchSharp" Version="0.103.0" />
4444
</ItemGroup>
4545

4646
</Project>

test/MinGPT.Tests.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
</PropertyGroup>
1313

1414
<ItemGroup>
15-
<PackageReference Include="libtorch-cpu" Version="1.13.0.1" />
1615
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
17-
<PackageReference Include="System.Drawing.Common" Version="6.0.0" />
16+
<PackageReference Include="System.Drawing.Common" Version="8.0.8" />
17+
<PackageReference Include="TorchSharp-cuda-windows" Version="0.103.0" />
1818
<PackageReference Include="xunit" Version="2.6.3" />
1919
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.5">
2020
<PrivateAssets>all</PrivateAssets>

0 commit comments

Comments
 (0)