I have a call to cumsum
in my Flux
model that works as intended in the forward pass, but throws an error when I try to differentiate it. A minimum working example:
using Flux
n = 5
x = [rand(1, n), rand(1, n)] # dummy data
vcat(cumsum(x)...) # works as expected
Flux.gradient(x -> vcat(cumsum(x)...) |> sum, x) # error
The specific error message that I receive is:
ERROR: MethodError: no method matching reverse(::ChainRulesCore.Tangent{Any, Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}; dims::Int64)
Closest candidates are:
reverse(::ChainRulesCore.Tangent) got unsupported keyword argument "dims"
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/structural_tangent.jl:406
reverse(::Union{SubString{String}, String}) got unsupported keyword argument "dims"
@ Base strings/substring.jl:174
reverse(::CUDA.AnyCuVector{T}) where T got unsupported keyword argument "dims"
@ CUDA ~/.julia/packages/CUDA/htRwP/src/reverse.jl:151
...
Stacktrace:
[1] cumsum_pullback
@ ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:226 [inlined]
[2] (::Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}})(dy::Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211
[3] Pullback
@ ./REPL[110]:1 [inlined]
[4] (::Zygote.Pullback{Tuple{var"#42#43", Vector{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Matrix{Float64}, typeof(sum)}, Tuple{Zygote.var"#2993#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1395"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Val{2}}}}}, Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[5] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#42#43", Vector{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Matrix{Float64}, typeof(sum)}, Tuple{Zygote.var"#2993#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1395"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Val{2}}}}}, Zygote.ZBack{ChainRules.var"#cumsum_pullback#1664"{1, Int64, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}, Tuple{Base.OneTo{Int64}}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[6] gradient(f::Function, args::Vector{Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
[7] top-level scope
@ REPL[110]:1
Can cumsum
be used in Flux models? Many thanks in advance for any comments or suggestions.