# Zygote calculate adjoint only for one of several arguments

Hey,

is there a way that Zygote prevents the calculation of adjoints which are not needed at all?

Let’s say we have a function `f(a, b)` where we only want to have the first gradient:

``````b = #some value
g(a) = f(a, b)
``````

Zygote now calculates (as far as I can see from my `rrule` definition) the adjoints with respect to `a` and `b` but the latter one is not needed at all and is actually wasted computing time.

Is there a way to prevent that?

Full Working Example
``````julia> function conv(u::AbstractArray{T, N}, v::AbstractArray{D, M}, dims=ntuple(+, max(N, M))) where {T, D, N, M}
return real(ifft(fft(u, dims) .* fft(v, dims), dims))
end

julia> function ChainRulesCore.rrule(::typeof(conv), u::AbstractArray{T, N}, v::AbstractArray{D, M},
dims=ntuple(+, max(N, M))) where {T, D, N, M}
Y = conv(u, v, dims)
function conv_pullback(barx)
z = zero(eltype(u))
return z, conv(barx, conj(v), dims), print("lo"), z
end
return Y, conv_pullback
end

julia> function main()
u = randn((512, 512, 23))
v = copy(u)

f1(u) = sum(conv(u, v))
f2(u, v) = sum(conv(u, v))

@time f1(u)
@time f1(u)
@time f2(u, v)
@time f2(u, v)
return 0
end
main (generic function with 1 method)

julia> main()
0.385862 seconds (106 allocations: 276.907 MiB, 6.92% gc time)
0.355024 seconds (106 allocations: 276.907 MiB)
0.499096 seconds (106 allocations: 276.907 MiB, 28.23% gc time)
0.368945 seconds (106 allocations: 276.907 MiB, 4.22% gc time)
lo  0.830857 seconds (216 allocations: 599.815 MiB, 7.08% gc time)
lo  0.871777 seconds (216 allocations: 599.815 MiB, 16.17% gc time)
lo  0.760347 seconds (216 allocations: 599.815 MiB, 3.82% gc time)
lo  0.804685 seconds (216 allocations: 599.815 MiB, 7.13% gc time)
lo  0.912386 seconds (216 allocations: 599.815 MiB, 16.41% gc time)
lo  0.804314 seconds (216 allocations: 599.815 MiB, 4.46% gc time)
``````

Thanks,

Felix

2 Likes

I think it is not currently possible. There is a mcoro @thunk in diffrulescore, which should enable this, but it is disabled in zygote.

1 Like

Has there been any progress on this front?

You can do it if you write a custom `rrule` for your function. Then you can return `@not_implemented("foo")` for the terms you don’t want.

Otherwise you are relying on the compiler to eliminate “dead code” for unused results.

In contrast, Enzyme.jl has separate explicit rules for constant arguments — see the post by @ChrisRackauckas contrasting Enzyme and Zygote rules: What's the state of Automatic Differentiation in Julia January 2023? - #6 by ChrisRackauckas

I’ll take a look, thanks!

There is Zygote#966 which perhaps someone can push over the finish line.

1 Like