Zygote calculate adjoint only for one of several arguments

Hey,

is there a way that Zygote prevents the calculation of adjoints which are not needed at all?

Let’s say we have a function f(a, b) where we only want to have the first gradient:

b = #some value
g(a) = f(a, b)
Zygote.gradient(f, a)

Zygote now calculates (as far as I can see from my rrule definition) the adjoints with respect to a and b but the latter one is not needed at all and is actually wasted computing time.

Is there a way to prevent that?

Full Working Example
julia> function conv(u::AbstractArray{T, N}, v::AbstractArray{D, M}, dims=ntuple(+, max(N, M))) where {T, D, N, M}
           return real(ifft(fft(u, dims) .* fft(v, dims), dims)) 
       end

julia> function ChainRulesCore.rrule(::typeof(conv), u::AbstractArray{T, N}, v::AbstractArray{D, M}, 
                                     dims=ntuple(+, max(N, M))) where {T, D, N, M}
           Y = conv(u, v, dims)
           function conv_pullback(barx)
               z = zero(eltype(u))
               return z, conv(barx, conj(v), dims), print("lo"), z
           end 
           return Y, conv_pullback
       end

julia> function main()
           u = randn((512, 512, 23))
           v = copy(u)
       
           f1(u) = sum(conv(u, v)) 
           f2(u, v) = sum(conv(u, v)) 
       
           @time f1(u) 
           @time f1(u) 
           @time f2(u, v)  
           @time f2(u, v)  
           @time Zygote.gradient(f1, u)
           @time Zygote.gradient(f1, u)
           @time Zygote.gradient(f1, u)
           @time Zygote.gradient(f2, u, v)
           @time Zygote.gradient(f2, u, v)
           @time Zygote.gradient(f2, u, v)
           return 0
       end
main (generic function with 1 method)

julia> main()
  0.385862 seconds (106 allocations: 276.907 MiB, 6.92% gc time)
  0.355024 seconds (106 allocations: 276.907 MiB)
  0.499096 seconds (106 allocations: 276.907 MiB, 28.23% gc time)
  0.368945 seconds (106 allocations: 276.907 MiB, 4.22% gc time)
lo  0.830857 seconds (216 allocations: 599.815 MiB, 7.08% gc time)
lo  0.871777 seconds (216 allocations: 599.815 MiB, 16.17% gc time)
lo  0.760347 seconds (216 allocations: 599.815 MiB, 3.82% gc time)
lo  0.804685 seconds (216 allocations: 599.815 MiB, 7.13% gc time)
lo  0.912386 seconds (216 allocations: 599.815 MiB, 16.41% gc time)
lo  0.804314 seconds (216 allocations: 599.815 MiB, 4.46% gc time)

Thanks,

Felix

2 Likes

I think it is not currently possible. There is a mcoro @thunk in diffrulescore, which should enable this, but it is disabled in zygote.

1 Like