Skip to content

Commit bdc4dac

Browse files
committed
added project.toml
1 parent 7e9f9c2 commit bdc4dac

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
4+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/deeponet.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, Random, Flux
1+
using Test, Random, Flux, MAT
22

33
@testset "DeepONet" begin
44
@testset "dimensions" begin
@@ -14,4 +14,50 @@ using Test, Random, Flux
1414
# Accept only Int as architecture parameters
1515
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
1616
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
17-
end
17+
end
18+
19+
#Just the first 16 datapoints from the Burgers' equation dataset
20+
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771]
21+
sensors = collect(range(0, 1, length=16))'
22+
23+
model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
24+
25+
model(a,sensors)
26+
27+
#forward pass
28+
@test size(model(a, sensors)) == (1, 16)
29+
30+
mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)
31+
32+
#gradients
33+
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1])
34+
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2])
35+
36+
#training
37+
vars = matread("burgerset.mat")
38+
39+
xtrain = vars["a"][1:280, :]'
40+
xval = vars["a"][end-19:end, :]'
41+
42+
ytrain = vars["u"][1:280, :]
43+
yval = vars["u"][end-19:end, :]
44+
45+
grid = collect(range(0, 1, length=1024))'
46+
model = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu)
47+
48+
learning_rate = 0.001
49+
opt = ADAM(learning_rate)
50+
51+
parameters = params(model)
52+
53+
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain)
54+
55+
evalcb() = @show(loss(xval,yval,grid))
56+
57+
Flux.@epochs 400 Flux.train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
58+
59+
= model(xval, grid)
60+
61+
diffvec = vec(abs.((yval .- ỹ)))
62+
mean_diff = sum(diffvec)/length(diffvec)
63+
@test mean_diff < 0.4

0 commit comments

Comments
 (0)