Speeding up gradients for custom neural network - currently much slower than in PyTorch

I am currently trying to implement stable neural ODEs according to https://arxiv.org/pdf/2001.06116.pdf. This includes an ODE function,

(m::StableDynamics)(x) = begin
    if x == zero(x)
        return 0
    end
    vx = Zygote.gradient(x->sum(m.v(x)), x)[1]
    return m.f_hat(x) - vx * relu(vx'*m.f_hat(x) + 0.9 * m.v(x)[1]) / sum(abs2, vx)
end
model = StableDynamics()
p,re = Flux.destructure(model)
f = (u,p)->re(p)(u)

Where, f_hat(x) is a standard dense and v(x) an input convex neural network. Now, for backpropagation we need to call gradient(u->f(u,p), u) and gradient(p->f(u,p),p) repeatedly. Unfortunately, compared to the PyTorch implementation the calculation of 1000 gradients with Zygote,

@time for i in 1:1000
    gradient(x->sum(f(x,p)),x.+randn(size(x)))
    gradient(p->sum(f(x.+randn(size(x)),p)),p)
end

takes more than 50x the time of the corresponding Pytorch code. Any suggestions how to speed things up? Apart from Zygote I also tried ReverseDiff, but this threw an error. From a related discussion I know that a new AD package should be released soon. However, as time for my project is somewhat running out, I would appreciate very much a temporary solution. Would it be worth to try to fix the problem with ReverseDiff or are there better options?

For completeness I’ll append the full code below:

using Flux, Zygote

data_dim = 2
################################################################################
# soft ReLU
function soft_relu(d)
    x -> max.(clamp.(sign.(x) .* 1/(2*d) .* x.^2, 0, d/2), x .- d/2)
end
################################################################################
# input convex neural network
struct ICNNLayer
    W
    U
    b
    act
end
# constructor
ICNNLayer(z_in::Integer, x_in::Integer, out::Integer, activation) =
    ICNNLayer(randn(out, z_in), randn(out, x_in), randn(out), activation)
# forward pass
(m::ICNNLayer)(z, x) = m.act(m.W*z + softplus.(m.U)*x + m.b)

# track params
Flux.@functor ICNNLayer

# Input Convex Neural Network (ICNN)
struct ICNN
    InLayer
    HLayer1
    HLayer2
    act
end
# constructor
ICNN(input_dim::Integer, layer_sizes::Vector, activation) = begin
    InLayer = Dense(input_dim, layer_sizes[1])
    HLayers = []
    if length(layer_sizes) > 1
        i = 1
        for out in layer_sizes[2:end]
            push!(HLayers, ICNNLayer(layer_sizes[i], input_dim, out, activation))
            i += 1
        end
        push!(HLayers, ICNNLayer(layer_sizes[end], input_dim, 1, activation))
    end
    ICNN(InLayer, HLayers[1], HLayers[2], activation)
end
# forward pass
(m::ICNN)(x) = begin
    z = m.act(m.InLayer(x))
    z = m.HLayer1(z, x)
    z = m.HLayer2(z, x)
    return z
end
Flux.@functor ICNN

################################################################################
# Lyapunov Function
struct Lyapunov
    icnn
    act
    d
    eps
end
# constructor
Lyapunov(input_dim::Integer; d=0.1, eps=1e-3, layer_sizes::Vector=[32,32], act=soft_relu(d)) = begin
    icnn = ICNN(input_dim, layer_sizes, act)
    Lyapunov(icnn, act, d, eps)
end
# forward pass
(m::Lyapunov)(x) = begin
    g = m.icnn(x)
    g0 = m.icnn(zeros(size(x)))
    z = m.act(g - g0) .+ m.eps * (x'*x)
    return z
end
Flux.@functor Lyapunov

################################################################################
# dynamics
struct StableDynamics
    v
    f_hat
    alpha
    nr_delays
    τ
    p
    grad
end

# constructor
StableDynamics(;nr_delays=0, τ=1, p=1.1,alpha=0.9, act=soft_relu(0.1), grad="zygote_reverse") = begin
    v = Lyapunov(data_dim, act=act)
    # v = Chain(Dense(data_dim, 10, tanh), Dense(10, 1))
    f_hat = Chain(Dense((data_dim) * (nr_delays + 1), 32, tanh),
               Dense(32, 32, tanh),
               Dense(32, data_dim))
    StableDynamics(v, f_hat, alpha, nr_delays, τ, p, grad)
end
# forward pass
(m::StableDynamics)(x) = begin
    if x == zero(x)
        return 0
    end
    vx = Zygote.gradient(x->sum(m.v(x)), x)[1]

    return m.f_hat(x) - vx * relu(vx'*m.f_hat(x) + m.alpha * m.v(x)[1]) / sum(abs2, vx)
end
Flux.@functor StableDynamics

# example
x = [1.0,1.0]
model = StableDynamics()
p, re = Flux.destructure(model)
f = (x,p) -> re(p)(x)
f(x,p)
@time for i in 1:1000
    gradient(x->sum(f(x,p)),x.+randn(size(x)))
    gradient(p->sum(f(x.+randn(size(x)),p)),p)
end

Thanks in advance for any help!

Your structs aren’t typed, so everything here is type-unstable. I would suggest reading the Julia performance tips.

Also, your code doesn’t show an ODE solve at all? This is just a comparison of Flux to PyTorch without the ODEs involved?

3 Likes

@ChrisRackauckas Thanks a lot for your answer!

Also, your code doesn’t show an ODE solve at all? This is just a comparison of Flux to PyTorch without the ODEs involved?

Yes, sorry about the title, it was misleading and I therefore changed it. I do want to use this function for a neural ODE and need the gradients for the adjoint equation. Looking at the computation time, it turned out that the gradient calls are really slow compared to the PyTorch implementation (without neural ODE). So it would be great if I could speed this up a bit, such that the adjoint method becomes more efficient :slight_smile:

Your structs aren’t typed, so everything here is type-unstable. I would suggest reading the Julia performance tips.

You’re right I followed the description for custom layers in the Flux documentation and completely forgot about the types. I specified them now, which gave me a speedup of approx. 20% (which is still far off PyTorch performance). This is the updated code:

Code
cd(@__DIR__)
using Pkg; Pkg.activate("../../."); Pkg.instantiate(); using Revise
using Flux, Zygote

data_dim = 2
################################################################################
# soft ReLU versions
function soft_relu(x)
    d = Float32(0.1)
    oftype(x, ifelse(x <= d, ifelse(x>=0, 1/(2*d) * x^2, zero(x)), x - d/2))
end
using Zygote: @adjoint
d=0.1
@adjoint soft_relu(x) = soft_relu(x), y -> (oftype(x, y*ifelse(x <= d, ifelse(x>=0, 2*x/d, zero(x)), 1)),)
################################################################################
# input convex neural network
struct ICNNLayer{R<:Array{Float32},S<:Array{Float32}, T<:Array{Float32}, F}
    W::R
    U::S
    b::T
    act::F
end
# constructor
ICNNLayer(z_in::Integer, x_in::Integer, out::Integer, activation) =
    ICNNLayer(randn(Float32, out, z_in), randn(Float32, out, x_in), randn(Float32, out), activation)
# forward pass
(m::ICNNLayer)(z::AbstractArray, x::AbstractArray) = m.act.(m.W*z + softplus.(m.U)*x + m.b)
# track params
Flux.@functor ICNNLayer

# Input Convex Neural Network (ICNN)
struct ICNN{F}
    InLayer::Dense
    HLayer1::ICNNLayer
    HLayer2::ICNNLayer
    act::F
end
# constructor
ICNN(input_dim::Integer, layer_sizes::Vector, activation) = begin
    InLayer = Dense(input_dim, layer_sizes[1])
    HLayer1 = ICNNLayer(layer_sizes[1], input_dim, layer_sizes[2], activation)
    HLayer2 = ICNNLayer(layer_sizes[2], input_dim, 1, activation)
    ICNN(InLayer, HLayer1, HLayer2, activation)
end
# forward pass
(m::ICNN)(x::AbstractArray) = begin
    z = m.act.(m.InLayer(x))
    z = m.HLayer1(z, x)
    z = m.HLayer2(z, x)
    return z
end
################################################################################
# Lyapunov Function
struct Lyapunov{S, T<:AbstractFloat, U<:AbstractFloat}
    icnn::ICNN
    act::S
    d::T
    eps::U
end
# constructor
Lyapunov(input_dim::Integer; d=0.1, eps=1e-3, layer_sizes::Vector=[32,32], act=soft_relu) = begin
    icnn = ICNN(input_dim, layer_sizes, act)
    Lyapunov(icnn, act, Float32(d), Float32(eps))
end
# forward pass
(m::Lyapunov)(x::AbstractArray) = begin
    g = m.icnn(x)
    g0 = m.icnn(zero(x))
    z = m.act.(g - g0) .+ m.eps * x'*x
    return z
end
Flux.@functor Lyapunov
################################################################################
# dynamics
struct StableDynamics{S, T<:AbstractFloat}
    v::Lyapunov
    f_hat::S
    alpha::T
end
# constructor
StableDynamics(data_dim;alpha=0.9, act=soft_relu) = begin
    v = Lyapunov(data_dim, act=act)
    f_hat = Chain(Dense(data_dim, 32, tanh),
               Dense(32, 32, tanh),
               Dense(32, data_dim))
    StableDynamics(v, f_hat, Float32(alpha))
end
# forward pass
(m::StableDynamics)(x::AbstractArray) = begin
    if x == zero(x)
        return zero(x)
    end
    vx = Zygote.gradient(x->sum(m.v(x)), x)[1]
    return m.f_hat(x) - vx * relu(vx'*m.f_hat(x) + m.alpha * m.v(x)[1]) / sum(abs2, vx)
end
Flux.@functor StableDynamics

# example
x = Array{Float32}([1.0,1.0])
model = StableDynamics(2)
p, re = Flux.destructure(model)
f = (x,p) -> re(p)(x)
f(x,p)
gradient(x->sum(f(x,p)),x.+randn(size(x)))
gradient(p->sum(f(x.+randn(size(x)),p)),p)
@time for i in 1:100
    gradient(x->sum(f(x,p)),x.+randn(size(x)))
    gradient(p->sum(f(x.+randn(size(x)),p)),p)
end

Any suggestions for further optimizations? Could a reason be my naive matrix multiplications in the custom ICNNLayer?

Yeah no worries, it’s just a title with ODEs is going to scare away the pure ML developers like @dhairyagandhi96 who would really answer the question you have here.

2 Likes

Yeah sure, thanks for the hint.

The typed version of the code works now (and I edited my last post). So there is a speedup, but it’s still considerably slower than PyTorch.

Just curious - what is the speedup vs. Pytorch now, after you have typed (and added other performance improvements)?

Am I understanding it correctly that your forward pass uses a gradient call?

For 1000 gradient calls it is still a factor of 10. In julia it takes 23.46sec compared to 2.35sec in PyTorch.

1 Like

Yes, exactly. The gradient of a custom neural network. And this gradient seems to be the problem. Without that gradient call in the forward pass, the backward pass is as fast as in PyTorch.

Keno recently did a talk on the new compiler-based AD that should come out soon, Diffractor.jl, and it focused on efficient higher order AD. Why? Because we know that Zygote has major performance problems with nesting. Essentially what happens is that two codegen calls together before doing optimization seems to generate large enough code that optimization heuristics can fail. In other words, Zygote is just slow in the case you’re looking at here, it’s sad, and the solution is described in the video below but won’t be ready right now.

If this is for physics-informed neural networks, then note from the video that PINNs are precisely what caused this line of R&D, and the optics formulation is a much better answer than what we had before. FWIW doing mixed mode and making the PDE derivatives be forward mode while the loss is reverse is an asymptotically good strategy anyways, so that’s the workaround I’d recommend for now.

5 Likes

Thanks a lot for this answer and the reference to the talk. This helps a lot! Using forward_jacobian did not work due to a mutation error in Zygotes forward_jacobian function. However, implementing a forward mode differentiation for the inner gradients from scratch did the trick. Now, performance is even better than in PyTorch and also ReverseDiff is working again which is nice for differentiation through the ODE solvers :slight_smile:

6 Likes

Sorry for spamming with this. Unfortunately, ReverseDiff makes still problems. Gradients of the ODE function with respect to input and parameters work, but in when trying to differentiate through the solver this error occurs:

MethodError: ReverseDiff.ForwardOptimize{typeof(+)}(+)(::ReverseDiff.TrackedReal{Float64,Float32,Nothing}, ::ReverseDiff.TrackedReal{Float64,Float64,Nothing}) is ambiguous. Candidates:
  (self::ReverseDiff.ForwardOptimize{F})(x::Real, t::ReverseDiff.TrackedReal{V,D,O} where O) where {F, V, D} in ReverseDiff at /home/andrschl/.julia/packages/ReverseDiff/NoIPU/src/macros.jl:109
  (self::ReverseDiff.ForwardOptimize{F})(t::ReverseDiff.TrackedReal{V,D,O} where O, x::Real) where {F, V, D} in ReverseDiff at /home/andrschl/.julia/packages/ReverseDiff/NoIPU/src/macros.jl:121
Possible fix, define
  (::ReverseDiff.ForwardOptimize{F})(::ReverseDiff.TrackedReal{V,D,O} where O, ::ReverseDiff.TrackedReal{V,D,O} where O) where {F, V, D, V, D}

Would be great if I can get it to run with ReverseDiff. Any ideas how to fix this?

Make all of your numbers either Float64 or Float32. That usually easier than mixing them.

1 Like

Thanks! I forgot to change the time to Float32… Looks like it works now :ok_hand: Seems like Zygote gives wrong gradients in this example. ReverseDiff and FiniteDifferences are approx. the same but Zygote not.

That would be good to isolate. Can you share your final code here for @dhairyagandhi96? Silent errors are the ones we need to investigate the most.

1 Like

Sorry, please forget about this. It was due to an error of mine in a custom Adjoint definition. However, I have been facing another issue regarding Zygote. When f(x,p) is a neural network with a linear layer on top and df(x,p) is its derivative with respect to the input x. Then in d/dp sum(df(x,p)) Zygote forgets the zero derivative w.r.t the top layer bias. I assume this is because Zygote returns nothing for the gradient of a constant function. Here is an example:

using Flux, Zygote, ReverseDiff
using FiniteDifferences, Test
m = Chain(Dense(2,2,tanh), Dense(2,1))
p,re = Flux.destructure(m)
f = (x,p)->re(p)(x)
df1 = (x,p) -> ReverseDiff.gradient(x->sum(f(x,p)), x)
df2 = (x,p) -> Zygote.gradient(x->sum(f(x,p)), x)[1]
@test isapprox(df1(x,p),df2(x,p))
g1 = ReverseDiff.gradient(p-> sum(df1(x,p)),p) # this doesn't work -> g1 = 0
g2 = Zygote.gradient(p-> sum(df2(x,p)),p)[1] # this drops the gradient component for the top layer bias (-> since df1 is constant in this component)
g3 = grad(central_fdm(5,1),p-> sum(df1(x,p)),p)[1]
@test isapprox(g2, g3[1:end-1], atol=1e-3)
@test size(g3) == size(g2) # fails