I’m testing out a gradient-based optimization of some (quantum) dynamical system using AD over a simulation with OrdinaryDiffEq
. The “quantum” here really just means that the dynamical state is complex-valued, not real-valued as in standard optimal control.
See the minimum work example below, also available as a standalone script mwe_zygote.jl.
There is also a tutorial notebook with the full context (intended for physics students, but probably understandable to anyone with a basic understanding of linear algebra if you don’t get caught up in the details).
Setup
using Pkg; Pkg.activate(temp=true)
Pkg.add(["Optimization", "OptimizationNLopt", "ComponentArrays", "UnPack", "OrdinaryDiffEq", "Zygote", "SciMLSensitivity", "Plots"])
using Optimization, OptimizationNLopt
using ComponentArrays: ComponentVector
using UnPack: @unpack
using OrdinaryDiffEq
using Zygote
using SciMLSensitivity
const 𝕚 = 1im
Model
E(t; E₀, t₁, t₂, a) = (E₀/2) * (tanh(a*(t-t₁)) - tanh(a*(t-t₂)));
"""ODE function"""
function f(Ψ, p, t; sign="+", a=1000.0)
ΔT₁, ΔT₂, ΔT₃, ϕ₁, ϕ₂, ϕ₃, E₀₁, E₀₂, E₀₃ = p
T₁ = ΔT₁
T₂ = ΔT₁ + ΔT₂
T₃ = ΔT₁ + ΔT₂ + ΔT₃
μ = (sign == "-" ? -1 : 1)
E₁ = E(t; E₀=E₀₁, t₁=0.0, t₂=T₁, a)
E₂ = E(t; E₀=E₀₂, t₁=T₁, t₂=T₂, a)
E₃ = E(t; E₀=E₀₃, t₁=T₂, t₂=T₃, a)
F = (-𝕚 * µ) * [ # -𝕚 * H (RHS of Schrödinger Eq. rewritten as ODE)
0.0 E₁ * exp(𝕚 * ϕ₁) E₃ * exp(𝕚 * ϕ₃)
E₁ * exp(-𝕚 * ϕ₁) 0.0 E₂ * exp(𝕚 * ϕ₂)
E₃ * exp(-𝕚 * ϕ₃) E₂ * exp(-𝕚 * ϕ₂) 0.0
]
return F * Ψ
end
f₊(Ψ, p, t) = f(Ψ, p, t; sign="+", a=1000.0);
f₋(Ψ, p, t) = f(Ψ, p, t; sign="-", a=1000.0);
Loss function
function loss_zygote(x)
Ψ₀ = ComplexF64[1, 0, 0];
tspan = (0.0, 1.0)
prob₊ = ODEProblem(f₊, Ψ₀, tspan, x)
prob₋ = ODEProblem(f₋, Ψ₀, tspan, x)
Ψ₊ = OrdinaryDiffEq.solve(prob₊, DP5(), verbose=true).u[end]
Ψ₋ = OrdinaryDiffEq.solve(prob₋, DP5(), verbose=true).u[end]
# XXX With verbose=true, this produces warnings!
fid = (abs2(Ψ₊[1]) + abs2(Ψ₋[3])) / 2
return 1 - fid
end
Optimization Problem
guess = ComponentVector(
ΔT=[0.2, 0.4, 0.3],
ϕ=[π, π, π],
E₀=[5.0, 5.0, 5.0]
);
prob_zygote = OptimizationProblem(
OptimizationFunction((x, _)->loss_zygote(x), AutoZygote()),
guess,
nothing;
lb=zeros(9),
ub=ComponentVector(
ΔT=[1.0, 1.0, 1.0],
ϕ=[2π, 2π, 2π],
E₀=[10.0, 10.0, 10.0]
),
stopval=(1-0.99),
);
Running the solver
obtained_fidelities = Float64[]; # for keeping track of the fidelity in each iteration
function callback(state, loss_val)
global obtained_fidelities
fid = 1 - loss_val
push!(obtained_fidelities, fid)
print("Iteration: $(length(obtained_fidelities)), current fidelity $(round(fid; digits=4))\r")
return false
end
res_zygote = Optimization.solve(prob_zygote, NLopt.LD_LBFGS(), maxiters=500, callback=callback)
Plotting the convergence
using Plots
plot(obtained_fidelities; marker=:cross, label="", xlabel="optimization iteration", ylabel="fidelity")
That just doesn’t look right to me. I haven’t used NLOpts LBFGS
implementation much, but shouldn’t this be monotonic? Unless the callback is also called in linesearch iterations.
So I’m not sure I can trust that gradient… I’ve had problems with complex-valued computations in Zygote before, so maybe it’s that. There actually is a method for calculating this gradient exactly for quantum dynamics. I’ve only partially implemented that in QuantumGradientGenerators
, but at some point, I’ll be able to check these.
But even beyond that, I’ve also tried the alternative auto-AD options, and Zygote
is actually the only that doesn’t crash. Not even AutoFiniteDiff()
works, and in my previous experimentatin with AD I’ve found FiniteDifferences to be pretty robust and a good way to test Zygote gradients.
There’s also the issue that I’m seeing a lot of
Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
That makes me think that I’m using this wrong. Is there anything in the MWE that isn’t set up the way it’s supposed to be?
Oh, if I tweak the functional to use abs
instead of abs2
, the optimization completely goes off the rails (that could be Zygote getting the gradient wrong).
For what it’s worth, as you can see in the full tutorial notebook, the gradient-free optimization of this problem works fine. So the issue is definitely with the gradient or the way I’m passing that gradient to the optimizer.