Hello Friends of Zygote,
while trying to create the adjoint for a custom array type, I was not really successful.
struct Example{T,N,F} <: AbstractArray{T,N} where F
sz::NTuple{N, Int}
f::F
end
Base.getindex(a::Example, idx::Vararg) = a.f(idx)
Base.size(e::Example) = e.sz
function g(o)
f(idx) = sum((Tuple(idx).+o).^2)
return Example{Float64,2,typeof(f)}((5,5),f)
end
using Zygote
gradient(x -> sum(g(x)), 1)
Understandably this leads to an error:
ERROR: Need an adjoint for constructor Example{Float64, 2, var"#f#193"{Int64}}. Gradient is of type FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] (::Zygote.Jnew{Example{Float64, 2, var"#f#193"{Int64}}, Nothing, false})(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\lib\lib.jl:314
[3] (::Zygote.var"#1720#back#194"{Zygote.Jnew{Example{Float64, 2, var"#f#193"{Int64}}, Nothing, false}})(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[4] Pullback
@ .\REPL[102]:2 [inlined]
[5] (::typeof(∂(Example{Float64, 2, var"#f#193"{Int64}})))(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface2.jl:0
[6] Pullback
@ .\REPL[119]:3 [inlined]
[7] (::typeof(∂(g)))(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface2.jl:0
[8] Pullback
@ .\REPL[121]:1 [inlined]
[9] (::typeof(∂(#194)))(Δ::Int64)
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface2.jl:0
[10] (::Zygote.var"#41#42"{typeof(∂(#194))})(Δ::Int64)
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:41
[11] gradient(::Function, ::Int64, ::Vararg{Int64, N} where N)
@ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:59
[12] top-level scope
@ REPL[121]:1
But trying to write the correct @adjoint
seems not so easy for this case. Ideally it would itself use the Example
type. Any help would be appreciated. Also a solution based on ChainRulesCore
would be great.