Zygote is very slow for log barrier function

I’m trying to implement log barrier function for my gradient descent code:


It have an easy derivative (sum of elementwise log → elementwise reciprocal)
Here is the script for testing:

using Zygote
using BenchmarkTools


function log_barrier(v,lower_bound,upper_bound)
    return  -1/100*sum(log.(-(v .-upper_bound))) + -1/100*sum(log.(v .- lower_bound))
end

function log_barrier_derivative(v, lower_bound, upper_bound)
    return -1/100 .*(1 ./ (v.-upper_bound) + 1 ./ (v.-lower_bound))
end

v = [1,2,3,4,5,6,7,8]
upper_bound = 10
lower_bound = 0

#Warm up JIT

display(@benchmark gradient(log_barrier,v,lower_bound,upper_bound))
println(" ")
display(@benchmark log_barrier_derivative(v, lower_bound, upper_bound))

and the output:

byakuya@seireitei ~/asdfggg> julia asdf.jl
BenchmarkTools.Trial: 
  memory estimate:  5.08 KiB
  allocs estimate:  156
  --------------
  minimum time:     11.210 μs (0.00% GC)
  median time:      13.368 μs (0.00% GC)
  mean time:        14.649 μs (1.88% GC)
  maximum time:     2.794 ms (98.78% GC)
  --------------
  samples:          10000
  evals/sample:     1 
BenchmarkTools.Trial: 
  memory estimate:  576 bytes
  allocs estimate:  4
  --------------
  minimum time:     158.838 ns (0.00% GC)
  median time:      162.543 ns (0.00% GC)
  mean time:        191.749 ns (7.44% GC)
  maximum time:     2.385 μs (92.14% GC)
  --------------
  samples:          10000
  evals/sample:     792                                  

What happened? How can a handwritten derivative be like 70x faster than Zygote? Is it possible to make Zygote or other automatic differentiation faster so I don’t need to hand derive all of my functions?

Zygote’s handling of broadcasting isn’t ideal. Some ways to go faster:

# As in question, my computer:
@btime log_barrier($v, 0, 10) # 270.354 ns
@btime Zygote.gradient(log_barrier, $v, 0, 10) # 7.053 μs
@btime log_barrier_derivative($v, 0, 10) # 167.531 ns
7/0.167 # 42

# Bigger, even worse:
v1k = (1 .+ 8 .* rand(1000));
@btime log_barrier($v1k, 0, 10) # 15.148 μs 
@btime Zygote.gradient(log_barrier, $v1k, 0, 10) # 130.708 μs
@btime log_barrier_derivative($v1k, 0, 10) # 2.519 μs
130/2.52 # 51

# With custom gradient definition:
log_barrier_ad(v,lo,hi) = log_barrier(v,lo,hi)
Zygote.@adjoint function log_barrier_ad(v,lo,hi)
    log_barrier(v,lo,hi), dy -> (dy .* log_barrier_derivative(v,lo,hi),)
end
@btime Zygote.gradient(log_barrier_ad, $v, 0, 10) # 489.856 ns
@btime Zygote.gradient(log_barrier_ad, $v1k, 0, 10) # 18.032 μs -- Zygote overhead?

using ForwardDiff, Test # for small arrays

@btime ForwardDiff.gradient(v -> log_barrier(v, 0, 10), $v) # 828.703 ns -- faster
@btime ForwardDiff.gradient(v -> log_barrier(v, 0, 10), $v1k) # 17.080 ms -- slower

@test Zygote.gradient(log_barrier, v, 0, 10)[1] ≈ ForwardDiff.gradient(v -> log_barrier(v, 0, 10), v)

using Tracker # more overhead, but better broadcasting

@btime Tracker.gradient(log_barrier, $v, 0, 10) # 16.403 μs -- slower
@btime Tracker.gradient(log_barrier, $v1k, 0, 10) # 62.095 μs -- faster

using Tullio, LoopVectorization

function log_bar(v,lo,hi)
    @tullio s := log(($hi - v[i]) * (v[i] - $lo))
    -s/100
end
@test log_bar(v,0,10) ≈ log_barrier(v,0,10)
@test Zygote.gradient(log_bar, v, 0, 10)[1] ≈ ForwardDiff.gradient(v -> log_barrier(v, 0, 10), v)

# On small problems, still some overhead:
@btime log_bar($v, 0, 10) # 45.381 ns
@btime Zygote.gradient(log_bar, $v, 0, 10) # 6.631 μs -- surprisingly much overhead

@btime log_bar($v1k, 0, 10) # 2.084 μs, was 15
@btime Zygote.gradient(log_bar, $v1k, 0, 10) # 11.065 μs, was 130

#= # gradient looks about optimal:
julia> @tullio s := log(($hi - v[i]) * (v[i] - $lo)) verbose=true              
┌ Info: symbolic gradients
│   inbody =
│    1-element Array{Any,1}:
└     :(𝛥v[i] = 𝛥v[i] + conj(conj(𝛥ℛ[1]) * (-(inv(hi - v[i])) + inv(v[i] - lo))))
[ Info: running LoopVectorization actor 
=#
3 Likes

I somewhat recently ran into a case where the zygote AD was not performant and liked to OOM. I was able to handle this by writing the back propagation by hand. How do we handle this from a community perspective? Do we make PR’s for Zygote with custom adjoints and add functions to Flux? What I was doing is very esoteric, so no value in a PR - but when do we cross the line of “you should make a PR”?.

You can use ChainRulesCore.jl to write a custom adjoint rule for just the piece of your calculation that Zygote has trouble with, and then that can plug into Zygote to do the rest of the differentiation/composition.

3 Likes