I am currently trying to implement stable neural ODEs according to https://arxiv.org/pdf/2001.06116.pdf. This includes an ODE function,
(m::StableDynamics)(x) = begin
if x == zero(x)
return 0
end
vx = Zygote.gradient(x->sum(m.v(x)), x)[1]
return m.f_hat(x) - vx * relu(vx'*m.f_hat(x) + 0.9 * m.v(x)[1]) / sum(abs2, vx)
end
model = StableDynamics()
p,re = Flux.destructure(model)
f = (u,p)->re(p)(u)
Where, f_hat(x) is a standard dense and v(x) an input convex neural network. Now, for backpropagation we need to call gradient(u->f(u,p), u) and gradient(p->f(u,p),p) repeatedly. Unfortunately, compared to the PyTorch implementation the calculation of 1000 gradients with Zygote,
@time for i in 1:1000
gradient(x->sum(f(x,p)),x.+randn(size(x)))
gradient(p->sum(f(x.+randn(size(x)),p)),p)
end
takes more than 50x the time of the corresponding Pytorch code. Any suggestions how to speed things up? Apart from Zygote I also tried ReverseDiff, but this threw an error. From a related discussion I know that a new AD package should be released soon. However, as time for my project is somewhat running out, I would appreciate very much a temporary solution. Would it be worth to try to fix the problem with ReverseDiff or are there better options?
For completeness I’ll append the full code below:
using Flux, Zygote
data_dim = 2
################################################################################
# soft ReLU
function soft_relu(d)
x -> max.(clamp.(sign.(x) .* 1/(2*d) .* x.^2, 0, d/2), x .- d/2)
end
################################################################################
# input convex neural network
struct ICNNLayer
W
U
b
act
end
# constructor
ICNNLayer(z_in::Integer, x_in::Integer, out::Integer, activation) =
ICNNLayer(randn(out, z_in), randn(out, x_in), randn(out), activation)
# forward pass
(m::ICNNLayer)(z, x) = m.act(m.W*z + softplus.(m.U)*x + m.b)
# track params
Flux.@functor ICNNLayer
# Input Convex Neural Network (ICNN)
struct ICNN
InLayer
HLayer1
HLayer2
act
end
# constructor
ICNN(input_dim::Integer, layer_sizes::Vector, activation) = begin
InLayer = Dense(input_dim, layer_sizes[1])
HLayers = []
if length(layer_sizes) > 1
i = 1
for out in layer_sizes[2:end]
push!(HLayers, ICNNLayer(layer_sizes[i], input_dim, out, activation))
i += 1
end
push!(HLayers, ICNNLayer(layer_sizes[end], input_dim, 1, activation))
end
ICNN(InLayer, HLayers[1], HLayers[2], activation)
end
# forward pass
(m::ICNN)(x) = begin
z = m.act(m.InLayer(x))
z = m.HLayer1(z, x)
z = m.HLayer2(z, x)
return z
end
Flux.@functor ICNN
################################################################################
# Lyapunov Function
struct Lyapunov
icnn
act
d
eps
end
# constructor
Lyapunov(input_dim::Integer; d=0.1, eps=1e-3, layer_sizes::Vector=[32,32], act=soft_relu(d)) = begin
icnn = ICNN(input_dim, layer_sizes, act)
Lyapunov(icnn, act, d, eps)
end
# forward pass
(m::Lyapunov)(x) = begin
g = m.icnn(x)
g0 = m.icnn(zeros(size(x)))
z = m.act(g - g0) .+ m.eps * (x'*x)
return z
end
Flux.@functor Lyapunov
################################################################################
# dynamics
struct StableDynamics
v
f_hat
alpha
nr_delays
τ
p
grad
end
# constructor
StableDynamics(;nr_delays=0, τ=1, p=1.1,alpha=0.9, act=soft_relu(0.1), grad="zygote_reverse") = begin
v = Lyapunov(data_dim, act=act)
# v = Chain(Dense(data_dim, 10, tanh), Dense(10, 1))
f_hat = Chain(Dense((data_dim) * (nr_delays + 1), 32, tanh),
Dense(32, 32, tanh),
Dense(32, data_dim))
StableDynamics(v, f_hat, alpha, nr_delays, τ, p, grad)
end
# forward pass
(m::StableDynamics)(x) = begin
if x == zero(x)
return 0
end
vx = Zygote.gradient(x->sum(m.v(x)), x)[1]
return m.f_hat(x) - vx * relu(vx'*m.f_hat(x) + m.alpha * m.v(x)[1]) / sum(abs2, vx)
end
Flux.@functor StableDynamics
# example
x = [1.0,1.0]
model = StableDynamics()
p, re = Flux.destructure(model)
f = (x,p) -> re(p)(x)
f(x,p)
@time for i in 1:1000
gradient(x->sum(f(x,p)),x.+randn(size(x)))
gradient(p->sum(f(x.+randn(size(x)),p)),p)
end
Thanks in advance for any help!