Questionable Zygote gradients for quantum optimal control problem

I’m testing out a gradient-based optimization of some (quantum) dynamical system using AD over a simulation with OrdinaryDiffEq. The “quantum” here really just means that the dynamical state is complex-valued, not real-valued as in standard optimal control.

See the minimum work example below, also available as a standalone script mwe_zygote.jl.

There is also a tutorial notebook with the full context (intended for physics students, but probably understandable to anyone with a basic understanding of linear algebra if you don’t get caught up in the details).


using Pkg; Pkg.activate(temp=true)
Pkg.add(["Optimization", "OptimizationNLopt", "ComponentArrays", "UnPack", "OrdinaryDiffEq", "Zygote", "SciMLSensitivity", "Plots"])

using Optimization, OptimizationNLopt
using ComponentArrays: ComponentVector
using UnPack: @unpack
using OrdinaryDiffEq
using Zygote
using SciMLSensitivity

const 𝕚 = 1im


E(t; E₀, t₁, t₂, a) =  (E₀/2) * (tanh(a*(t-t₁)) - tanh(a*(t-t₂)));

"""ODE function"""
function f(Ψ, p, t; sign="+", a=1000.0)
    ΔT₁, ΔT₂, ΔT₃, ϕ₁, ϕ₂, ϕ₃, E₀₁, E₀₂, E₀₃ = p
    T₁ = ΔT₁
    T₂ = ΔT₁ + ΔT₂
    T₃ = ΔT₁ + ΔT₂ + ΔT₃
    μ = (sign == "-" ? -1 : 1)
    E₁ = E(t; E₀=E₀₁, t₁=0.0, t₂=T₁, a)
    E₂ = E(t; E₀=E₀₂, t₁=T₁, t₂=T₂, a)
    E₃ = E(t; E₀=E₀₃, t₁=T₂, t₂=T₃, a)
    F = (-𝕚 * µ) * [  # -𝕚 * H  (RHS of Schrödinger Eq. rewritten as ODE)
              0.0           E₁ * exp(𝕚 * ϕ₁)   E₃ * exp(𝕚 * ϕ₃)
        E₁ * exp(-𝕚 * ϕ₁)         0.0          E₂ * exp(𝕚 * ϕ₂)
        E₃ * exp(-𝕚 * ϕ₃)   E₂ * exp(-𝕚 * ϕ₂)         0.0
    return F * Ψ

f₊(Ψ, p, t) = f(Ψ, p, t; sign="+", a=1000.0);
f₋(Ψ, p, t) = f(Ψ, p, t; sign="-", a=1000.0);

Loss function

function loss_zygote(x)

    Ψ₀ = ComplexF64[1, 0, 0];
    tspan = (0.0, 1.0)

    prob₊ = ODEProblem(f₊, Ψ₀, tspan, x)
    prob₋ = ODEProblem(f₋, Ψ₀, tspan, x)

    Ψ₊ = OrdinaryDiffEq.solve(prob₊, DP5(), verbose=true).u[end]
    Ψ₋ = OrdinaryDiffEq.solve(prob₋, DP5(), verbose=true).u[end]
   # XXX With verbose=true, this produces warnings!

    fid = (abs2(Ψ₊[1]) + abs2(Ψ₋[3])) / 2
    return 1 - fid


Optimization Problem

guess = ComponentVector(
    ΔT=[0.2, 0.4, 0.3],
    ϕ=[π, π, π],
    E₀=[5.0, 5.0, 5.0]

prob_zygote = OptimizationProblem(
    OptimizationFunction((x, _)->loss_zygote(x), AutoZygote()),
        ΔT=[1.0, 1.0, 1.0],
        ϕ=[2π, 2π, 2π],
        E₀=[10.0, 10.0, 10.0]

Running the solver

obtained_fidelities = Float64[];  # for keeping track of the fidelity in each iteration

function callback(state, loss_val)
    global obtained_fidelities
    fid = 1 - loss_val
    push!(obtained_fidelities, fid)
    print("Iteration: $(length(obtained_fidelities)), current fidelity $(round(fid; digits=4))\r")
    return false

res_zygote = Optimization.solve(prob_zygote, NLopt.LD_LBFGS(), maxiters=500, callback=callback)

Plotting the convergence

using Plots

plot(obtained_fidelities; marker=:cross, label="", xlabel="optimization iteration", ylabel="fidelity")

That just doesn’t look right to me. I haven’t used NLOpts LBFGS implementation much, but shouldn’t this be monotonic? Unless the callback is also called in linesearch iterations.

So I’m not sure I can trust that gradient… I’ve had problems with complex-valued computations in Zygote before, so maybe it’s that. There actually is a method for calculating this gradient exactly for quantum dynamics. I’ve only partially implemented that in QuantumGradientGenerators, but at some point, I’ll be able to check these.

But even beyond that, I’ve also tried the alternative auto-AD options, and Zygote is actually the only that doesn’t crash. Not even AutoFiniteDiff() works, and in my previous experimentatin with AD I’ve found FiniteDifferences to be pretty robust and a good way to test Zygote gradients.

There’s also the issue that I’m seeing a lot of

Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs

That makes me think that I’m using this wrong. Is there anything in the MWE that isn’t set up the way it’s supposed to be?

Oh, if I tweak the functional to use abs instead of abs2, the optimization completely goes off the rails (that could be Zygote getting the gradient wrong).

For what it’s worth, as you can see in the full tutorial notebook, the gradient-free optimization of this problem works fine. So the issue is definitely with the gradient or the way I’m passing that gradient to the optimizer.

1 Like

First things first, you always have to try different tolerances and I don’t see that here. Is this still an issue at lower tolerances?

Next, did you try an Enzyme version?

If both of those have issues, could you try and isolate this to an MWE on just gradients of complex solves on a solve?

We do cover complex with a few test cases:

1 Like

First things first, you always have to try different tolerances and I don’t see that here. Is this still an issue at lower tolerances?

Very good point! I didn’t realize that the default tolerances were pretty high. Increasing the precision definitely has a significant effect. If I set reltol = 1e-9, abstol = 1e-7, I get the much saner

which stays the same if I continue increasing the precision all the way to 1e-14.

There’s still the non-monotonic convergence, but I think what is happening is that the optimizer is running into the box constraints. I don’t know if NLOpt.LBFGS is supposed to behave like that (I don’t know how it actually takes into account box constraints), but I suppose this could be correct. It would be nice to also have the option to use LBFSGB to double check since that has the bounds built into the method, and I don’t think it would do that kind of non-monotonic convergence.

Next, did you try an Enzyme version?

Yeah OptimizationFunction((x, _)->loss_zygote(x), AutoEnzyme()) fails with

ERROR: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, Core.apply_type, 7, 6)

after a lot of Warning: TypeAnalysisDepthLimit.

If both of those have issues, could you try and isolate this to an MWE on just gradients of complex solves on a solve?

You mean a simple state-to-state optimization like this?

function loss_simple(x)
    Ψ₀ = ComplexF64[1, 0, 0]
    Ψtgt = ComplexF64[0, 0, 1]
    tspan = (0.0, 1.0)
    prob = ODEProblem(f₋, Ψ₀, tspan, x)
    Ψ = OrdinaryDiffEq.solve(prob, DP5(), verbose=false, reltol = 1e-9, abstol = 1e-7).u[end]
    fid = abs2(Ψ ⋅ Ψtgt)
    return 1 - fid

That gives me a very crazy / interesting

where every other iteration, the fidelity drops back to exactly zero. Turns out the reason for that is that the optimizer really wants to push the first control parameter (the duration of the first sub-pulse) to the lower bound of zero, and then the system just doesn’t evolve at all and the resulting fidelity is exactly zero. So, it seems like a good idea would be to put the lower bound not at zero, but at, e.g. 0.1. That helps a lot, and produces

Still the non-monotonic dips, but those again are where the optimizer pushes against the constraints.

So I still can’t tell for absolutely sure that the gradients are good, but it seems to me that with the increased precision, it’s probably fine and it’s more the optimizer that’s being a bit wonky (to my taste). I’ll be able to tell for sure once I put in a bit more work in QuantumGradientGenerators to extend it to parameterized control fields (it’s basically a very specialized version of forward-mode AD),

In the meantime, it would also be nice to have LBFSGB in Optimization.jl (#277), because I think that’ll be less “wonky” :crossed_fingers:

Oh, one more thing: OptimizationFunction((x, _)->loss_simple(x), AutoFiniteDiff()) behaves substantially differently (it gets stuck for the more complicated control problem and finds a reasonable but different solution for the simple problem). So, clearly FiniteDifferences gives quite different gradients than Zygote. In the past, I’ve used FiniteDifferences to check Zygote for pretty simple functions, and they’ve always matched up pretty well. Is the application of FiniteDifferences to an entire ODE solve asking too much of it?

1 Like

There is also still the

┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:104

that I’m seeing a lot of if I don’t suppress them with verbose = false. What does that mean, exactly? Should I be worried about these warnings?

Yup, let’s get that in. We do want to get some new native BFGS implementations based on NonlinearSolve.jl soon, but we should get every single one we can wrapped anyways for benchmarking

1 Like

Small followup:

I tried LBFSGB directly. If I record the “fidelity” for all evaluations of the loss function, I get this plot:

That’s remarkably similar to the first plot in my previous post obtained for Optimization.jl with NLopt.LD_LBFGS()

However, this includes evaluations of the loss function during the line seach! If I pull out the actual values of the loss function from LBFSGB’s internals (the “iterate information” with iprint=100), I get this plot for the fidelity (1-loss):

which is nice and monotonic and exactly what I was expecting to see.

So this leads me to conclude that the callback function in Optimization.jl is called inside of linesearch iterations, not just after each iteration of the optimizer. That totally explains the non-monotonic results, but it’s not generally what I would expect for a callback function (or at least, it would be nice to have the option to decide whether I want to get a callback on all evaluations of the loss function, or just the “iterate values”). I’ve often used the callback to check monotonic convergence, because if I’m not seeing monotonic convergence (of the iterate values), that’s usually an indicator that something is wrong in my numerics. That’s not going to work if Optimizaton.jl uses the callback inside the linesearch.

(This could be NLOpt’s fault, too, not necessarily Optimization.jl)

So, the tentative conclusion seems to be that the gradients from Zygote are probably okay, but there’s some rough edges in the frameworks.

Anyway, thanks! This has been instructive!

P.S.: opened an issue at The `callback` appears to be called for linesearch iterations · Issue #724 · SciML/Optimization.jl · GitHub

Yeah, that’s implementation-specific. It’s worth an issue. I’m not sure we can easily solve this since we’re mostly using the callback interfaces of each library, so it will have the properties that the solver library gives us.

1 Like