Let me clarify. Say we are defining the following loss function
function loss_function(m, x, ps, st)
f = x -> m(x, ps, st)
return sum(abs2, only(Zygote.gradient(f, x)))
end
If we do a Zygote.gradient(loss_function, ....)
it will be very slow in the best case and error in the vast majority of cases, so we want to use JVP over Zygote.gradient to compute the gradients wrt x and ps.
function CRC.rrule(cfg::..., ::typeof(Zygote.gradient), f::F, x) where {F}
# Here `f` is potentially a closure
# Q1: How do we extract the values like `m`, `ps` and `st` in `f`?
# Let's proceed assuming we did it
y = Zygote.gradient(f, x)
function gradient_pullback(Delta)
# We use our jvp trick to compute dm, dps, dst, dx
return NoTangent(), NoTangent(), Tangent for `f`, dx
# Q2: What is the appropriate tangent for `f`
end
return y, gradient_pullback
end
Since I did not know how to solve this what Lux does and I am proposing for DI to do is., we have a ParamsStruct
function Zygote.gradient(f::ParamsStruct, x)
return __my_custom_gradient_func(f.func, x, f.ps) # In the forward pass we just call Zygote.gradient
end
# Defining rrule is now simple as
function CRC.rrule(cfg::..., ::typeof(__my_custom_gradient_func), f::F, x, ps) where {F}
# Use the JVP trick here
end