What is the recommended approach to @unpack parameters in a differential equation when using gradient wrt these parameters:
I could not get named tuples or dicts to work, this is what I tried:
using DiffEqSensitivity
using DifferentialEquations
using Flux
using Parameters
#p = [1.0, 2.0] # works, but can't @unpack this array
p = (a=1.0, b=2.0) # named tuple: ERROR: MethodError: no method matching similar(::NamedTuple{(:a, :b),Tuple{Float64,Float64}})
#p = Dict(:a=>1.0, :b=>2.0) # dictionary: ERROR: MethodError: no method matching similar(::Dict{Symbol,Float64})
function f!(du, u, p, t)
#a, b = p # works, but @unpack would be prettier
@unpack a, b = p # causes gradient to fail with errors printed above
@. du = a * u + b
end
@show gradient(
() -> sum(solve(
ODEProblem(f!, [0.0], (0.0, 10.0), p),
Tsit5(), sensealg = QuadratureAdjoint())),
params(p))
a, b = p works perfectly, but then I have to keep the arguments in the correct order. I am trying to keep track of about 100 different arguments, @unpack would be cleaner.
The difficulty is that you need something that is integrable. @jonniedie does @unpack work with ComponentArrays? That might be the solution here since the adjoints work.
@unpack works. It looks like I had a convert method missing for Zygote though (which, oddly enough, I’ve never run into in similar problems). The fix is getting registered now.
But I’m getting zero for the gradients. Not sure what’s missing:
using ComponentArrays
using DiffEqSensitivity
using DifferentialEquations
using Flux
using Parameters
p = ComponentArray(a=1.0, b=2.0)
function f!(du, u, p, t)
# a, b = p # works, but @unpack would be prettier
@unpack a, b = p
@. du = a * u + b
end
@show grad = Flux.gradient(
(par) -> sum(solve(
ODEProblem(f!, [0.0], (0.0, 10.0), par),
Tsit5(),
sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()),
)),
p)
Things seem to be working fine with Zygote:
julia> Zygote.gradient(x-> sum(convert(typeof(x), x).^2), p)
((a = 2.0, b = 4.0),)
julia> function f(u, p, t)
@unpack a, b = p
a .* u .+ b
end
f (generic function with 1 method)
julia> gradient(x->sum(f([0.0], x, 0.0)), p)
((a = 0.0, b = 1.0),)
@unpack works for me as well, thanks! I’ll check out your fix later.
When I use QuadratureAdjoint I am running into ERROR: type TrackedArray has no field a: getproperty(::ReverseDiff.TrackedArray{Int64,Float64,1,ComponentArray{Int64,1,Array{Int64,1},Tuple{Axis{(a = 1, b = 2)}}},ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(a = 1, b = 2)}}}}, ::Symbol) at .\Base.jl:33
p = ComponentArray(a=1, b=2)
function f!(du, u, p, t)
@unpack a,b = p
@. du = a * u + b
end
@show gradient(
() -> sum(solve(
ODEProblem(f!, [0.0], (0.0, 10.0), p),
Tsit5(), sensealg = QuadratureAdjoint())),
params(p))