Impossible to perform autodiff on simple ODE

Hello,

I was trying to differentiate a simple ODE using Zygote.jl, but it returns the following error:

ERROR: `p` is not a SciMLStructure. This is required for adjoint sensitivity analysis. For more information,
see the documentation on SciMLStructures.jl for the definition of the SciMLStructures interface.
In particular, adjoint sensitivities only applies to `Tunable`.

My code is the following

using LinearAlgebra
using OrdinaryDiffEq
using Zygote
using SciMLSensitivity

##

function dudt!(du, u, p, t)
    return mul!(du, p.U, u)
end

##

const T = ComplexF64
const N = 10
const u0 = ones(T, N)
H_tmp = rand(T, N, N)
const H = H_tmp + H_tmp'

function my_f(γ)
    U = -1im * H - γ * Diagonal(H)
    p = (U=U,)
    # p = [U]
    tspan = (0.0, 1.0)
    prob = ODEProblem{true}(dudt!, u0, tspan, p)
    sol = solve(prob, Tsit5())
    return sol.u[end]
end

my_f(1)

##

gradient(my_f, 1)

I think that the problem is related to the type of the ODE parameters. I tried to put is as a vector (p = [U]), but same problem. However, I need it as a NamedTuple, because in the final implementation, I would have different parameters of different types, and the NamedTuple leaves it type-stable.

Could you use ComponentArrays instead?

What is the advantage of using ComponentArrays instead of NamedTuples? Moreover, it seems from the examples that he ComponentArray needs to have all the fields of the same type. My params would be very generic, including also non-arrays types.

ComponentArrays are stored as a single vector, which simplifies some things. Anything you want to take the gradient of will (probably) need to be of the same type.

If I’ve got a complicated parameter structure, I sometimes separate it into the part that I want to differentiate with respect to, which is a ComponentArray or a flat vector, and everything else, which I leave as a struct or similar.

I would highly recommend using a ComponentArray here. Maybe we can modify the error message on NamedTuple of homogeneous to even directly suggest it.

Ok, I can try to be more specific. I’m integrating an ODE for an open quantum system. The function is called mesolve, and you can find it here.

The ODE itself is relatively easy. If it is time-independent, it can be seen as

\dot{\mathbf{u}} = \mathcal{L} \mathbf{u}

If there is some time dependence, I need to introduce a custom struct to handle this (because I assume a time dependence of the form \mathcal{L} (t) = \sum_n c_n (t) O).

Then, i also have a progress bar to print and a callback for saving the expectation values during time.

So, to recap, my params contains different objects:

  • Any AbstractArray, from sparse matrices to dense or static vectors ecc…
  • struct of AbstractArray, vectors and tuples of function
  • vector of AbstractArray
  • vector of Int64
  • possibly Nothing types
  • Bool types

Moreover, this ODE has also some callbacks.

Can I convert all this structure to anything compatible to Zygote.jl or Enzyme.jl?

This would be very useful to compare recent results of a new python package doing the same thing.

It should be fine as long as your f is differentiatiable.

It is differentiable, like the example of my first comment.

Then I don’t understand what the remaining question is :sweat_smile:

The code in my first comment doesn’t work. You suggested to use the parameters as ComponentArray, but my final ODE would include other parameters of different types, custom structs, and so on. So my question was, first, how to make my first MWE working, and second, how to deal when the params contains different variables, including structs, sparse matrices, AbstracArray, Nothing and so on.

You can use the SciMLStrucutres interface to declare the tunable portion and then it will only differentiate with respect to that subset. That set should be a single type anyways since you cannot diff bools and ints

Ok thanks. It seems that the documentation of SciMLStructures.jl doesn’t have an example. Could you give me a more specific example?

Let’s take the code of my first comment, but this time we have many parameters, let say

p = (
        U = U,
        e_ops = e_ops2,
        expvals = expvals,
        progr = progr,
        Hdims = H.dims,
        H_t = H_t,
        times = t_l,
        is_empty_e_ops = is_empty_e_ops,
        params...,
    )

where

  • U is the same matrix as before (dense or sparse with eltype=T)
  • e_ops is a vector of matrices (dense or sparse with eltype=T)
  • expvals dense matrix with eltype=T
  • progr a struct for progressbar, like ProgressMeter.jl
  • Hdims a Tuple of Int
  • H_t either Nothing or a struct (depending on time-dependence or not of the Liouvillian) containing matrices, vectors and functions. The matrices have eltype=T
  • times a vector of Real
  • is_empty_e_ops a Bool
  • params additional parameters needed for H_t

