I need to differentiate parameters of my model (neural network) through an ODE solver. Apparently, DiffEqFlux allows me to do that but it seems that the gradients are computed only if the network parameters enter directly as the parameters of ODEprob(…). If these parameters enter indirectly (through some functional dependence) Flux returns ‘nothing’.
To make it clearer, I have the following example where I make one step of an ODE which is implemented in two slightly different ways.
using Flux, DiffEqFlux, DifferentialEquations
model=Chain(Dense(1,1,identity))
#weights and biases by hand
model[1].W[1]=2.0f0
model[1].b[1]=1.0f0
p1,re=Flux.destructure(model)
#ini state
u0=[2.0f0]
function f1(u,p,t) #evaluates the coefficint in the body of f1
alpha=re(p)(u0)[1]
return alpha*u
end
function f2(u,p,t) #p is directly the coefficient
return p[1]*u
end
dt=0.005f0
tspan = (0.0,dt)
prob1 = ODEProblem(f1,u0,tspan,p1) #model parameters directly in ODEproblem
p2=re(p1)(u0)
prob2 = ODEProblem(f2,u0,tspan,p2) #model parameters indirectly in ODEproblem
function dynamics1()
solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end
function dynamics2()
solve(prob2,Heun(), dt=dt,saveat=[dt],adaptive1=false)[1,:]
end
#some loss function
criterion1() = sum(abs.(dynamics1()-u0))
criterion2() = sum(abs.(dynamics2()-u0))
# analytic grad W
#2*2*0.005*exp((2*2+1)*0.005)=0.020506302410488578
# analytic grad b
#2*0.005*exp((2*2+1)*0.005)=0.010253151205244289
θ = Flux.params(p1)
gs1 = gradient(θ) do
criterion1()
end
gs2 = gradient(θ) do
criterion2()
end
println("grads f1:", gs1[p1] ) #Float32[0.020503124, 0.010251562]
println("grads f2:", gs2[p1] ) #nothing
In prob1 the parameters of the network p1 enter directly the ODEproblem(…,p1) and are further elaborated in the body of the function f1. Whereas in prob2, these parameters enter the ODEproblem through parameters p2 which depend on p1.
The results in the forward pass — dynamics1() and dynamics2() — are the same (and correct as can be checked analytically). However, if I take gradients of the network parameters p1, in the case of prob1 I get the correct results and in the case prob2 I get ‘nothing’.
Is there a way to make Flux compute the gradients also for the implementation with prob2?