Hey,
I’m working on an optimization problem and I’m using Zygote to calculate the gradients (code).
Initially a lot of the chain rules were implemented manually by code but then I switched so that Zygote can do this stuff.
Unfortunately the performance seemed to drop despite I defined computationally intensive adjoints explicitly. So I would have expected that Zygote only has to apply a few chain rules and summations.
To find the source of the performance drop, I did a profiling.
That fft.jl
uses a lot of time is reasonable since all computationally expensive calculations are FFTs. I’m more wondering about tuple.jl
and getindex
. By clicking on getindex
I got information that it is also connected to tuples.
In my source code I’m not using tuples directly except for the adjoint definitions. For example
function conv_aux(conv, rec, otf)
return conv(rec, otf)
end
function ChainRulesCore.rrule(::typeof(conv_aux), conv, rec, otf)
Y = conv_aux(conv, rec, otf)
function conv_aux_pullback(barx)
return NO_FIELDS, DoesNotExist(), conv(barx, conj(otf)), DoesNotExist()
end
return Y, conv_aux_pullback
end
and here.
I’m not really experienced in profiling or using AD, and I would be really happy if someone could give me some hints how to improve performance and why the tuples seem to cause issues.
Thanks a lot,
Felix
1 Like
I guess I found the reason for the many calls of tuple.
It was connected to a max(0.1, conv_aux(conv, rec, otf)
. I could reproduce that in a smaller example.
From a mathematical point of view that’s reasonable to cause troubles because this max
function is not differentiable at 0.1.
I still would be interested why Zygote ends up in that tuples calls a lot.
Thanks and sorry that I didn’t really provide a small standalone example,
Felix
1 Like
Hey,
I can now provide a minimal working example:
using Zygote
using FFTW
using Profile
using StatProfilerHTML
N =1000
global psf = randn((N, N))
global img = randn((N, N))
function conv(img, psf)
f_psf = fft(psf)
f_img = fft(img)
conv_res = real(ifft(f_psf .* f_img))
#return max.(0.0, conv_res)
return conv_res
end
f(img) = sum(conv(img, psf));
Zygote.gradient(f, img)[1]
@profilehtml Zygote.gradient(f, img)[1]
The results with max.
:
And without:
So I’m not sure whether I’m at the right place to discuss this.
Thanks,
Felix
I think what you’re seeing here is that Zygote’s broadcasting isn’t very fast, except on cases where it has a special rule for that function:
julia> @btime gradient(x -> sum(1 .+ max.(0.0,x)), mat) setup=(mat=randn(100,100));
990.097 μs (50054 allocations: 1.84 MiB)
julia> using Flux # Flux.relu has a special rule
julia> @btime gradient(x -> sum(1 .+ relu.(x)), mat) setup=(mat=randn(100,100));
17.888 μs (6 allocations: 234.61 KiB)
julia> using Tullio # handles gradients without Broadcast.
julia> @btime gradient(x -> sum(@tullio y[i,j] := 1 + max(0.0,x[i,j])), mat) setup=(mat=randn(100,100));
22.290 μs (18 allocations: 156.80 KiB)
julia> using Tracker # different broadcasting, perhaps using ForwardDiff
julia> @btime Tracker.gradient(x -> sum(1 .+ max.(0.0,x)), mat) setup=(mat=randn(100,100));
94.865 μs (173 allocations: 709.87 KiB)
2 Likes
Hey,
thanks for the details.
Is it max.(0.0, x)
or 1 .+ ...
which causes the problems? I guess the first one, right?
Thanks,
Felix
Notice the crazy number of allocations in the first of your examples: about 5 per element of a 10000 entry array. The other non-Zygote examples have much less allocations.
I often observe a lot of allocations with Zygote and so far I haven’t really figured how to fix that properly.
Usually the forward passes are type stable in these situations.