H_t is a struct of the type

struct TimeDependentOperatorSum{CFT<:Tuple,OST<:OperatorSum}
    coefficient_functions::CFT
    operator_sum::OST
end

where coefficient_functions is a tuple of functions, and

struct OperatorSum{CT<:AbstractVector{<:Number},OT<:Union{AbstractVector,Tuple}} <: AbstractQuantumObject
    coefficients::CT
    operators::OT
end

As far as I understood, I should divide it into the part that I want to differentiate, and the part that should be fixed. According to this, the fixed part would contain:

  • progr
  • Hdims
  • is_empty_e_ops

and the remaining ones would be involved in the differentiation. They can be of the same eltype=T (although some of them are structs not only containing matrices), except for tlist which is always Real compared to the others that are usually Complex.

So, how should I write my first example code?

Is there anyone who can help me?

The documentation for SciMLStructures.jl is missing of practical examples. In my case my params would include more complicated structs.

I really need this functionality for my QuantumToolbox.jl for quantum optimal control of open quantum systems.

I just haven’t gotten around to it. Maybe @cryptic.ax can find a bit of time.

You can track this PR: feat: add doc example for implementing the interface by AayushSabharwal · Pull Request #28 · SciML/SciMLStructures.jl · GitHub which is adding an example for the SciMLStructures interface.

Thanks for the link, it is very useful, and I’m starting to implementing it in my full code.

During that, I faced some issues, so I decided to analyze the simple code here. I noticed that the following code (which is a revision of the first example I wrote) doesn’t give an error, but returns nothing for the gradient.

using LinearAlgebra
using OrdinaryDiffEq
using SciMLOperators
using Zygote
using SciMLSensitivity

const T = ComplexF64
const N = 10
const u0 = ones(T, N)
H_tmp = rand(T, N, N)
const H = H_tmp + H_tmp'

function my_f(γ)
    U = MatrixOperator(-1im * H - γ * Diagonal(H))
    tspan = (0.0, 1.0)
    prob = ODEProblem{true}(U, u0, tspan)
    sol = solve(prob, Tsit5())
    return real(sol.u[end][end])
end

my_f(1.9) # -0.041501631765322705

gradient(my_f, 1.9)
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/concrete_solve.jl:24

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/concrete_solve.jl:67

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/concrete_solve.jl:207
(nothing,)

First of all, it is strange to see those warnings related to Enzyme, while I’m using Zygote. Then, it returns nothing. I thought that the problem could be related to the fact that the ODE is parameter free, so I decided to implement it in a different way

coef(a, u, p, t) = - p[1]

function my_f(γ)
    # U = MatrixOperator(-1im * H - γ * Diagonal(H))
    U = MatrixOperator(-1im * H) + ScalarOperator(one(T), coef) * MatrixOperator(Diagonal(H))
    tspan = (0.0, 1.0)
    prob = ODEProblem{true}(U, u0, tspan, [γ])
    sol = solve(prob, Tsit5())
    return real(sol.u[end][end])
end

my_f(1.9) # -0.04150163176532281

