Skip to content

Commit 3a2a79c

Browse files
Fix Zygote autodiff errors
1 parent f0693e3 commit 3a2a79c

File tree

5 files changed

+20
-10
lines changed

5 files changed

+20
-10
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ function mcsolve(
414414
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
415415
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
416416

417+
kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
418+
417419
return TimeEvolutionMCSol(
418420
ntraj,
419421
ens_prob_mc.times,
@@ -424,7 +426,7 @@ function mcsolve(
424426
col_which,
425427
sol.converged,
426428
_sol_1.alg,
427-
_sol_1.prob.kwargs[:abstol],
428-
_sol_1.prob.kwargs[:reltol],
429+
kwargs.abstol,
430+
kwargs.reltol,
429431
)
430432
end

src/time_evolution/mesolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,16 @@ function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
204204
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u)
205205
end
206206

207+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
208+
207209
return TimeEvolutionSol(
208210
prob.times,
209211
sol.t,
210212
ρt,
211213
_get_expvals(sol, SaveFuncMESolve),
212214
sol.retcode,
213215
sol.alg,
214-
sol.prob.kwargs[:abstol],
215-
sol.prob.kwargs[:reltol],
216+
kwargs.abstol,
217+
kwargs.reltol,
216218
)
217219
end

src/time_evolution/sesolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,16 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
154154

155155
ψt = map-> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u)
156156

157+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
158+
157159
return TimeEvolutionSol(
158160
prob.times,
159161
sol.t,
160162
ψt,
161163
_get_expvals(sol, SaveFuncSESolve),
162164
sol.retcode,
163165
sol.alg,
164-
sol.prob.kwargs[:abstol],
165-
sol.prob.kwargs[:reltol],
166+
kwargs.abstol,
167+
kwargs.reltol,
166168
)
167169
end

src/time_evolution/smesolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ function smesolve(
426426
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol))
427427
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) # Stack on dimension 2 to align with QuTiP
428428

429+
kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
430+
429431
return TimeEvolutionStochasticSol(
430432
ntraj,
431433
ens_prob.times,
@@ -435,7 +437,7 @@ function smesolve(
435437
m_expvals, # Measurement expectation values
436438
sol.converged,
437439
_sol_1.alg,
438-
_sol_1.prob.kwargs[:abstol],
439-
_sol_1.prob.kwargs[:reltol],
440+
kwargs.abstol,
441+
kwargs.reltol,
440442
)
441443
end

src/time_evolution/ssesolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ function ssesolve(
418418
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
419419
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
420420

421+
kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
422+
421423
return TimeEvolutionStochasticSol(
422424
ntraj,
423425
ens_prob.times,
@@ -427,7 +429,7 @@ function ssesolve(
427429
m_expvals, # Measurement expectation values
428430
sol.converged,
429431
_sol_1.alg,
430-
_sol_1.prob.kwargs[:abstol],
431-
_sol_1.prob.kwargs[:reltol],
432+
kwargs.abstol,
433+
kwargs.reltol,
432434
)
433435
end

0 commit comments

Comments
 (0)