Zygote InexactError using repeat() with inner keyword

The following function throws InexactError: Int64(2.0794415416798357):

function get_error(G, E, x)
    n, m = size(G)
    Ê = repeat(E, inner=(1, m))
    Ĝ = repeat(G, inner=(1, m))
    return Ĝ * prod(x.^Ê, dims=1)'
end

G = [0. 1.5 1.; 1.5 1. 1.5]  # Float matrix with non-integer values
E = [0 0 1; 0 1 0]  # Int matrix
x = [1., 2.]

gs = jacobian(get_error, G, E, x)

However replacing get_error() with the following function (where just the standart repeat() is used – so the computed result is different) works just fine.

function get_no_error(G, E, x)
    n, m = size(G)
    Ê = repeat(E, 1, m)
    Ĝ = repeat(G, inner=(1, m))
    return Ĝ * prod(x.^Ê, dims=1)'
end

Is there a way to reformulate get_error() so that autodiff works also for the inner repeat?
A method to just ignore the derivative for the integer matrix E and only return the derivative for G and x would also work for me.

julia> using ChainRulesCore

julia> function get_error(G, E, x)
           n, m = size(G)
           Ê = @ignore_derivatives repeat(E, inner=(1, m))
           Ĝ = repeat(G, inner=(1, m))
           return Ĝ * prod(x.^Ê, dims=1)'
       end
get_error (generic function with 1 method)

julia> gs = jacobian((G,x) -> get_error(G,E,x), G, x)
([3.0 0.0 … 3.0 0.0; 0.0 3.0 … 0.0 3.0], [3.0 4.5; 4.5 3.0])

The above solution works for me, so I’m accepting it.

But I still don’t get, why the inner=(1,m) throws off Zygote.

It’s a bug. Smaller example:

julia> jacobian([1, 2.0], [3]) do E,x
           x .^ repeat(E, inner=2)
       end
([3.295836866004329 0.0; 3.295836866004329 0.0; 0.0 9.887510598012987; 0.0 9.887510598012987], [1.0; 1.0; 6.0; 6.0;;])

julia> jacobian([1, 2], [3]) do E,x
           x .^ repeat(E, inner=2)
       end
ERROR: InexactError: Int64(3.295836866004329)
[... not useful]

The problem is a particular rule:

julia> using ChainRulesCore, ChainRules

julia> y, back = rrule(repeat, [1,2], inner=3)
([1, 1, 1, 2, 2, 2], ChainRules.var"#repeat_pullback#1371"{Vector{Int64}, Int64, Tuple{Int64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}([1, 2], 3, (2,), ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),))))

julia> back(ones(6))
(NoTangent(), [3.0, 3.0])

julia> back(rand(6))
ERROR: InexactError: Int64(0.31780288786437183)
Stacktrace:
 [1] Int64
   @ ./float.jl:900 [inlined]
 [2] convert
   @ ./number.jl:7 [inlined]
 [3] setindex!(A::Vector{Int64}, x::Float64, i1::Int64)
   @ Base ./array.jl:969
 [4] (::ChainRules.var"#repeat_pullback#1371"{Vector{Int64}, Int64, Tuple{Int64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}})(ȳ::Vector{Float64})
   @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/array.jl:185

Somewhere there’s an issue in ChainRules with a better version of this rule.

2 Likes