EnsembleProblem with SteadyStateProblem in Zygote

Is there a premade parallel simulation version of SteadyStateProblem (currently using a for loop)? I am assuming there isn’t because EnsembleProblem expects tspan which I don’t provide. My goal is to obtain the steady state solutions with adjoints for training a neural ODE. Thanks!

It should work. Share an MWE

using DifferentialEquations
using SciMLSensitivity
using Plots
using Zygote

# Some model with multiple steady states
function some_model(u, p, t)
    dxdt = p[1]*u[1]*u[2]
    dydt = -p[2]*u[2]
    return [dxdt, dydt]
end

# Specify the initial conditions for each simulation
function prob_func(prob, i, repeat)
    global x0
    remake(prob, u0=x0[i, :])
end

x0 = hcat(rand(10), ones(10)) # Initialize some random initial conditions
foop = [1., 1.] # Parameters
tspan = LinRange(0, 50, 25)

# Do some random gradients just to make sure code works for parametrizing stuff
foo_gs = gradient(foop) do foop
    foo_prob = ODEProblem(some_model, u0, (tspan[1], tspan[end]), foop) # Works
    foo_ensemble = EnsembleProblem(foo_prob, prob_func=prob_func)
    foo_sol = solve(
        foo_ensemble, 
        Tsit5(), 
        EnsembleThreads(), 
        trajectories=size(x0)[1];
        saveat = tspan,
        sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    )
    Zygote.ignore_derivatives() do
        plot(foo_sol) # Plot to see what is going on
    end

    # Some random function
    foo_val = 0.
    for some_sol in foo_sol.u
        _some_sol_mat = mapreduce(permutedims, vcat, some_sol.u)
        foo_val += sum(abs2, _some_sol_mat)
    end
    @show foo_val
    return foo_val
end

# The following below doesn't work
foo_gs_SS = gradient(foop) do foop
    foo_SS_prob = SteadyStateProblem(some_model, u0, foop)
    foo_SS_ensemble = EnsembleProblem(foo_SS_prob, prob_func=prob_func)
    foo_SS_sol = solve(
        foo_SS_ensemble,
        DynamicSS(Tsit5(), abstol=1e-4, reltol=1e-3, tspan=tspan[end]),
        EnsembleThreads(), 
        trajectories=size(x0)[1];
        sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    )
    # @show Array(foo_SS_sol.u[1])
    
    # Some random function
    foo_val = 0.
    for some_sol in foo_SS_sol.u
        foo_val += sum(abs2, some_sol)
    end
    @show foo_val
    return foo_val
end

I’m not too familiar with the backend of AD packages, but using Zygote seems to break the ensemble problem. Using it without gradients work just fine. The error message is “ERROR: type SteadyStateProblem has no field tspan”.

Thanks! I see. In general Zygote works with ensemble problems (it’s tested with ODEs, SDEs, DDEs, DAEs), but this combination seems to require something extra. Can you open an issue on SciMLSensitivity.jl? I can solve this by the end of the week.

(post deleted by author)

Solved, in a way. The issue is that for steady state problems you should use SteadyStateAdjoint. Throw a better error for time-based adjoint on no-time problem by ChrisRackauckas · Pull Request #705 · SciML/SciMLSensitivity.jl · GitHub solves this by throwing a very explicit error message saying that.