Cannot differentiate through a multivariate log normal probability density with Zygote

I’ve built a loss function based on the negative log likelihood of a multivariate log normal distribution. However, when I try to calculate the gradients I hit the following error which I would think stems from the handling of the covariance matrix. As you can see in the error message the main issue seems to be that we end up with an invalid argument to a function in LAPACK. The loss function I wrote evaluates fine though and get the correct values. So I’m assuming there’s a mistake somewhere in rrule or maybe one missing?

Happy to make an MVE just wanted to check if anyone has seen this before and can point me to what’s going wrong?

The main loss function looks like this:

function gbmdistloss2(data, response)
    Σb, μb, X₀b, Yb = data
    Σbhat, μbhat = response
    loss = 0
    T = first(size(Yb)) # The number of timesteps into the future we're predicting
    for j in 1:last(size(μb))
        Σ, μ = Σbhat[:, :, j], μbhat[:, j]
        ΣΣ = LowerTriangular(Σ) * transpose(LowerTriangular(Σ))
        X₀ = X₀b[:, j]
        Y = Yb[:, :, j]
        loss += -sum([mvlognormpdf(Y[t, :], X₀, t, μ, ΣΣ) for t in 1:T])
    end
    loss
end

Stacktrace below

ERROR: ArgumentError: invalid argument #4 to LAPACK call
Stacktrace:
  [1] chklapackerror
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:38 [inlined]
  [2] gesdd!(job::Char, A::Matrix{Float32})
    @ LinearAlgebra.LAPACK ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/lapack.jl:1714
  [3] _svd!
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:125 [inlined]
  [4] svd!(A::Matrix{Float32}; full::Bool, alg::LinearAlgebra.DivideAndConquer)
    @ LinearAlgebra ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:105
  [5] svd!
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:100 [inlined]
  [6] svd(A::Matrix{Float32}; full::Bool, alg::LinearAlgebra.DivideAndConquer)
    @ LinearAlgebra ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:179
  [7] svd
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/svd.jl:178 [inlined]
  [8] rrule
    @ ~/.julia/packages/ChainRules/DbWAz/src/rulesets/LinearAlgebra/factorization.jl:284 [inlined]
  [9] rrule
    @ ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138 [inlined]
 [10] chain_rrule
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
 [12] _pullback
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87 [inlined]
 [13] opnorm2
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:685 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::typeof(LinearAlgebra.opnorm2), args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [15] opnorm
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:745 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(opnorm), ::Matrix{Float32}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [17] opnorm
    @ ~/src/progs/julia-1.11.2/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:744 [inlined]
 [18] mvlognormpdf
    @ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:136 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::typeof(mvlognormpdf), ::Vector{…}, ::Vector{…}, ::Int64, ::Vector{…}, ::Matrix{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [20] #80
    @ ./none:0 [inlined]
 [21] _pullback(ctx::Zygote.Context{false}, f::var"#80#81"{Matrix{…}, Vector{…}, Matrix{…}, Vector{…}}, args::Int64)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [22] #666
    @ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188 [inlined]
 [23] iterate
    @ ./generator.jl:48 [inlined]
 [24] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [25] collect_similar
    @ ./array.jl:720 [inlined]
 [26] map
    @ ./abstractarray.jl:3371 [inlined]
 [27] ∇map
    @ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188 [inlined]
 [28] _pullback
    @ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:231 [inlined]
 [29] gbmdistloss2
    @ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:336 [inlined]
 [30] _pullback(::Zygote.Context{…}, ::typeof(gbmdistloss2), ::@NamedTuple{…}, ::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [31] #82
    @ ~/src/aal/StochasticGenie/experiments/gbmneuralnettest.jl:349 [inlined]
 [32] _pullback(ctx::Zygote.Context{false}, f::var"#82#83"{@NamedTuple{…}}, args::GBMNet{Parallel{…}, Chain{…}})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [33] pullback(f::Function, cx::Zygote.Context{false}, args::GBMNet{Parallel{typeof(vcat), Tuple{…}}, Chain{Tuple{…}}})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
 [34] pullback
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88 [inlined]
 [35] withgradient(f::Function, args::GBMNet{Parallel{typeof(vcat), Tuple{…}}, Chain{Tuple{…}}})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
 [36] #withgradient#5
    @ ~/.julia/packages/Flux/5vIRy/src/gradient.jl:182 [inlined]
 [37] withgradient
    @ ~/.julia/packages/Flux/5vIRy/src/gradient.jl:169 [inlined]

Hi! Can you provide a complete runnable example with all necessary imports and object definitions?

1 Like

Working on it :blush:

1 Like

Just to follow up on this. I’m fairly convinced that this long error message was caused by me getting NaNs in my gradients. In the MVP I made I can see this clearly. The funny thing is it seems to happen primarily when using AdamW. With Adam it often works out fine. I’m still investigating… If people are interested I can post my small example.

As a sanity check, can you confirm the epsilon and lambda fields on the instantiated AdamW rule have the values you expect?