# Evaluate gradient (backpropagation) on a user defined type

Hi all,

I have the following problem. Given a function `f(x::Vector)` I know that I can evaluate the gradient using any AD package (In particular I use `ReverseDiff.jl`). What I need now is to evaluate the gradient on a custom type.

For the sake of an example, Imagine that the function is

``````f(x) = x^2
``````

And that I have a custom type

``````struct foo
x::Float64
y::Float64
end
``````

That has all the custom operations defined (i.e. `*(a::Number,b::foo) = foo(a*b.x,a^2*b.y)`), and something similar for `+,-,/,...` as well as the math intrinsics (`sin,cos,...`).

What I need is to evaluate

``````df(x) = 2*x
``````

using my “rules” for the defined type `foo` (and therefore returning a type `foo` for the gradient).

Is there a easy way to achieve this, or I would need to hook on to `ChainRule.jl` and construct by hand the forward/backward pass?

Thanks,

PS: For efficiency reasons, I would need to use backpropagation. Forward mode AD is not efficient for my problem.

It works fine with Zygote (which does reverse-mode AD) as long as define `foo` as a type of `Number` and define enough `Number`-like methods:

``````struct foo <: Number
x::Float64
y::Float64
end
Base.:*(a::foo, b::foo) = foo(a.x*b.x - a.y*b.y, a.y*b.x + a.x*b.y)
Base.:*(x::Real, b::foo) = foo(x*b.x, x*b.y)
Base.one(a::foo) = foo(1,0)
Base.conj(a::foo) = foo(a.x, -a.y)

f(x) = x^2

using Zygote
``````

gives the correct answer `(foo(6.0, -8.0),)`.

2 Likes

Hi,

Many thanks. I cannot manage it to work for a function of several arguments:

``````f(x) = x^2 + sin(x)*x - 1.0
julia> input = Vector{foo}(undef, 2);
julia> input = foo(6.0, -8.0);
julia> input = foo(1.0, -2.0);
julia> df = x -> gradient(f, x);
julia> df(rand(2))
([0.8613523955990562, 0.09654726844074117],)
julia> df(input)
ERROR: MethodError: no method matching AbstractFloat(::foo)
Closest candidates are:
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
(::Type{T})(::Base.TwicePrecision) where T<:Number at twiceprecision.jl:243

``````

Also does `Zygote` have an in-place version of `gradient`? I have to call this several times, and I would prefer not allocating memory. Any other tips to make `Zygote` fast? I am used to have a pre-compiled tape for running the forward/backward pass more efficiently.

Thanks,

Define `Base.float(x::foo) = x` (since `foo` is already floating-point in your case, it doesn’t need to do anything… more generally, you’d call `float` recursively on the fields).

Basically, you need to define enough methods for `foo` to act like a number type to satisfy Zygote. Whenever you get a `MethodError`, it’s an indication that you are missing a `Number`-like function. (e.g. you may need to define promotion rules as well).

Many thanks,

I have managed `Zygote.jl` to work, but unfortunately is much slower than `ReverseDiff.jl` (even for simple real-valued functions). I have open a new question about this, to see if I am missing anything (Zygote terribly slow. What am I doing wrong?).

Thanks!