DIffEqFlux loss function with variable solve options

Does someone know a good way to pass solve options to a custom loss function? We are building a package and would like to provide a loss function of the DiffEqFlux type, and I would like the user to be able to control the solver from the outside, e.g. in pseudocode:

function library_loss(p; solver_options...)
  sol = solve(Problem; solver_options...)
end

But this does not seem to be compatible with Zygote which complains about named tuples no matter how I try to hide this…

Edit: If needed (e.g. no one has come across this) I will whip up a MWE sometime next week…

make an MWE. That should be fine? Unless it’s a Zygote bug, but I thought I found all of the Zygote NamedTuple bugs already. Worth getting it documented for @dhairyagandhi96

2 Likes

Thanks Chris for letting me know it should work! The NamedTuple Error was triggered by incorrect behaviour in our code (missing Semicolon), the only surprising part is that the bug doesn’t prevent the normal code to run (as long as no keyword arguments are actually passed):

using DiffEqFlux
using OrdinaryDiffEq

    
function l1(p; solver_options...)
    dd_spec = solve(ODEProblem((dy,y,p,t) -> dy .= p, zeros(2), (0., 10.), p), Tsit5(), saveat=0.:0.1:10.
    , solver_options...) # Comma instead of semicolon here
    sum(dd_spec[end])
end

l1(ones(2)) # This works despite the missing semicolon

DiffEqFlux.sciml_train(
    l1,
    ones(2),
    DiffEqFlux.ADAM(0.5),
    maxiters = 2) # NamedTuple Error.

# Apparently splatted keyword arguments being passed
# splatted into arguments is not so good (and almost 
# certainly this is never going to be the intended behaviour).
# The following works as expected:

function l2(p; solver_options...)
    dd_spec = solve(ODEProblem((dy,y,p,t) -> dy .= p, zeros(2), (0., 10.), p), Tsit5(), saveat=0.:0.1:10.; solver_options...)
    sum(dd_spec[end])
end

l2(ones(2))

DiffEqFlux.sciml_train(
    l2,
    ones(2),
    DiffEqFlux.ADAM(0.5),
    maxiters = 2)

While reducing this I also came across the follwoing failure to differentiate. We create a temporary array of arrays, don’t use it, and that throws off the adjoint. Again, this is triggered by having code you don’t want to have in the first place, but that is perfectly valid.

function l3(p)
    ps = [p[1 + (n - 1) * 2:n * 2] for n in 1:10] # This is unused

    dd_spec = solve(ODEProblem((dy,y,p,t) -> dy .= p, zeros(2), (0., 10.), p[1:2]), Tsit5())
    sum(dd_spec[end])
end

l3(ones(20))

DiffEqFlux.sciml_train(
    l3,
    ones(20),
    DiffEqFlux.ADAM(0.5),
    maxiters = 2) # MethodError: no method matching iterate(::Nothing)

Yeah, open an issue on Zygote for that. @dhairyagandhi96 should take notice. But what you’re pointing out are all Zygote behaviors and not DiffEqFlux specifically.

Hmm, we did fix a bug with the kwargs handling some time back, if so could you please MWE something like


 julia> f(x; kwargs...) = 2x .+ kwargs[1]
 f (generic function with 1 method)
 
 julia> f(5, b = true)
 11
 
 julia> gradient(x -> f(x, b = true), 40)
 (2,)

Could you share the specific named tuples error you saw?

For the other error, an mwe like so might be it

julia> function g(x)
         # p = zeros(size(x))
         p = [2i for i in x]
         sum(x)
       end
g (generic function with 2 methods)

julia> gradient(g, rand(3,3))
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::DataStructures.SparseIntSet, ::Any...) at /home/dhairyalgandhi/.julia/packages/DataStructures/mgePl/src/sparse_int_set
.jl:147
  iterate(::Base.RegexMatchIterator) at regex.jl:552
  iterate(::Base.RegexMatchIterator, ::Any) at regex.jl:552
  ...

Yeah, that’s it. I posted a MWE for the first error here:

Edit: And the other is here: