Value of gradients with DiffEqFlux and solve

As recently announced, solve command is now compatible with Flux and should be used. I am confused about the value of the gradients which I obtain in these two very simple examples.

using Flux, DiffEqFlux, DifferentialEquations

#parameter
p=[5.0f0]
#ini state
u0=[2.0f0]

function f(u,p,t) 
   p[1]*u
end

dt=0.005f0
tspan = (0.0,dt)

function dynamics1(p)
   prob1 = ODEProblem(f,u0,tspan,p)
   solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end

prob2 = ODEProblem(f,u0,tspan,p)
function dynamics2(p)
   solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end

#some loss function
loss1(p) = sum(abs.(dynamics1(p)-u0))
loss2(p) = sum(abs.(dynamics2(p)-u0))

θ = Flux.params(p)
gs1 = gradient(θ) do
   loss1(p)
end

gs2 = gradient(θ) do
   loss2(p)
end
#gs1 exactly twice larger as the other version
println("grads 1:", gs1[p] ) #Float32[0.020503124]
println("grads 2:", gs2[p] ) #Float32[0.010251562]

The only difference is that prob1 is defined inside the gradient loop and prob2 outside. In the first case (gs1) I get exactly twice larger gradient than for gs2. Where does this factor 2 come from? When I use the old concrete_solve the results come out identical. Can it be linked to the fact that solve does not explicitly take the parameters p as an argument?

Yeah, something is odd here. If you explicitly connect p to the problem you get 2x the gradient. Looks like a Zygote issue:

using Flux, DiffEqFlux, DifferentialEquations

#parameter
p=[5.0f0]
#ini state
u0=[2.0f0]

function f(u,p,t)
   p[1]*u
end

dt=0.005f0
tspan = (0.0,dt)

function dynamics1(p)
   prob1 = ODEProblem(f,u0,tspan,p)
   solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end

prob2 = ODEProblem(f,u0,tspan,p)
function dynamics2(p)
   solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end

function dynamics3(p)
   solve(prob2,Heun(),p=p,dt=dt,saveat=[dt],adaptive=false)[1,:]
end

function dynamics4()
   solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end

#some loss function
loss1(p) = sum(abs.(dynamics1(p)-u0))
loss2(p) = sum(abs.(dynamics2(p)-u0))
loss3(p) = sum(abs.(dynamics3(p)-u0))
loss4() = sum(abs.(dynamics4()-u0))

θ = Flux.params(p)
gs1 = gradient(θ) do
   loss1(p)
end

gs2 = gradient(θ) do
   loss2(p)
end

gs3 = gradient(θ) do
   loss3(p)
end

gs4 = gradient(θ) do
   loss4()
end

#gs1 exactly twice larger as the other version
println("grads 1:", gs1[p] ) #Float32[0.020503124]
println("grads 2:", gs2[p] ) #Float32[0.010251562]
println("grads 3:", gs3[p] ) #Float32[0.020503124]
println("grads 4:", gs4[p] ) #Float32[0.010251562]

Looks like the smaller one is correct. Tracking at Gradients double if connection to `p` is explicit? · Issue #271 · SciML/SciMLSensitivity.jl · GitHub

Thanks for having a look and opening an issue on GitHub! Yeah, the smaller value should be correct (it can be computed analytically).