Sensitivity of an ODEProblem defined by another package (QuantumToolbox)

I am interested in calculating derivatives of a quantum state defined by QuantumToolbox with respect to system parameters in order to calculate quantum fisher information. I have found initial success with a finite difference method by solving two copies of the system with slightly different parameters. Here’s how that looks:

using QuantumToolbox

function final_state(A)
    #Example from QT docs

    N = 4 # Fock space truncated dimension

    ωa = 1
    ωc = 1 * ωa    # considering cavity and atom are in resonance
    σz = sigmaz() ⊗ qeye(N) # order of tensor product should be consistent throughout
    a  = qeye(2)  ⊗ destroy(N)  
    Ω  = 0.05
    σ  = sigmam() ⊗ qeye(N)

    Ha = ωa / 2 * σz
    Hc = ωc * a' * a # the symbol `'` after a `QuantumObject` act as adjoint
    Hint = Ω * (σ * a' + σ' * a)
    Hdrive = (a' + a)

    omega = 2

    coef(p, t) = A*sin(omega*t)

    H_t = QobjEvo(Hdrive, coef)

    Htot  = Ha + Hc + Hint +H_t
    e_ket = basis(2,0) 
    ψ0 = e_ket ⊗ fock(N, 0)

    tlist = 0:2.5:1000 # a list of time points of interest

    κ = 0.3
    L = sqrt(κ) * a

    sol_me  = mesolve(Htot,  ψ0, tlist,[L])
    return sol_me.states[end]
end

function derivative(A,dA)
    rho1 = final_state(A).data
    rho2 = final_state(A+dA).data
    
    return (rho1, rho2, (rho2 - rho1)/dA)
end

(rho1, rho2, rhodot) = derivative(3,0.01)

However, this is not going to scale well, and I see that there are lots of existing tools to solve this kind of problem. I found ODEForwardSensitivityProblem, and the time series of derivatives provided by

x,dp = extract_local_sensitivities(sol)

is exactly the functionality I am interested in. My first attempt was to create a new function within QuantumToolbox to expose the ODEProblem and construct a sensitivity problem from this, but
I wonder if there is a way to perform the sensitivity analysis from the ‘outside’ without modifying QuantumToolbox. From what I can see, many of the auto differentiation libraries can do this kind of thing, differentiating arbitrary Julia code. One problem I’ve run into on a few attempts to use an AD library is that the ODESystem constructed by mesolve() is complex valued, and the various AD libraries have been upset by the complex numbers.

What is the best way to get the functionality of extract_local_sensitivities() from a function which has not been created as an ODEForwardSensitivityProblem, and which is complex-valued?

If you use Zygote then complex numbers should be fine, and the adjoint/forward sensitivity overloads are fine. You might need to use a forward/adjoint sensitivity with numerical vjp/jvp though, but I would actually be surprised since Enzyme and Zygote handle complex fine. It might be best to just make an MWE for SciMLSensitivty and I can work through it.

There are already a few test cases in the repo covering complex numbers:

Thanks for checking this out. As you guessed, the complex numbers were not the real issue.

Since QuantumToolbox wraps both the ODEProblem construction and solution within package-specific functions, I now think that the best solution will treat mesolve() as a black box and ignore the fact that mesolve() constructs and solves an ODEProblem internally.

Differentiate ‘through’ QuantumToolbox with FiniteDiff

I wrapped the time evolution into a function that takes parameters only, and then tried to calculate the jacobian with several backends to DifferentiationInterface. The good news: FiniteDiff works, so one path forward is to finite difference the whole trajectory.

using QuantumToolbox
using DifferentiationInterface
using SciMLSensitivity
using FiniteDiff

function final_state(p)
    f1(p, t) = p[1] * cos(p[2] * t)
    f2(p, t) = p[3] * sin(p[4] * t)
    γ(p, t)  = sqrt(p[5] * exp(-p[6] * t))

    H_t = sigmaz() + QobjEvo(sigmax(), f1) + QobjEvo(sigmay(), f2)

    c_ops = [
        QobjEvo(destroy(2), γ)
    ]

    ψ0 = basis(2, 0)
    tlist = 0:2.5:100

    sol = mesolve(H_t, ψ0, tlist, c_ops; params = p)
    return sol.states[end].data #'data' is a field of a 'Qobj' which stores its matrix representation
end

jac = DifferentiationInterface.jacobian(final_state,AutoFiniteDiff(),[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

Attempt at autodifferentiation with Zygote

If I switch the backend to Zygote and try

jac = DifferentiationInterface.jacobian(final_state,AutoZygote(),[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

I get


┌ Warning: Using fallback BLAS replacements for (["zgemv_64_", "zhemv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:3649
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/concrete_solve.jl:68

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/concrete_solve.jl:208
ERROR: MethodError: no method matching setvjp(::GaussAdjoint{0, false, Val{:central}, Bool}, ::ReverseDiffVJP{false})
The function `setvjp` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  setvjp(::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK}, ::Any) where {CS, AD, FDT, VJP, LS, LK}
   @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/sensitivity_algorithms.jl:1096
  setvjp(::GaussAdjoint{CS, AD, FDT, Nothing}, ::Any) where {CS, AD, FDT}
   @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/sensitivity_algorithms.jl:574
  setvjp(::QuadratureAdjoint{CS, AD, FDT, Nothing}, ::Any) where {CS, AD, FDT}
   @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/sensitivity_algorithms.jl:479
  ...

Stacktrace:
  [1] _concrete_solve_adjoint(::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEqTsit5.Tsit5{…}, ::Nothing, ::Vector{…}, ::Vector{…}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/KLsRs/src/concrete_solve.jl:277
  [2] _concrete_solve_adjoint
    @ ~/.julia/packages/SciMLSensitivity/KLsRs/src/concrete_solve.jl:246 [inlined]
  [3] #_solve_adjoint#67
    @ ~/.julia/packages/DiffEqBase/5hvMq/src/solve.jl:1685 [inlined]
  [4] _solve_adjoint
    @ ~/.julia/packages/DiffEqBase/5hvMq/src/solve.jl:1658 [inlined]
  [5] #rrule#4
    @ ~/.julia/packages/DiffEqBase/5hvMq/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
  [6] rrule
    @ ~/.julia/packages/DiffEqBase/5hvMq/ext/DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
  [7] rrule
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:138 [inlined]
  [8] chain_rrule
    @ ~/.julia/packages/Zygote/wfLOG/src/compiler/chainrules.jl:234 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context{…}, ::typeof(DiffEqBase.solve_up), ::SciMLBase.ODEProblem{…}, ::Nothing, ::Vector{…}, ::Vector{…}, ::OrdinaryDiffEqTsit5.Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:81
 [11] _apply
    @ ./boot.jl:946 [inlined]
 [12] adjoint
    @ ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:199 [inlined]
 [13] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [14] #solve#43
    @ ~/.julia/packages/DiffEqBase/5hvMq/src/solve.jl:1089 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#43", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEqTsit5.Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [16] _apply
    @ ./boot.jl:946 [inlined]
 [17] adjoint
    @ ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:199 [inlined]
 [18] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [19] solve
    @ ~/.julia/packages/DiffEqBase/5hvMq/src/solve.jl:1079 [inlined]
 [20] mesolve
    @ ~/.julia/packages/QuantumToolbox/FxWqe/src/time_evolution/mesolve.jl:203 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(mesolve), ::QuantumToolbox.TimeEvolutionProblem{…}, ::OrdinaryDiffEqTsit5.Tsit5{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [22] #mesolve#136
    @ ~/.julia/packages/QuantumToolbox/FxWqe/src/time_evolution/mesolve.jl:199 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::QuantumToolbox.var"##mesolve#136", ::OrdinaryDiffEqTsit5.Tsit5{…}, ::Nothing, ::Vector{…}, ::Val{…}, ::Val{…}, ::@Kwargs{}, ::typeof(mesolve), ::QuantumObjectEvolution{…}, ::QuantumObject{…}, ::StepRangeLen{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [24] mesolve
    @ ~/.julia/packages/QuantumToolbox/FxWqe/src/time_evolution/mesolve.jl:162 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(mesolve), ::QuantumObjectEvolution{…}, ::QuantumObject{…}, ::StepRangeLen{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [26] final_state
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:22 [inlined]
 [27] _pullback(ctx::Zygote.Context{false}, f::typeof(final_state), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [28] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [29] adjoint
    @ ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:199 [inlined]
 [30] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [31] call_composed
    @ ./operators.jl:1054 [inlined]
 [32] call_composed
    @ ./operators.jl:1053 [inlined]
 [33] #_#113
    @ ./operators.jl:1050 [inlined]
 [34] _pullback(::Zygote.Context{…}, ::Base.var"##_#113", ::@Kwargs{}, ::ComposedFunction{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [35] _apply
    @ ./boot.jl:946 [inlined]
 [36] adjoint
    @ ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:199 [inlined]
 [37] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [38] ComposedFunction
    @ ./operators.jl:1050 [inlined]
 [39] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [40] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:96
 [41] pullback
    @ ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:94 [inlined]
 [42] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/grad.jl:181
 [43] jacobian
    @ ~/.julia/packages/Zygote/wfLOG/src/lib/grad.jl:168 [inlined]
 [44] jacobian
    @ ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:165 [inlined]
 [45] jacobian(::typeof(final_state), ::AutoZygote, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:106
 [46] top-level scope
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:26
Some type information was truncated. Use `show(err)` to see complete types.

Attempt with Enzyme

ERROR: Enzyme compilation failed due to illegal type analysis.
 This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
 Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Caused by:
Stacktrace:
 [1] +
   @ ~/.julia/dev/QuantumToolbox/src/qobj/arithmetic_and_attributes.jl:48

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/errors.jl:366
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/errors.jl:249
  [3] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/.julia/packages/Enzyme/8d7o7/src/api.jl:338
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:1793
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:4692
  [6] codegen
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:3455 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5553
  [8] _thunk
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5553 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5605 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5716
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5901
 [12] autodiff
    @ ~/.julia/packages/Enzyme/8d7o7/src/Enzyme.jl:641 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/8d7o7/src/Enzyme.jl:525 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:726 [inlined]
 [15] #gradient#124
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:582 [inlined]
 [16] #jacobian#126
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:789 [inlined]
 [17] jacobian
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:788 [inlined]
 [18] jacobian(::typeof(final_state), ::DifferentiationInterfaceEnzymeExt.EnzymeForwardOneArgJacobianPrep{…}, ::AutoEnzyme{…}, ::Vector{…})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl:259
 [19] jacobian(::typeof(final_state), ::AutoEnzyme{Nothing, Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:106
 [20] top-level scope
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:26
Some type information was truncated. Use `show(err)` to see complete types.

Enzyme with the Enzyme.API.strictAliasing!(false) suggestion

ERROR: LLVM.LLVMException("function failed verification (4)")
Stacktrace:
  [1] handle_error(reason::Cstring)
    @ LLVM ~/.julia/packages/LLVM/2JPxT/src/core/context.jl:194
  [2] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/.julia/packages/Enzyme/8d7o7/src/api.jl:338
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:1793
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:4692
  [5] codegen
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:3455 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5553
  [7] _thunk
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5553 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5605 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5716
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/8d7o7/src/compiler.jl:5901
 [11] autodiff
    @ ~/.julia/packages/Enzyme/8d7o7/src/Enzyme.jl:641 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/8d7o7/src/Enzyme.jl:525 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:726 [inlined]
 [14] #gradient#124
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:582 [inlined]
 [15] #jacobian#126
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:789 [inlined]
 [16] jacobian
    @ ~/.julia/packages/Enzyme/8d7o7/src/sugar.jl:788 [inlined]
 [17] jacobian(::typeof(final_state), ::DifferentiationInterfaceEnzymeExt.EnzymeForwardOneArgJacobianPrep{…}, ::AutoEnzyme{…}, ::Vector{…})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl:259
 [18] jacobian(::typeof(final_state), ::AutoEnzyme{Nothing, Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:106
 [19] top-level scope
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:27
Some type information was truncated. Use `show(err)` to see complete types.

Mooncake

ERROR: No rrule!! available for foreigncall with primal argument types Tuple{Val{:jl_get_world_counter}, Val{UInt64}, Tuple{}, Val{0}, Val{:ccall}}. This problem has most likely arisen because there is a ccall somewhere in the function you are trying to differentiate, for which an rrule!! has not been explicitly written.You have three options: write an rrule!! for this foreigncall, write an rrule!! for a Julia function that calls this foreigncall, or re-write your code to avoid this foreigncall entirely. If you believe that this error has arisen for some other reason than the above, or the above does not help you to workaround this problem, please open an issue.
Stacktrace:
  [1] rrule!!(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/rrules/foreigncall.jl:12
  [2] methods
    @ ./reflection.jl:1220 [inlined]
  [3] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual, none::Mooncake.CoDual, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [4] DerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
  [5] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1827
  [6] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
  [7] _QobjEvo_generate_data
    @ ~/.julia/dev/QuantumToolbox/src/qobj/quantum_object_evo.jl:359 [inlined]
  [8] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [9] DerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
 [10] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1827
 [11] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
 [12] #QuantumObjectEvolution#13
    @ ~/.julia/dev/QuantumToolbox/src/qobj/quantum_object_evo.jl:271 [inlined]
 [13] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [14] DerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
 [15] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1827
 [16] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
 [17] final_state
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:9 [inlined]
 [18] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [19] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/interpreter/s2s_reverse_mode_ad.jl:966
 [20] prepare_pullback_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
    @ Mooncake ~/.julia/packages/Mooncake/Lwrnz/src/interface.jl:414
 [21] prepare_pullback_cache
    @ ~/.julia/packages/Mooncake/Lwrnz/src/interface.jl:407 [inlined]
 [22] prepare_pullback_nokwarg(::Val{…}, ::typeof(final_state), ::AutoMooncake{…}, ::Vector{…}, ::Tuple{…})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:14
 [23] _prepare_jacobian_aux(::Val{…}, ::DifferentiationInterface.PushforwardSlow, ::DifferentiationInterface.BatchSizeSettings{…}, ::Matrix{…}, ::Tuple{…}, ::AutoMooncake{…}, ::Vector{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:245
 [24] prepare_jacobian_nokwarg(::Val{true}, ::typeof(final_state), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:182
 [25] jacobian(::typeof(final_state), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:105
 [26] top-level scope
    @ ~/.julia/dev/SystemLevelHamiltonian/examples/autodiffToolbox.jl:26
Some type information was truncated. Use `show(err)` to see complete types.

Previous strategy, extract ODEProblem from QuantumToolbox

When I wrote the first post I thought there would be a clean way to extract the ODEProblem and convert it to a sensitivity problem, via something like

prob = generate_ode_problem(Htot, ψ0, tlist, L; e_ops = eop_ls, params = p) #<- this is a function which would need to be added to QuantumToolbox
senseprob = ODEForwardSensitivityProblem(prob.f,prob.u0,prob.tspan,prob.p;sensealg= ForwardSensitivity()) 
sol = solve(senseprob)
x,dp = extract_local_sensitivities(sol)

While this code can run, all the derivatives come out zero, and tracking down why leads me into the bowels of QuantumToolbox. I don’t quite understand how they are constructing and solving their ODEProblem.

Here is one method for mesolve() in QuantumToolbox.jl.

function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
    sol = solve(prob.prob, alg)

    # No type instabilities since `isoperket` is a Val, and so it is known at compile time
    if getVal(prob.kwargs.isoperket)
        ρt = map(ϕ -> QuantumObject(ϕ, type = OperatorKet(), dims = prob.dimensions), sol.u)
    else
        ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u)
    end

    return TimeEvolutionSol(
        prob.times,
        ρt,
        _get_expvals(sol, SaveFuncMESolve),
        sol.retcode,
        sol.alg,
        NamedTuple(sol.prob.kwargs).abstol,
        NamedTuple(sol.prob.kwargs).reltol,
    )
end

This strikes me as a restrictive wrapper, why not let the user pass kwargs to solve()? I suppose you could just add them to the signature and pass them in ‘manually,’ but I imagine this could lead to unexpected behavior and I understand that QuantumToolbox is interested in being identical to QuTiP.

A nicer solution would have solve() return a solution composed of the package specific data type QObj. Are there any packages which achieve this effect?

Remaining Questions:

  • Is there a preferred way to construct a timeseries of derivatives with FiniteDiff?
  • Is it feasible to auto differentiate through QuantumToolbox’s solver? What more would I need to learn about the internals of the package to figure this out?
  • How have you seen domain-specific packages successfully wrap ODEProblem so that users can cleanly work with the SciML interface if they like?

What is your callback here? Is this package still using the FunctionCallingCallback? Last time I dove into this package was like 2019.