Is there a way to define custom Zygote adjoints for only some arguments of a function?

Suppose I have some function for which I define custom Zygote adjoints, but the gradient for some of the arguments I don’t need yet and I’d like to leave to implement later. Currently, I could define that adjoint to return nothing for those arguments and this will work, but if I accidentally take a gradient w.r.t. to them, it will silently fail and give me 0/nothing. Instead, is there a way to do something kind of like the following?

f(x,y) = x*y
# suppose the gradient w.r.t. x is easy but I haven't coded up y yet
@adjoint f(x,y) = f(x,y), Δ->(Δ*y, error("not implemented yet"))

gradient(x->f(x,1), 1) # should work 
gradient(y->f(1,y), 1) # should error "not implemented yet"

The above does not currently work since the error is always evaluated, even we don’t need the gradient w.r.t. to y, but could something similar to this work?

Basically, no, but this would be a great reason to have thunks (i.e. returning a zero-argument closure that can lazily evaluate to the gradient when needed; or in this case throw an error). We’ve discussed adding that a fair bit, ChainRules already has a design, so it’ll happen eventually.

1 Like