Given this function:
function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
mat = normal_pdf.(data, means', stds' .^2) # (N, K)
sum(
mat .* weights', dims=2
) .|> log |> sum
end
…and differentiating like this:
objective = params -> mixture_loglikelihood(params, data)
_, (_, grad_storage) = Yota.grad(objective, params0)
…Yota produces this error:
[ Info: Computing gradient w/ Yota
ERROR: LoadError: MethodError: no method matching length(::Type{Val{2}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{ArrayInterfaceCore.BidiagonalIndex, ArrayInterfaceCore.TridiagonalIndex}) at ~/.julia/packages/ArrayInterfaceCore/7kMjZ/src/ArrayInterfaceCore.jl:594
length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/8Dz3j/src/abstractarray.jl:1
...
Stacktrace:
[1] unzip(tuples::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}})
@ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:92
[2] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:49
[3] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:48
[4] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[5] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
[6] record_or_recurse!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:85
[7] trace!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
[8] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, fargtypes::Tuple{typeof(normal_pdf), Tuple{DataType, DataType, DataType}}, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
[9] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:136
[10] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:170
[11] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::LinearAlgebra.Adjoint{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, LinearAlgebra.Adjoint{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}}, Base.RefValue{Val{2}}}})
@ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:98
[12] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[13] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
[14] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:49
[15] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:193
[16] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
[17] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:202
[18] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
[19] trace(f::Function, args::Vector{Float64}; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
[20] #gradtape#90
@ ~/.julia/packages/Yota/VCIzN/src/grad.jl:243 [inlined]
[21] grad(f::var"#12#13", args::Vector{Float64}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
[22] grad(f::var"#12#13", args::Vector{Float64})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
[23] top-level scope
@ ~/test/autodiff_bench/code.jl:118
[24] include(fname::String)
@ Base.MainInclude ./client.jl:476
[25] top-level scope
@ REPL[2]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:116
When I use the version with Tullio:
function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
Tullio.@tullio tmp[n] := weights[k] * normal_pdf(data[n], means[k], stds[k]^2) grad=Dual
sum(log, tmp)
end
…Yota produces this error:
ERROR: LoadError: No deriative rule found for op %78 = convert(%3, %76)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:222
[5] #gradtape#90
@ ~/.julia/packages/Yota/VCIzN/src/grad.jl:244 [inlined]
[6] grad(f::var"#21#22", args::Vector{Float64}; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
[7] grad(f::var"#21#22", args::Vector{Float64})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
[8] top-level scope
@ ~/test/autodiff_bench/code.jl:114
[9] include(fname::String)
@ Base.MainInclude ./client.jl:476
[10] top-level scope
@ REPL[2]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:112
Both errors point to code within Yota, so it seems like it doesn’t work…
Looks like Yota has trouble figuring out broadcasting. Here Zygote works fine:
julia> xs = randn(200);
julia> Zygote.gradient(mu->sum(log, normal_pdf.(xs, mu, 1.0)), 1.0)
(-169.05854117272128,)
But Yota fails:
julia> Yota.grad(mu->sum(log, normal_pdf.(xs, mu, 1.0)), 1.0)
ERROR: MethodError: no method matching length(::Type{Val{2}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{ArrayInterfaceCore.BidiagonalIndex, ArrayInterfaceCore.TridiagonalIndex}) at ~/.julia/packages/ArrayInterfaceCore/7kMjZ/src/ArrayInterfaceCore.jl:594
length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/8Dz3j/src/abstractarray.jl:1
...
Stacktrace:
[1] unzip(tuples::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}})
@ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:92
[2] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:49
[3] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:48
[4] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[5] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
[6] record_or_recurse!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:85
[7] trace!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
[8] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, fargtypes::Tuple{typeof(normal_pdf), Tuple{DataType, DataType, DataType}}, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
[9] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:136
[10] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:170
[11] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::Float64, ::Float64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:98
[12] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
[13] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
[14] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:49
[15] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:193
[16] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
[17] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
[18] #gradtape#90
@ ~/.julia/packages/Yota/VCIzN/src/grad.jl:243 [inlined]
[19] grad(f::var"#85#86", args::Float64; seed::Int64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
[20] grad(f::var"#85#86", args::Float64)
@ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
[21] top-level scope
@ REPL[28]:1
As a user, I have no clue what ::Type{Val{2}}
even is and where it came from. I don’t think I have it in my code.