I’ve built a loss function based on the negative log likelihood of a multivariate log normal distribution. However, when I try to calculate the gradients I hit the following error which I would think stems from the handling of the covariance matrix. As you can see in the error message the main issue seems to be that we end up with an invalid argument to a function in LAPACK. The loss function I wrote evaluates fine though and get the correct values. So I’m assuming there’s a mistake somewhere in rrule or maybe one missing?
Happy to make an MVE just wanted to check if anyone has seen this before and can point me to what’s going wrong?
The main loss function looks like this:
function gbmdistloss2(data, response)
Σb, μb, X₀b, Yb = data
Σbhat, μbhat = response
loss = 0
T = first(size(Yb)) # The number of timesteps into the future we're predicting
for j in 1:last(size(μb))
Σ, μ = Σbhat[:, :, j], μbhat[:, j]
ΣΣ = LowerTriangular(Σ) * transpose(LowerTriangular(Σ))
X₀ = X₀b[:, j]
Y = Yb[:, :, j]
loss += -sum([mvlognormpdf(Y[t, :], X₀, t, μ, ΣΣ) for t in 1:T])
end
loss
end
Stacktrace below
ERROR: ArgumentError: invalid argument #4 to LAPACK call
Stacktrace:
[1] chklapackerror
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:38 [inlined]
[2] gesdd!(job::Char, A::Matrix{Float32})
@ LinearAlgebra.LAPACK ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:1714
[3] _svd!
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:125 [inlined]
[4] svd!(A::Matrix{Float32}; full::Bool, alg::LinearAlgebra.DivideAndConquer)
@ LinearAlgebra ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:105
[5] svd!
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:100 [inlined]
[6] svd(A::Matrix{Float32}; full::Bool, alg::LinearAlgebra.DivideAndConquer)
@ LinearAlgebra ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:179
[7] svd
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:178 [inlined]
[8] rrule
@ ~/.julia/packages/ChainRules/DbWAz/src/rulesets/LinearAlgebra/factorization.jl:284 [inlined]
[9] rrule
@ ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138 [inlined]
[10] chain_rrule
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
[11] macro expansion
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
[12] _pullback
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87 [inlined]
[13] opnorm2
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:685 [inlined]
[14] _pullback(ctx::Zygote.Context{false}, f::typeof(LinearAlgebra.opnorm2), args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[15] opnorm
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:745 [inlined]
[16] _pullback(::Zygote.Context{false}, ::typeof(opnorm), ::Matrix{Float32}, ::Int64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[17] opnorm
@ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:744 [inlined]
[18] mvlognormpdf
@ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:136 [inlined]
[19] _pullback(::Zygote.Context{…}, ::typeof(mvlognormpdf), ::Vector{…}, ::Vector{…}, ::Int64, ::Vector{…}, ::Matrix{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[20] #80
@ ./none:0 [inlined]
[21] _pullback(ctx::Zygote.Context{false}, f::var"#80#81"{Matrix{…}, Vector{…}, Matrix{…}, Vector{…}}, args::Int64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[22] #666
@ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188 [inlined]
[23] iterate
@ ./generator.jl:48 [inlined]
[24] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
@ Base ./array.jl:811
[25] collect_similar
@ ./array.jl:720 [inlined]
[26] map
@ ./abstractarray.jl:3371 [inlined]
[27] ∇map
@ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188 [inlined]
[28] _pullback
@ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:231 [inlined]
[29] gbmdistloss2
@ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:336 [inlined]
[30] _pullback(::Zygote.Context{…}, ::typeof(gbmdistloss2), ::@NamedTuple{…}, ::Tuple{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[31] #82
@ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:349 [inlined]
[32] _pullback(ctx::Zygote.Context{false}, f::var"#82#83"{@NamedTuple{…}}, args::GBMNet{Parallel{…}, Chain{…}})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[33] pullback(f::Function, cx::Zygote.Context{false}, args::GBMNet{Parallel{typeof(vcat), Tuple{…}}, Chain{Tuple{…}}})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
[34] pullback
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88 [inlined]
[35] withgradient(f::Function, args::GBMNet{Parallel{typeof(vcat), Tuple{…}}, Chain{Tuple{…}}})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
[36] #withgradient#5
@ ~/.julia/packages/Flux/5vIRy/src/gradient.jl:182 [inlined]
[37] withgradient
@ ~/.julia/packages/Flux/5vIRy/src/gradient.jl:169 [inlined]