Hi all,

I have the following problem. Given a function `f(x::Vector)`

I know that I can evaluate the gradient using any AD package (In particular I use `ReverseDiff.jl`

). What I need now is to evaluate the gradient on a custom type.

For the sake of an example, Imagine that the function is

```
f(x) = x^2
```

And that I have a custom type

```
struct foo
x::Float64
y::Float64
end
```

That has all the custom operations defined (i.e. `*(a::Number,b::foo) = foo(a*b.x,a^2*b.y)`

), and something similar for `+,-,/,...`

as well as the math intrinsics (`sin,cos,...`

).

What I need is to evaluate

```
df(x) = 2*x
```

using my â€śrulesâ€ť for the defined type `foo`

(and therefore returning a type `foo`

for the gradient).

Is there a easy way to achieve this, or I would need to hook on to `ChainRule.jl`

and construct by hand the forward/backward pass?

Thanks,

PS: For efficiency reasons, I would need to use backpropagation. Forward mode AD is not efficient for my problem.

It works fine with Zygote (which does reverse-mode AD) as long as define `foo`

as a type of `Number`

and define enough `Number`

-like methods:

```
struct foo <: Number
x::Float64
y::Float64
end
Base.:*(a::foo, b::foo) = foo(a.x*b.x - a.y*b.y, a.y*b.x + a.x*b.y)
Base.:*(x::Real, b::foo) = foo(x*b.x, x*b.y)
Base.one(a::foo) = foo(1,0)
Base.conj(a::foo) = foo(a.x, -a.y)
f(x) = x^2
using Zygote
gradient(f, foo(3,4))
```

gives the correct answer `(foo(6.0, -8.0),)`

.

2 Likes

Hi,

Many thanks. I cannot manage it to work for a function of several arguments:

```
f(x) = x[1]^2 + sin(x[2])*x[1] - 1.0
julia> input = Vector{foo}(undef, 2);
julia> input[1] = foo(6.0, -8.0);
julia> input[2] = foo(1.0, -2.0);
julia> df = x -> gradient(f, x);
julia> df(rand(2))
([0.8613523955990562, 0.09654726844074117],)
julia> df(input)
ERROR: MethodError: no method matching AbstractFloat(::foo)
Closest candidates are:
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
(::Type{T})(::Base.TwicePrecision) where T<:Number at twiceprecision.jl:243
```

Also does `Zygote`

have an in-place version of `gradient`

? I have to call this several times, and I would prefer not allocating memory. Any other tips to make `Zygote`

fast? I am used to have a pre-compiled tape for running the forward/backward pass more efficiently.

Thanks,

Define `Base.float(x::foo) = x`

(since `foo`

is already floating-point in your case, it doesnâ€™t need to do anythingâ€¦ more generally, youâ€™d call `float`

recursively on the fields).

Basically, you need to define enough methods for `foo`

to act like a number type to satisfy Zygote. Whenever you get a `MethodError`

, itâ€™s an indication that you are missing a `Number`

-like function. (e.g. you may need to define promotion rules as well).

Many thanks,

I have managed `Zygote.jl`

to work, but unfortunately is much slower than `ReverseDiff.jl`

(even for simple real-valued functions). I have open a new question about this, to see if I am missing anything (Zygote terribly slow. What am I doing wrong?).

Thanks!