Evaluate gradient (backpropagation) on a user defined type

Hi all,

I have the following problem. Given a function f(x::Vector) I know that I can evaluate the gradient using any AD package (In particular I use ReverseDiff.jl). What I need now is to evaluate the gradient on a custom type.

For the sake of an example, Imagine that the function is

f(x) = x^2

And that I have a custom type

struct foo
   x::Float64
   y::Float64
end

That has all the custom operations defined (i.e. *(a::Number,b::foo) = foo(a*b.x,a^2*b.y)), and something similar for +,-,/,... as well as the math intrinsics (sin,cos,...).

What I need is to evaluate

df(x) = 2*x

using my “rules” for the defined type foo (and therefore returning a type foo for the gradient).

Is there a easy way to achieve this, or I would need to hook on to ChainRule.jl and construct by hand the forward/backward pass?

Thanks,

PS: For efficiency reasons, I would need to use backpropagation. Forward mode AD is not efficient for my problem.

It works fine with Zygote (which does reverse-mode AD) as long as define foo as a type of Number and define enough Number-like methods:

struct foo <: Number
    x::Float64
    y::Float64
end
Base.:*(a::foo, b::foo) = foo(a.x*b.x - a.y*b.y, a.y*b.x + a.x*b.y)
Base.:*(x::Real, b::foo) = foo(x*b.x, x*b.y)
Base.one(a::foo) = foo(1,0)
Base.conj(a::foo) = foo(a.x, -a.y)

f(x) = x^2

using Zygote
gradient(f, foo(3,4))

gives the correct answer (foo(6.0, -8.0),).

2 Likes

Hi,

Many thanks. I cannot manage it to work for a function of several arguments:

f(x) = x[1]^2 + sin(x[2])*x[1] - 1.0
julia> input = Vector{foo}(undef, 2);
julia> input[1] = foo(6.0, -8.0);
julia> input[2] = foo(1.0, -2.0);
julia> df = x -> gradient(f, x);
julia> df(rand(2))
([0.8613523955990562, 0.09654726844074117],)
julia> df(input)
ERROR: MethodError: no method matching AbstractFloat(::foo)
Closest candidates are:
  (::Type{T})(::T) where T<:Number at boot.jl:760
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
  (::Type{T})(::Base.TwicePrecision) where T<:Number at twiceprecision.jl:243

Also does Zygote have an in-place version of gradient? I have to call this several times, and I would prefer not allocating memory. Any other tips to make Zygote fast? I am used to have a pre-compiled tape for running the forward/backward pass more efficiently.

Thanks,

Define Base.float(x::foo) = x (since foo is already floating-point in your case, it doesn’t need to do anything… more generally, you’d call float recursively on the fields).

Basically, you need to define enough methods for foo to act like a number type to satisfy Zygote. Whenever you get a MethodError, it’s an indication that you are missing a Number-like function. (e.g. you may need to define promotion rules as well).