Differentiable @unpack-ing of function arguments

What is the recommended approach to @unpack parameters in a differential equation when using gradient wrt these parameters:
I could not get named tuples or dicts to work, this is what I tried:

    using DiffEqSensitivity
    using DifferentialEquations
    using Flux
    using Parameters    
    #p = [1.0, 2.0] # works, but can't @unpack this array
    p = (a=1.0, b=2.0) # named tuple: ERROR: MethodError: no method matching similar(::NamedTuple{(:a, :b),Tuple{Float64,Float64}})
    #p = Dict(:a=>1.0, :b=>2.0) # dictionary: ERROR: MethodError: no method matching similar(::Dict{Symbol,Float64})
    function f!(du, u, p, t)
        #a, b = p # works, but @unpack would be prettier
        @unpack a, b = p # causes gradient to fail with errors printed above
        @. du = a * u + b
    end
    @show gradient(
        () -> sum(solve(
                ODEProblem(f!, [0.0], (0.0, 10.0), p),
                Tsit5(), sensealg = QuadratureAdjoint())),        
        params(p))

a, b = p (without the @unpack) should be fine.

a, b = p works perfectly, but then I have to keep the arguments in the correct order. I am trying to keep track of about 100 different arguments, @unpack would be cleaner.

The difficulty is that you need something that is integrable. @jonniedie does @unpack work with ComponentArrays? That might be the solution here since the adjoints work.

@unpack works. It looks like I had a convert method missing for Zygote though (which, oddly enough, I’ve never run into in similar problems). The fix is getting registered now.

But I’m getting zero for the gradients. Not sure what’s missing:

using ComponentArrays
using DiffEqSensitivity
using DifferentialEquations
using Flux
using Parameters

p = ComponentArray(a=1.0, b=2.0)

function f!(du, u, p, t)
    # a, b = p # works, but @unpack would be prettier
    @unpack a, b = p
    @. du = a * u + b
end

@show grad = Flux.gradient(
        (par) -> sum(solve(
                ODEProblem(f!, [0.0], (0.0, 10.0), par),
                Tsit5(),
                sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()),
            )),        
        p)

Things seem to be working fine with Zygote:

julia> Zygote.gradient(x-> sum(convert(typeof(x), x).^2), p)
((a = 2.0, b = 4.0),)

julia> function f(u, p, t)
           @unpack a, b = p
           a .* u .+ b
       end
f (generic function with 1 method)

julia> gradient(x->sum(f([0.0], x, 0.0)), p)
((a = 0.0, b = 1.0),)

Looks like out-of-place works fine, though.

julia> @show Flux.gradient((par->begin
                #= C:\Users\jdiegelm\.julia\dev\ComponentArrays\examples\wip\diff.jl:21 =#
                sum(solve(ODEProblem(f, [0.0], (0.0, 10.0), par), Tsit5(), sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())))
            end), p) = ((a = 682437.0287268654, b = 40552.658586328536),)
((a = 682437.0287268654, b = 40552.658586328536),)

@unpack works for me as well, thanks! I’ll check out your fix later.
When I use QuadratureAdjoint I am running into ERROR: type TrackedArray has no field a:
getproperty(::ReverseDiff.TrackedArray{Int64,Float64,1,ComponentArray{Int64,1,Array{Int64,1},Tuple{Axis{(a = 1, b = 2)}}},ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(a = 1, b = 2)}}}}, ::Symbol) at .\Base.jl:33

       p = ComponentArray(a=1, b=2)
    function f!(du, u, p, t)
        @unpack a,b = p
        @. du = a * u + b
    end
    @show gradient(
        () -> sum(solve(
                ODEProblem(f!, [0.0], (0.0, 10.0), p),
                Tsit5(), sensealg = QuadratureAdjoint())),        
        params(p))

Yeah, it’s gotta be QuadratureAdjoint(autojacvec=ZygoteVJP()). Tracker and ReverseDiff don’t work with ComponentArrays

ok, thanks for the explanation!

1 Like

In-place Zygote has never really worked for us. Technically buffers should work, but in reality it doesn’t so we always default away from it.

1 Like