Flux.jl Inconsistent Training on Custom Architecture

I was referring to the MWE, in which only model.K is used. Incidentally, the gradient there is nothing as well in the MWE because of an interaction between implicit parameters (i.e. the thing you get from params and pass to gradient) and the AD (Zygote). c.f:

julia> using Zygote, LinearAlgebra

julia> X = [rand(2, 2)]
1-element Vector{Matrix{Float64}}:
 [0.6202008269430048 0.8555679159356662; 0.8423362289177463 0.09771425479926421]

# 1. this doesn't work
julia> gradient(() -> norm(X[1]), Params(X)).grads
IdDict{Any, Any} with 2 entries:
  [0.620201 0.855568; 0.842336 0.0977143] => nothing  # gradient wrt. X[1]
  :(Main.X)                               => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]

# 2. but this does
julia> gradient(() -> norm(X[1]), Params([X])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.X)                                 => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]
  [[0.620201 0.855568; 0.842336 0.0977143]] => Union{Nothing, Matrix{Float64}}[[0.45775 0.631467; 0.621701 0.0721198]]

# 3. as does this
julia> x₁ = X[1]
2×2 Matrix{Float64}:
 0.620201  0.855568
 0.842336  0.0977143

julia> gradient(() -> norm(x₁), Params(X)).grads
IdDict{Any, Any} with 2 entries:
  [0.620201 0.855568; 0.842336 0.0977143] => [0.45775 0.631467; 0.621701 0.0721198]
  :(Main.x₁)                              => [0.45775 0.631467; 0.621701 0.0721198]

# 4. and this (note explicit instead of implicit parameters.
# That is, we pass X directly and use it instead of params(X)). 
# This works with full Flux models too!
julia> gradient(x -> norm(x[1]), X)[1]
1-element Vector{Union{Nothing, Matrix{Float64}}}:
 [0.4577503206426886 0.6314672132598408; 0.6217013298363201 0.07211975463850043]

The gist is that params (and Params, if given a single argument) splat their arguments into an underlying IdDict:

julia> params(X).order[1]
2×2 Matrix{Float64}:  # this is X[1], where you'd expect it to be X itself
 0.620201  0.855568
 0.842336  0.0977143

For whatever reason, Zygote isn’t smart enough to link the X[1] in the loss to the actual value of X[1] in the params. You can see I avoid this in 2) and 3) by stopping X from being unravelled and hoisting the declaration of X[1] into a variable respectively.