Hey,

I’m working on an optimization problem and I’m using Zygote to calculate the gradients (code).

Initially a lot of the chain rules were implemented manually by code but then I switched so that Zygote can do this stuff.

Unfortunately the performance seemed to drop despite I defined computationally intensive adjoints explicitly. So I would have expected that Zygote only has to apply a few chain rules and summations.

To find the source of the performance drop, I did a profiling.

That `fft.jl`

uses a lot of time is reasonable since all computationally expensive calculations are FFTs. I’m more wondering about `tuple.jl`

and `getindex`

. By clicking on `getindex`

I got information that it is also connected to tuples.

In my source code I’m not using tuples directly except for the adjoint definitions. For example

```
function conv_aux(conv, rec, otf)
return conv(rec, otf)
end
function ChainRulesCore.rrule(::typeof(conv_aux), conv, rec, otf)
Y = conv_aux(conv, rec, otf)
function conv_aux_pullback(barx)
return NO_FIELDS, DoesNotExist(), conv(barx, conj(otf)), DoesNotExist()
end
return Y, conv_aux_pullback
end
```

and here.

I’m not really experienced in profiling or using AD, and I would be really happy if someone could give me some hints how to improve performance and why the tuples seem to cause issues.

Thanks a lot,

Felix

1 Like

I guess I found the reason for the many calls of tuple.

It was connected to a `max(0.1, conv_aux(conv, rec, otf)`

. I could reproduce that in a smaller example.

From a mathematical point of view that’s reasonable to cause troubles because this `max`

function is not differentiable at 0.1.

I still would be interested why Zygote ends up in that tuples calls a lot.

Thanks and sorry that I didn’t really provide a small standalone example,

Felix

1 Like

Hey,

I can now provide a minimal working example:

```
using Zygote
using FFTW
using Profile
using StatProfilerHTML
N =1000
global psf = randn((N, N))
global img = randn((N, N))
function conv(img, psf)
f_psf = fft(psf)
f_img = fft(img)
conv_res = real(ifft(f_psf .* f_img))
#return max.(0.0, conv_res)
return conv_res
end
f(img) = sum(conv(img, psf));
Zygote.gradient(f, img)[1]
@profilehtml Zygote.gradient(f, img)[1]
```

The results with `max.`

:

And without:

So I’m not sure whether I’m at the right place to discuss this.

Thanks,

Felix

I think what you’re seeing here is that Zygote’s broadcasting isn’t very fast, except on cases where it has a special rule for that function:

```
julia> @btime gradient(x -> sum(1 .+ max.(0.0,x)), mat) setup=(mat=randn(100,100));
990.097 μs (50054 allocations: 1.84 MiB)
julia> using Flux # Flux.relu has a special rule
julia> @btime gradient(x -> sum(1 .+ relu.(x)), mat) setup=(mat=randn(100,100));
17.888 μs (6 allocations: 234.61 KiB)
julia> using Tullio # handles gradients without Broadcast.
julia> @btime gradient(x -> sum(@tullio y[i,j] := 1 + max(0.0,x[i,j])), mat) setup=(mat=randn(100,100));
22.290 μs (18 allocations: 156.80 KiB)
julia> using Tracker # different broadcasting, perhaps using ForwardDiff
julia> @btime Tracker.gradient(x -> sum(1 .+ max.(0.0,x)), mat) setup=(mat=randn(100,100));
94.865 μs (173 allocations: 709.87 KiB)
```

2 Likes

Hey,

thanks for the details.

Is it `max.(0.0, x)`

or `1 .+ ...`

which causes the problems? I guess the first one, right?

Thanks,

Felix

Notice the crazy number of allocations in the first of your examples: about 5 per element of a 10000 entry array. The other non-Zygote examples have much less allocations.

I often observe a lot of allocations with Zygote and so far I haven’t really figured how to fix that properly.

Usually the forward passes are type stable in these situations.