1
- using Test, Random, Flux
1
+ using Test, Random, Flux, MAT
2
2
3
3
@testset " DeepONet" begin
4
4
@testset " dimensions" begin
@@ -14,4 +14,50 @@ using Test, Random, Flux
14
14
# Accept only Int as architecture parameters
15
15
@test_throws MethodError DeepONet ((32.5 ,64 ,72 ), (24 ,48 ,72 ), σ, tanh)
16
16
@test_throws MethodError DeepONet ((32 ,64 ,72 ), (24.1 ,48 ,72 ))
17
- end
17
+ end
18
+
19
+ # Just the first 16 datapoints from the Burgers' equation dataset
20
+ a = [0.83541104 , 0.83479851 , 0.83404712 , 0.83315711 , 0.83212979 , 0.83096755 , 0.82967374 , 0.82825263 , 0.82670928 , 0.82504949 , 0.82327962 , 0.82140651 , 0.81943734 , 0.81737952 , 0.8152405 , 0.81302771 ]
21
+ sensors = collect (range (0 , 1 , length= 16 ))'
22
+
23
+ model = DeepONet ((16 , 22 , 30 ), (1 , 16 , 24 , 30 ), σ, tanh; init_branch= Flux. glorot_normal, bias_trunk= false )
24
+
25
+ model (a,sensors)
26
+
27
+ # forward pass
28
+ @test size (model (a, sensors)) == (1 , 16 )
29
+
30
+ mgrad = Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)
31
+
32
+ # gradients
33
+ @test ! iszero (Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)[1 ])
34
+ @test ! iszero (Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)[2 ])
35
+
36
+ # training
37
+ vars = matread (" burgerset.mat" )
38
+
39
+ xtrain = vars[" a" ][1 : 280 , :]'
40
+ xval = vars[" a" ][end - 19 : end , :]'
41
+
42
+ ytrain = vars[" u" ][1 : 280 , :]
43
+ yval = vars[" u" ][end - 19 : end , :]
44
+
45
+ grid = collect (range (0 , 1 , length= 1024 ))'
46
+ model = DeepONet ((1024 ,1024 ,1024 ),(1 ,1024 ,1024 ),gelu,gelu)
47
+
48
+ learning_rate = 0.001
49
+ opt = ADAM (learning_rate)
50
+
51
+ parameters = params (model)
52
+
53
+ loss (xtrain,ytrain,sensor) = Flux. Losses. mse (model (xtrain,sensor),ytrain)
54
+
55
+ evalcb () = @show (loss (xval,yval,grid))
56
+
57
+ Flux. @epochs 400 Flux. train! (loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
58
+
59
+ ỹ = model (xval, grid)
60
+
61
+ diffvec = vec (abs .((yval .- ỹ)))
62
+ mean_diff = sum (diffvec)/ length (diffvec)
63
+ @test mean_diff < 0.4
0 commit comments