For now pass the sensitivity algorithm in the loss function:
function loss(prob, p)
prob = remake(prob; p)
sol = solve(prob, Tsit5(); sensealg, callback=cb, saveat=0.1)
return mean(sum(sol; dims=1)) * tspan[end]
end
Zygote.withgradient((p) -> loss(prob, p), params)
@ChrisRackauckas @frankschae Changing the get_Fake_Integrator definition to
function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
FakeIntegrator([x for x in u], p, t, tprev)
end
fixes the reversediff problem.