I have the following problem. Given a function
f(x::Vector) I know that I can evaluate the gradient using any AD package (In particular I use
ReverseDiff.jl). What I need now is to evaluate the gradient on a custom type.
For the sake of an example, Imagine that the function is
f(x) = x^2
And that I have a custom type
struct foo x::Float64 y::Float64 end
That has all the custom operations defined (i.e.
*(a::Number,b::foo) = foo(a*b.x,a^2*b.y)), and something similar for
+,-,/,... as well as the math intrinsics (
What I need is to evaluate
df(x) = 2*x
using my “rules” for the defined type
foo (and therefore returning a type
foo for the gradient).
Is there a easy way to achieve this, or I would need to hook on to
ChainRule.jl and construct by hand the forward/backward pass?
PS: For efficiency reasons, I would need to use backpropagation. Forward mode AD is not efficient for my problem.