As recently announced,
solve command is now compatible with Flux and should be used. I am confused about the value of the gradients which I obtain in these two very simple examples.
using Flux, DiffEqFlux, DifferentialEquations #parameter p=[5.0f0] #ini state u0=[2.0f0] function f(u,p,t) p*u end dt=0.005f0 tspan = (0.0,dt) function dynamics1(p) prob1 = ODEProblem(f,u0,tspan,p) solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:] end prob2 = ODEProblem(f,u0,tspan,p) function dynamics2(p) solve(prob2,Heun(),dt=dt,saveat=[dt],adaptive=false)[1,:] end #some loss function loss1(p) = sum(abs.(dynamics1(p)-u0)) loss2(p) = sum(abs.(dynamics2(p)-u0)) θ = Flux.params(p) gs1 = gradient(θ) do loss1(p) end gs2 = gradient(θ) do loss2(p) end #gs1 exactly twice larger as the other version println("grads 1:", gs1[p] ) #Float32[0.020503124] println("grads 2:", gs2[p] ) #Float32[0.010251562]
The only difference is that
prob1 is defined inside the gradient loop and
prob2 outside. In the first case (
gs1) I get exactly twice larger gradient than for
gs2. Where does this factor 2 come from? When I use the old
concrete_solve the results come out identical. Can it be linked to the fact that
solve does not explicitly take the parameters
p as an argument?