Hi all,
I have the following problem. Given a function f(x::Vector)
I know that I can evaluate the gradient using any AD package (In particular I use ReverseDiff.jl
). What I need now is to evaluate the gradient on a custom type.
For the sake of an example, Imagine that the function is
f(x) = x^2
And that I have a custom type
struct foo
x::Float64
y::Float64
end
That has all the custom operations defined (i.e. *(a::Number,b::foo) = foo(a*b.x,a^2*b.y)
), and something similar for +,-,/,...
as well as the math intrinsics (sin,cos,...
).
What I need is to evaluate
df(x) = 2*x
using my “rules” for the defined type foo
(and therefore returning a type foo
for the gradient).
Is there a easy way to achieve this, or I would need to hook on to ChainRule.jl
and construct by hand the forward/backward pass?
Thanks,
PS: For efficiency reasons, I would need to use backpropagation. Forward mode AD is not efficient for my problem.
It works fine with Zygote (which does reverse-mode AD) as long as define foo
as a type of Number
and define enough Number
-like methods:
struct foo <: Number
x::Float64
y::Float64
end
Base.:*(a::foo, b::foo) = foo(a.x*b.x - a.y*b.y, a.y*b.x + a.x*b.y)
Base.:*(x::Real, b::foo) = foo(x*b.x, x*b.y)
Base.one(a::foo) = foo(1,0)
Base.conj(a::foo) = foo(a.x, -a.y)
f(x) = x^2
using Zygote
gradient(f, foo(3,4))
gives the correct answer (foo(6.0, -8.0),)
.
2 Likes
Hi,
Many thanks. I cannot manage it to work for a function of several arguments:
f(x) = x[1]^2 + sin(x[2])*x[1] - 1.0
julia> input = Vector{foo}(undef, 2);
julia> input[1] = foo(6.0, -8.0);
julia> input[2] = foo(1.0, -2.0);
julia> df = x -> gradient(f, x);
julia> df(rand(2))
([0.8613523955990562, 0.09654726844074117],)
julia> df(input)
ERROR: MethodError: no method matching AbstractFloat(::foo)
Closest candidates are:
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
(::Type{T})(::Base.TwicePrecision) where T<:Number at twiceprecision.jl:243
Also does Zygote
have an in-place version of gradient
? I have to call this several times, and I would prefer not allocating memory. Any other tips to make Zygote
fast? I am used to have a pre-compiled tape for running the forward/backward pass more efficiently.
Thanks,
Define Base.float(x::foo) = x
(since foo
is already floating-point in your case, it doesn’t need to do anything… more generally, you’d call float
recursively on the fields).
Basically, you need to define enough methods for foo
to act like a number type to satisfy Zygote. Whenever you get a MethodError
, it’s an indication that you are missing a Number
-like function. (e.g. you may need to define promotion rules as well).
Many thanks,
I have managed Zygote.jl
to work, but unfortunately is much slower than ReverseDiff.jl
(even for simple real-valued functions). I have open a new question about this, to see if I am missing anything (Zygote terribly slow. What am I doing wrong?).
Thanks!