Skip to content

Commit d95bf63

Browse files
Add support for Enzyme autodiff (#531)
1 parent fab15f7 commit d95bf63

File tree

4 files changed

+123
-74
lines changed

4 files changed

+123
-74
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)
99

1010
- Improve Bloch sphere rendering for animation. ([#520])
11+
- Add support to `Enzyme.jl` for `sesolve` and `mesolve`. ([#531])
1112

1213
## [v0.34.0]
1314
Release date: 2025-07-29
@@ -295,3 +296,4 @@ Release date: 2024-11-13
295296
[#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515
296297
[#517]: https://github.com/qutip/QuantumToolbox.jl/issues/517
297298
[#520]: https://github.com/qutip/QuantumToolbox.jl/issues/520
299+
[#531]: https://github.com/qutip/QuantumToolbox.jl/issues/531

src/time_evolution/mesolve.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ function mesolve(
178178
kwargs...,
179179
)
180180

181+
# Move sensealg argument to solve for Enzyme.jl support.
182+
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
183+
sensealg = get(kwargs, :sensealg, nothing)
184+
kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,))
185+
181186
prob = mesolveProblem(
182187
H,
183188
ψ0,
@@ -188,14 +193,19 @@ function mesolve(
188193
params = params,
189194
progress_bar = progress_bar,
190195
inplace = inplace,
191-
kwargs...,
196+
kwargs_filtered...,
192197
)
193198

194-
return mesolve(prob, alg)
199+
# TODO: Remove sensealg when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed
200+
if isnothing(sensealg)
201+
return mesolve(prob, alg)
202+
else
203+
return mesolve(prob, alg; sensealg = sensealg)
204+
end
195205
end
196206

197-
function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
198-
sol = solve(prob.prob, alg)
207+
function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
208+
sol = solve(prob.prob, alg; kwargs...)
199209

200210
# No type instabilities since `isoperket` is a Val, and so it is known at compile time
201211
if getVal(prob.kwargs.isoperket)

src/time_evolution/sesolve.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ function sesolve(
135135
inplace::Union{Val,Bool} = Val(true),
136136
kwargs...,
137137
)
138+
139+
# Move sensealg argument to solve for Enzyme.jl support.
140+
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
141+
sensealg = get(kwargs, :sensealg, nothing)
142+
kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,))
143+
138144
prob = sesolveProblem(
139145
H,
140146
ψ0,
@@ -143,14 +149,19 @@ function sesolve(
143149
params = params,
144150
progress_bar = progress_bar,
145151
inplace = inplace,
146-
kwargs...,
152+
kwargs_filtered...,
147153
)
148154

149-
return sesolve(prob, alg)
155+
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
156+
if isnothing(sensealg)
157+
return sesolve(prob, alg)
158+
else
159+
return sesolve(prob, alg; sensealg = sensealg)
160+
end
150161
end
151162

152-
function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
153-
sol = solve(prob.prob, alg)
163+
function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
164+
sol = solve(prob.prob, alg; kwargs...)
154165

155166
ψt = map-> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u)
156167

Lines changed: 92 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,77 @@
1-
@testset "Autodiff" verbose=true begin
2-
@testset "sesolve" verbose=true begin
3-
ψ0 = fock(2, 1)
4-
t_max = 10
5-
tlist = range(0, t_max, 100)
1+
# ---- SESOLVE ----
2+
const ψ0_sesolve = fock(2, 1)
3+
t_max = 10
4+
const tlist_sesolve = range(0, t_max, 100)
65

7-
# For direct Forward differentiation
8-
function my_f_sesolve_direct(p)
9-
H = p[1] * sigmax()
10-
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))
6+
# For direct Forward differentiation
7+
function my_f_sesolve_direct(p)
8+
H = p[1] * sigmax()
9+
sol = sesolve(H, ψ0_sesolve, tlist_sesolve, progress_bar = Val(false))
1110

12-
return real(expect(projection(2, 0, 0), sol.states[end]))
13-
end
11+
return real(expect(projection(2, 0, 0), sol.states[end]))
12+
end
1413

15-
# For SciMLSensitivity.jl
16-
coef_Ω(p, t) = p[1]
17-
H_evo = QobjEvo(sigmax(), coef_Ω)
18-
19-
function my_f_sesolve(p)
20-
sol = sesolve(
21-
H_evo,
22-
ψ0,
23-
tlist,
24-
progress_bar = Val(false),
25-
params = p,
26-
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
27-
)
28-
29-
return real(expect(projection(2, 0, 0), sol.states[end]))
30-
end
14+
# For SciMLSensitivity.jl
15+
coef_Ω(p, t) = p[1]
16+
const H_evo = QobjEvo(sigmax(), coef_Ω)
17+
18+
function my_f_sesolve(p)
19+
sol = sesolve(
20+
H_evo,
21+
ψ0_sesolve,
22+
tlist_sesolve,
23+
progress_bar = Val(false),
24+
params = p,
25+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
26+
)
27+
28+
return real(expect(projection(2, 0, 0), sol.states[end]))
29+
end
3130

32-
# Analytical solution
33-
my_f_analytic(Ω) = abs2(sin* t_max))
34-
my_f_analytic_deriv(Ω) = 2 * t_max * sin* t_max) * cos* t_max)
31+
# Analytical solution
32+
my_f_analytic(Ω) = abs2(sin* t_max))
33+
my_f_analytic_deriv(Ω) = 2 * t_max * sin* t_max) * cos* t_max)
34+
35+
# ---- MESOLVE ----
36+
const N = 20
37+
const a = destroy(N)
38+
const ψ0_mesolve = fock(N, 0)
39+
const tlist_mesolve = range(0, 40, 100)
40+
41+
# For direct Forward differentiation
42+
function my_f_mesolve_direct(p)
43+
H = p[1] * a' * a + p[2] * (a + a')
44+
c_ops = [sqrt(p[3]) * a]
45+
sol = mesolve(H, ψ0_mesolve, tlist_mesolve, c_ops, progress_bar = Val(false))
46+
return real(expect(a' * a, sol.states[end]))
47+
end
48+
49+
# For SciMLSensitivity.jl
50+
coef_Δ(p, t) = p[1]
51+
coef_F(p, t) = p[2]
52+
coef_γ(p, t) = sqrt(p[3])
53+
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
54+
c_ops = [QobjEvo(a, coef_γ)]
55+
const L = liouvillian(H, c_ops)
56+
57+
function my_f_mesolve(p)
58+
sol = mesolve(
59+
L,
60+
ψ0_mesolve,
61+
tlist_mesolve,
62+
progress_bar = Val(false),
63+
params = p,
64+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
65+
)
66+
67+
return real(expect(a' * a, sol.states[end]))
68+
end
3569

70+
# Analytical solution
71+
n_ss(Δ, F, γ) = abs2(F /+ 1im * γ / 2))
72+
73+
@testset "Autodiff" verbose=true begin
74+
@testset "sesolve" verbose=true begin
3675
Ω = 1.0
3776
params = [Ω]
3877

@@ -52,46 +91,21 @@
5291

5392
@test grad_qt grad_exact atol=1e-6
5493
end
55-
end
5694

57-
@testset "mesolve" verbose=true begin
58-
N = 20
59-
a = destroy(N)
60-
ψ0 = fock(N, 0)
61-
tlist = range(0, 40, 100)
62-
63-
# For direct Forward differentiation
64-
function my_f_mesolve_direct(p)
65-
H = p[1] * a' * a + p[2] * (a + a')
66-
c_ops = [sqrt(p[3]) * a]
67-
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
68-
return real(expect(a' * a, sol.states[end]))
69-
end
95+
@testset "Enzyme.jl" begin
96+
dparams = Enzyme.make_zero(params)
97+
Enzyme.autodiff(
98+
Enzyme.set_runtime_activity(Enzyme.Reverse),
99+
my_f_sesolve,
100+
Active,
101+
Duplicated(params, dparams),
102+
)[1]
70103

71-
# For SciMLSensitivity.jl
72-
coef_Δ(p, t) = p[1]
73-
coef_F(p, t) = p[2]
74-
coef_γ(p, t) = sqrt(p[3])
75-
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
76-
c_ops = [QobjEvo(a, coef_γ)]
77-
L = liouvillian(H, c_ops)
78-
79-
function my_f_mesolve(p)
80-
sol = mesolve(
81-
L,
82-
ψ0,
83-
tlist,
84-
progress_bar = Val(false),
85-
params = p,
86-
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
87-
)
88-
89-
return real(expect(a' * a, sol.states[end]))
104+
@test dparams grad_exact atol=1e-6
90105
end
106+
end
91107

92-
# Analytical solution
93-
n_ss(Δ, F, γ) = abs2(F /+ 1im * γ / 2))
94-
108+
@testset "mesolve" verbose=true begin
95109
Δ = 1.0
96110
F = 1.0
97111
γ = 1.0
@@ -111,5 +125,17 @@
111125
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]
112126
@test grad_qt grad_exact atol=1e-6
113127
end
128+
129+
@testset "Enzyme.jl" begin
130+
dparams = Enzyme.make_zero(params)
131+
Enzyme.autodiff(
132+
Enzyme.set_runtime_activity(Enzyme.Reverse),
133+
my_f_mesolve,
134+
Active,
135+
Duplicated(params, dparams),
136+
)[1]
137+
138+
@test dparams grad_exact atol=1e-6
139+
end
114140
end
115141
end

0 commit comments

Comments
 (0)