As recently announced, solve
command is now compatible with Flux and should be used. I am confused about the value of the gradients which I obtain in these two very simple examples.
using Flux, DiffEqFlux, DifferentialEquations
#parameter
p=[5.0f0]
#ini state
u0=[2.0f0]
function f(u,p,t)
p[1]*u
end
dt=0.005f0
tspan = (0.0,dt)
function dynamics1(p)
prob1 = ODEProblem(f,u0,tspan,p)
solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end
prob2 = ODEProblem(f,u0,tspan,p)
function dynamics2(p)
solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end
#some loss function
loss1(p) = sum(abs.(dynamics1(p)-u0))
loss2(p) = sum(abs.(dynamics2(p)-u0))
θ = Flux.params(p)
gs1 = gradient(θ) do
loss1(p)
end
gs2 = gradient(θ) do
loss2(p)
end
#gs1 exactly twice larger as the other version
println("grads 1:", gs1[p] ) #Float32[0.020503124]
println("grads 2:", gs2[p] ) #Float32[0.010251562]
The only difference is that prob1
is defined inside the gradient loop and prob2
outside. In the first case (gs1
) I get exactly twice larger gradient than for gs2
. Where does this factor 2 come from? When I use the old concrete_solve
the results come out identical. Can it be linked to the fact that solve
does not explicitly take the parameters p
as an argument?
Yeah, something is odd here. If you explicitly connect p
to the problem you get 2x the gradient. Looks like a Zygote issue:
using Flux, DiffEqFlux, DifferentialEquations
#parameter
p=[5.0f0]
#ini state
u0=[2.0f0]
function f(u,p,t)
p[1]*u
end
dt=0.005f0
tspan = (0.0,dt)
function dynamics1(p)
prob1 = ODEProblem(f,u0,tspan,p)
solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end
prob2 = ODEProblem(f,u0,tspan,p)
function dynamics2(p)
solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end
function dynamics3(p)
solve(prob2,Heun(),p=p,dt=dt,saveat=[dt],adaptive=false)[1,:]
end
function dynamics4()
solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:]
end
#some loss function
loss1(p) = sum(abs.(dynamics1(p)-u0))
loss2(p) = sum(abs.(dynamics2(p)-u0))
loss3(p) = sum(abs.(dynamics3(p)-u0))
loss4() = sum(abs.(dynamics4()-u0))
θ = Flux.params(p)
gs1 = gradient(θ) do
loss1(p)
end
gs2 = gradient(θ) do
loss2(p)
end
gs3 = gradient(θ) do
loss3(p)
end
gs4 = gradient(θ) do
loss4()
end
#gs1 exactly twice larger as the other version
println("grads 1:", gs1[p] ) #Float32[0.020503124]
println("grads 2:", gs2[p] ) #Float32[0.010251562]
println("grads 3:", gs3[p] ) #Float32[0.020503124]
println("grads 4:", gs4[p] ) #Float32[0.010251562]
Thanks for having a look and opening an issue on GitHub! Yeah, the smaller value should be correct (it can be computed analytically).