Taking gradients when propagating through an ODE solver

I need to differentiate parameters of my model (neural network) through an ODE solver. Apparently, DiffEqFlux allows me to do that but it seems that the gradients are computed only if the network parameters enter directly as the parameters of ODEprob(…). If these parameters enter indirectly (through some functional dependence) Flux returns ‘nothing’.

To make it clearer, I have the following example where I make one step of an ODE which is implemented in two slightly different ways.

using Flux, DiffEqFlux, DifferentialEquations


model=Chain(Dense(1,1,identity))
#weights and biases by hand
model[1].W[1]=2.0f0
model[1].b[1]=1.0f0
p1,re=Flux.destructure(model)

#ini state
u0=[2.0f0]

function f1(u,p,t) #evaluates the coefficint in the body of f1
   alpha=re(p)(u0)[1]
   return alpha*u
end

function f2(u,p,t) #p is directly the coefficient
    return p[1]*u
end

dt=0.005f0
tspan = (0.0,dt)


prob1 = ODEProblem(f1,u0,tspan,p1) #model parameters directly in ODEproblem

p2=re(p1)(u0)
prob2 = ODEProblem(f2,u0,tspan,p2) #model parameters indirectly in ODEproblem

function dynamics1()
   solve(prob1,Heun(), dt=dt,saveat=[dt],adaptive=false)[1,:]
end

function dynamics2()
   solve(prob2,Heun(), dt=dt,saveat=[dt],adaptive1=false)[1,:]
end

#some loss function
criterion1() = sum(abs.(dynamics1()-u0))
criterion2() = sum(abs.(dynamics2()-u0))

# analytic grad W
#2*2*0.005*exp((2*2+1)*0.005)=0.020506302410488578
# analytic grad b
#2*0.005*exp((2*2+1)*0.005)=0.010253151205244289
θ = Flux.params(p1)
gs1 = gradient(θ) do
   criterion1()
end

gs2 = gradient(θ) do
   criterion2()
end

println("grads f1:", gs1[p1] ) #Float32[0.020503124, 0.010251562]
println("grads f2:", gs2[p1] ) #nothing

In prob1 the parameters of the network p1 enter directly the ODEproblem(…,p1) and are further elaborated in the body of the function f1. Whereas in prob2, these parameters enter the ODEproblem through parameters p2 which depend on p1.

The results in the forward pass — dynamics1() and dynamics2() — are the same (and correct as can be checked analytically). However, if I take gradients of the network parameters p1, in the case of prob1 I get the correct results and in the case prob2 I get ‘nothing’.

Is there a way to make Flux compute the gradients also for the implementation with prob2?

1 Like

If you make your parameter vector p = [p1;p2] then grab the right parts like re(p[1:N]) it’ll work.

Thanks for your reply! However, I must be still missing something. I understand that I need to insert both
p1 and p2 as Flux parameters. When I do as you suggest I still get zero gradients for the network parameters p1.

These are the parts of the code that I have changed

function f2(u,p,t)
    return p[end]*u
end

p = [p1;p2]
prob2 = ODEProblem(f2,u0,tspan,p) 

θ = Flux.params(p)
gs2 = gradient(θ) do
   criterion2()
end

println("grads:", gs2[p] ) #Float32[0.0, 0.0, 0.010251562]

Thanks!

criterion2 only uses part of the parameters, and correctly computes the gradient for that part. Did you mean to run both criterion1 and criterion2?

That’s the thing. I would like to use criterion2() and get the gradients of p1 as well. I somehow thought that it is possible as p2 is a function of p1.

Oh I missed that. You can’t differentiate that in isolation then. You have to differentiate the full process if you want the derivative of the full process.

θ = Flux.params([p1,p2])
gs2 = gradient(θ) do
   p2=re(p1)(u0)
   criterion2(p2)
end
1 Like

Great! Makes sense…Thanks for help!

1 Like