Is `cumsum` differentiable with Zygote when used in a Flux model?

I have a call to cumsum in my Flux model that works as intended in the forward pass, but throws an error when I try to differentiate it. A minimum working example:

using Flux
n = 5
x = [rand(1, n), rand(1, n)]    # dummy data
vcat(cumsum(x)...)              # works as expected
Flux.gradient(x -> vcat(cumsum(x)...) |> sum, x) # error

The specific error message that I receive is:

ERROR: MethodError: no method matching reverse(::ChainRulesCore.Tangent{Any, Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}; dims::Int64)

Closest candidates are:
  reverse(::ChainRulesCore.Tangent) got unsupported keyword argument "dims"
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/structural_tangent.jl:406
  reverse(::Union{SubString{String}, String}) got unsupported keyword argument "dims"
   @ Base strings/substring.jl:174
  reverse(::CUDA.AnyCuVector{T}) where T got unsupported keyword argument "dims"
   @ CUDA ~/.julia/packages/CUDA/htRwP/src/reverse.jl:151
  ...

Stacktrace:
 [1] cumsum_pullback
   @ ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:226 [inlined]
 [2] (::Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}})(dy::Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
   @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211
 [3] Pullback
   @ ./REPL[110]:1 [inlined]
 [4] (::Zygote.Pullback{Tuple{var"#42#43", Vector{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Matrix{Float64}, typeof(sum)}, Tuple{Zygote.var"#2993#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1395"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Val{2}}}}}, Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#42#43", Vector{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Matrix{Float64}, typeof(sum)}, Tuple{Zygote.var"#2993#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1395"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Val{2}}}}}, Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}}}}})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [6] gradient(f::Function, args::Vector{Matrix{Float64}})
   @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [7] top-level scope
   @ REPL[110]:1

Can cumsum be used in Flux models? Many thanks in advance for any comments or suggestions.

Looks like reduce(vcat, cumsum(x)) works. Maybe this is a bug? Or maybe splatting is an issue?

1 Like

I think you want cumsum(reduce(vcat, x); dims=1), or better just never to make slices & work with one solid array reduce(vcat, x) == vcat(x...).

Yes, 599.

2 Likes