What versions are you on? I get the following error wheter or not I define the fast_max
gradient:
julia> ReverseDiff.gradient(sum∘logsumexp_avx, x)
ERROR: MethodError: no method matching similar(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof
(exp),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.Track
edReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{Reverse
Diff.TrackedStyle,Nothing,typeof(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,F
loat64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, ::Type{ReverseDiff.TrackedReal{Float64,Float64,Nothing}}, ::Tuple{Base.OneTo{I
nt64},Base.OneTo{Int64}})
Closest candidates are:
similar(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{N},Axes,F,Args} where Args<:Tuple where F where Axes, ::Type{ElType}, ::Any) where {N, ElType} at broadcast.jl:197
similar(::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayConflict,Axes,F,Args} where Args<:Tuple where F where Axes, ::Type{ElType}, ::Any) where ElType at broadcast.jl:202
similar(::Base.Broadcast.Broadcasted, ::Type{T}) where T at broadcast.jl:196
...
Stacktrace:
[1] similar(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{Base.Broadcast.Broad
casted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedAr
ray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof
(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Vector{
Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, #unused#::Type{ReverseDiff.TrackedReal{Float64,Float64,Nothing}})
@ Base.Broadcast ./broadcast.jl:196
[2] vmaterialize(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{Base.Broadcast.
Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(-),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,Trac
kedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}},Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,t
ypeof(==),Tuple{TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}},LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,TrackedArray{Float64,Float64,1,Vector{Float64},Ve
ctor{Float64}}},TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}}}}}}, #unused#::Val{:Main})
@ LoopVectorization ~/.julia/dev/LoopVectorization/src/broadcast.jl:327
[3] logsumexp_avx(mat::TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}}; dims::Int64)
@ Main ./REPL[34]:4
[4] logsumexp_avx(mat::TrackedArray{Float64,Float64,2,Matrix{Float64},Matrix{Float64}})
@ Main ./REPL[34]:2