Zygote gradient results in slow Tuple getindex calls

Hey,

I’m working on an optimization problem and I’m using Zygote to calculate the gradients (code).

Initially a lot of the chain rules were implemented manually by code but then I switched so that Zygote can do this stuff.
Unfortunately the performance seemed to drop despite I defined computationally intensive adjoints explicitly. So I would have expected that Zygote only has to apply a few chain rules and summations.

To find the source of the performance drop, I did a profiling.

That fft.jl uses a lot of time is reasonable since all computationally expensive calculations are FFTs. I’m more wondering about tuple.jl and getindex. By clicking on getindex I got information that it is also connected to tuples.

In my source code I’m not using tuples directly except for the adjoint definitions. For example

function conv_aux(conv, rec, otf)
    return conv(rec, otf)
end

function ChainRulesCore.rrule(::typeof(conv_aux), conv, rec, otf)
    Y = conv_aux(conv, rec, otf)
    function conv_aux_pullback(barx)
        return NO_FIELDS, DoesNotExist(), conv(barx, conj(otf)), DoesNotExist()
    end
    return Y, conv_aux_pullback
end

and here.

I’m not really experienced in profiling or using AD, and I would be really happy if someone could give me some hints how to improve performance and why the tuples seem to cause issues.

Thanks a lot,

Felix

1 Like

I guess I found the reason for the many calls of tuple.
It was connected to a max(0.1, conv_aux(conv, rec, otf). I could reproduce that in a smaller example.

From a mathematical point of view that’s reasonable to cause troubles because this max function is not differentiable at 0.1.
I still would be interested why Zygote ends up in that tuples calls a lot.

Thanks and sorry that I didn’t really provide a small standalone example,

Felix

1 Like

Hey,

I can now provide a minimal working example:

using Zygote
using FFTW
using Profile
using StatProfilerHTML

N =1000
global psf = randn((N, N))
global img = randn((N, N))

function conv(img, psf)
    f_psf = fft(psf)
    f_img = fft(img)
    conv_res = real(ifft(f_psf .* f_img))
    #return max.(0.0, conv_res)
    return conv_res    
end 

f(img) = sum(conv(img, psf));

Zygote.gradient(f, img)[1]
@profilehtml Zygote.gradient(f, img)[1]

The results with max.:

And without:

So I’m not sure whether I’m at the right place to discuss this.

Thanks,

Felix

I think what you’re seeing here is that Zygote’s broadcasting isn’t very fast, except on cases where it has a special rule for that function:

julia> @btime gradient(x -> sum(1 .+ max.(0.0,x)), mat)  setup=(mat=randn(100,100));
  990.097 μs (50054 allocations: 1.84 MiB)

julia> using Flux # Flux.relu has a special rule

julia> @btime gradient(x -> sum(1 .+ relu.(x)), mat)  setup=(mat=randn(100,100));
  17.888 μs (6 allocations: 234.61 KiB)

julia> using Tullio # handles gradients without Broadcast.

julia> @btime gradient(x -> sum(@tullio y[i,j] := 1 + max(0.0,x[i,j])), mat)  setup=(mat=randn(100,100));
  22.290 μs (18 allocations: 156.80 KiB)

julia> using Tracker # different broadcasting, perhaps using ForwardDiff

julia> @btime Tracker.gradient(x -> sum(1 .+ max.(0.0,x)), mat)  setup=(mat=randn(100,100));
  94.865 μs (173 allocations: 709.87 KiB)
2 Likes

Hey,

thanks for the details.

Is it max.(0.0, x) or 1 .+ ... which causes the problems? I guess the first one, right?

Thanks,

Felix

Notice the crazy number of allocations in the first of your examples: about 5 per element of a 10000 entry array. The other non-Zygote examples have much less allocations.

I often observe a lot of allocations with Zygote and so far I haven’t really figured how to fix that properly.

Usually the forward passes are type stable in these situations.