gradient(my_f, 1.9)
┌ Warning: Using fallback BLAS replacements for (["zgemv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/GnbhK/src/utils.jl:59
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/concrete_solve.jl:67

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/concrete_solve.jl:207
ERROR: InexactError: Float64(0.010428536019754344 + 0.048873021588869255im)
Stacktrace:
  [1] Real
    @ ./complex.jl:44 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]
  [3] setindex!
    @ ./array.jl:976 [inlined]
  [4] setindex!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:335 [inlined]
  [5] _setindex!
    @ ./abstractarray.jl:1436 [inlined]
  [6] setindex!
    @ ./abstractarray.jl:1413 [inlined]
  [7] _modify!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:91 [inlined]
  [8] _generic_matmatmul!(C::Adjoint{…}, A::Adjoint{…}, B::Matrix{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:923
  [9] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
 [10] _mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
 [11] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [12] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [13] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/XCu1T/src/gauss_adjoint.jl:469

It returns an error, together with the previous warnings.

I would like to point out that I may usually need the first case, which is much simpler. Of course the second one is also ok.

Yes but it has to use Enzyme VJPs on the f (here what you call U) in order to handle the mutation of the in-place form. If you use ODEProblem{false} then that wouldn’t be the case.

This is because your u0 is real valued but your operator returns complex values. const u0 = ones(T, N) change that to const u0 = ones(T, N) .+ 0im

Ok clear.

But I already have T = ComplexF64, and I checked that u0 is a Vector{ComplexF64}. Nonetheless, I tried your suggestion and I still have the same issue.

Moreover, why does the parameter-free implementation of the gradient returns nothing? The parameter dependence is implicit in the MatrixOperator itself.

If the parameter is hidden from the ODEProblem then it cannot be handled by an adjoint method. We usually catch this, though it’s hard to catch this in all cases a global can be introduced.

As for your true issue, okay it looks like we may need to do something in GaussAdjoint to better support complex there. What if you set it to InterpolatingAdjoint(autojacvec=false)

I get a different error. It uses ForwardDiff.jl, which doesn’t support complex numbers apparently.

ERROR: ArgumentError: Cannot create a dual over scalar type ComplexF64. If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{ComplexF64}) = true.
Stacktrace:
  [1] throw_cannot_dual(V::Type)
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:41
  [2] ForwardDiff.Dual{ForwardDiff.Tag{…}, ComplexF64, 10}(value::ComplexF64, partials::ForwardDiff.Partials{10, ComplexF64})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:18
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:673 [inlined]
  [4] _broadcast_getindex
    @ ./broadcast.jl:646 [inlined]
  [5] getindex
    @ ./broadcast.jl:605 [inlined]
  [6] macro expansion
    @ ./broadcast.jl:968 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] copyto!
    @ ./broadcast.jl:967 [inlined]
  [9] copyto!
    @ ./broadcast.jl:920 [inlined]
 [10] materialize!
    @ ./broadcast.jl:878 [inlined]
 [11] materialize!
    @ ./broadcast.jl:875 [inlined]
 [12] seed!(duals::Vector{ForwardDiff.Dual{…}}, x::Vector{ComplexF64}, seeds::NTuple{10, ForwardDiff.Partials{…}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:52

Nonetheless, even if I use T = Float64 instead of ComplexF64 (and also removing the -1im for the MatrixOperator), I get the error

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{ODEFunction{…}, Float64}, Float64, 10})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:265
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float64(::IrrationalConstants.Invπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112
  ...

Stacktrace:
  [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{ODEFunction{…}, Float64}, Float64, 10})
    @ Base ./number.jl:7
  [2] setproperty!(x::ScalarOperator{Float64, typeof(coef)}, f::Symbol, v::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 10})
    @ Base ./Base.jl:52
  [3] update_coefficients!(L::ScalarOperator{…}, u::Vector{…}, p::Vector{…}, t::Float64; kwargs::@Kwargs{})
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/scalar.jl:193
  [4] update_coefficients!(L::ScalarOperator{…}, u::Vector{…}, p::Vector{…}, t::Float64)
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/scalar.jl:192
  [5] update_coefficients!(L::SciMLOperators.ScaledOperator{…}, u::Vector{…}, p::Vector{…}, t::Float64)
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/basic.jl:251
  [6] macro expansion
    @ ~/.julia/packages/SciMLOperators/Q5dkx/src/basic.jl:462 [inlined]
  [7] update_coefficients!(L::SciMLOperators.AddedOperator{…}, u::Vector{…}, p::Vector{…}, t::Float64)
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/basic.jl:457
  [8] (::SciMLOperators.AddedOperator{…})(du::Vector{…}, u::Vector{…}, p::Vector{…}, t::Float64; kwargs::@Kwargs{})
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/interface.jl:116
  [9] (::SciMLOperators.AddedOperator{…})(du::Vector{…}, u::Vector{…}, p::Vector{…}, t::Float64)
    @ SciMLOperators ~/.julia/packages/SciMLOperators/Q5dkx/src/interface.jl:115
 [10] (::ODEFunction{…})(::Vector{…}, ::Vararg{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/hJh6T/src/scimlfunctions.jl:2355
 [11] (::ODEFunction{…})(::Vector{…}, ::Vararg{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/hJh6T/src/scimlfunctions.jl:2355