Hey,
is there a way that Zygote prevents the calculation of adjoints which are not needed at all?
Let’s say we have a function f(a, b)
where we only want to have the first gradient:
b = #some value
g(a) = f(a, b)
Zygote.gradient(f, a)
Zygote now calculates (as far as I can see from my rrule
definition) the adjoints with respect to a
and b
but the latter one is not needed at all and is actually wasted computing time.
Is there a way to prevent that?
Full Working Example
julia> function conv(u::AbstractArray{T, N}, v::AbstractArray{D, M}, dims=ntuple(+, max(N, M))) where {T, D, N, M}
return real(ifft(fft(u, dims) .* fft(v, dims), dims))
end
julia> function ChainRulesCore.rrule(::typeof(conv), u::AbstractArray{T, N}, v::AbstractArray{D, M},
dims=ntuple(+, max(N, M))) where {T, D, N, M}
Y = conv(u, v, dims)
function conv_pullback(barx)
z = zero(eltype(u))
return z, conv(barx, conj(v), dims), print("lo"), z
end
return Y, conv_pullback
end
julia> function main()
u = randn((512, 512, 23))
v = copy(u)
f1(u) = sum(conv(u, v))
f2(u, v) = sum(conv(u, v))
@time f1(u)
@time f1(u)
@time f2(u, v)
@time f2(u, v)
@time Zygote.gradient(f1, u)
@time Zygote.gradient(f1, u)
@time Zygote.gradient(f1, u)
@time Zygote.gradient(f2, u, v)
@time Zygote.gradient(f2, u, v)
@time Zygote.gradient(f2, u, v)
return 0
end
main (generic function with 1 method)
julia> main()
0.385862 seconds (106 allocations: 276.907 MiB, 6.92% gc time)
0.355024 seconds (106 allocations: 276.907 MiB)
0.499096 seconds (106 allocations: 276.907 MiB, 28.23% gc time)
0.368945 seconds (106 allocations: 276.907 MiB, 4.22% gc time)
lo 0.830857 seconds (216 allocations: 599.815 MiB, 7.08% gc time)
lo 0.871777 seconds (216 allocations: 599.815 MiB, 16.17% gc time)
lo 0.760347 seconds (216 allocations: 599.815 MiB, 3.82% gc time)
lo 0.804685 seconds (216 allocations: 599.815 MiB, 7.13% gc time)
lo 0.912386 seconds (216 allocations: 599.815 MiB, 16.41% gc time)
lo 0.804314 seconds (216 allocations: 599.815 MiB, 4.46% gc time)
Thanks,
Felix