Pullback method error for Diagonal() call inside array generator

Here’s a minimal example of an error I’ve encountered with my code using Flux.

const x = zeros(Float32, 10, 4)
p = Flux.params(x)
function f()
    x1 = softmax(x, dims=2)
    x_d = [Diagonal(x1[:, k]) for k = 1:4]
    x_d = cat(x_d..., dims=3)
    # (10,10,4) array with dims=(1,2) diagonalized
    return sum(x_d.^2)
end
g = gradient(p) do
    loss = f()
end

While the forward pass runs to the function f() runs without any problems, the gradient computations seems to throw this kind of an error:

ERROR: MethodError: no method matching _Diagonal_pullback(::Array{Float32, 3})
Closest candidates are:
  _Diagonal_pullback(::Diagonal) at ~/.julia/packages/ChainRules/3yDBX/src/rulesets/LinearAlgebra/structured.jl:54
  _Diagonal_pullback(::AbstractMatrix) at ~/.julia/packages/ChainRules/3yDBX/src/rulesets/LinearAlgebra/structured.jl:53
  _Diagonal_pullback(::ChainRulesCore.Tangent) at ~/.julia/packages/ChainRules/3yDBX/src/rulesets/LinearAlgebra/structured.jl:55
  ...
Stacktrace:
...

My intention is to diagonalize a multidimensional array (that depends on some trainable parameters) along a specific axis before operating in further. Any suggestions for debugging/rewriting this?

After some tinkering, I found that replacing Diagonal with diagm along with changing the cat() function call by augmenting with reduce returns the correct result with no errors. Also took the liberty to replace params(x) with params([x]) to make it easier to verify the value of the gradient obtained.

const x = zeros(Float32, 10, 4)
p = Flux.params([x])
function f()
    x1 = softmax(x, dims=2)
    x_d = LinearAlgebra.diagm.([x1[:, k] for k = 1:4])
    x_d = reduce((x1, x2) -> cat(x1, x2, dims=3), x_d)
    # (10,10,4) array with dims=(1,2) diagonalized
    return sum(x_d.^2)
end
g = gradient(p) do
    loss = f()
end

While I could probably understand a certain problem with the sparse array returned by Diagonal, but having to replace the call to cat with reduce doesn’t really make sense. A better scenario would be to have a multi-dimensional version of diagm, similar to torch.diag_embed in PyTorch.

I think the error is due to this issue with Zygote’s rule for cat.

One way you could work around it for now is to insert reshape, whose gradient rule will fix the shape:

julia> gradient(reshape(1:12,3,4)) do x
        y = cat([reshape(Diagonal(x[:, k]),3,3) for k = 1:4]...; dims=3)
        sum(abs2, y)
       end[1]
3×4 Matrix{Float64}:
 2.0   8.0  14.0  20.0
 4.0  10.0  16.0  22.0
 6.0  12.0  18.0  24.0

julia> using TransmuteDims  # a package which does generalised diagm

julia> gradient(reshape(1:12,3,4)) do x
        y = transmute(x, (1,1,2))
        sum(abs2, y)
       end[1]
3×4 Matrix{Float64}:
 2.0   8.0  14.0  20.0
 4.0  10.0  16.0  22.0
 6.0  12.0  18.0  24.0

julia> transmute(rand(3,2), (1,1,2))
3×3×2 transmute(::Matrix{Float64}, (1, 1, 2)) with eltype Float64:
[:, :, 1] =
 0.0825788   ⋅         ⋅ 
  ⋅         0.948785   ⋅ 
  ⋅          ⋅        0.1827

[:, :, 2] =
 0.36514   ⋅         ⋅ 
  ⋅       0.791012   ⋅ 
  ⋅        ⋅        0.726368
1 Like

@mcabbott thanks for the reply and indeed a great work on TransmuteDims package. It makes the code much more presentable and performs superior to the other methods for diagonalising a multidimensional array:

julia> x = reshape(1:12,3,4);

julia> @btime transmute(x, (1, 1, 2, 3));
  309.012 ns (4 allocations: 176 bytes)

@btime begin
               y = [LinearAlgebra.diagm(x[:, k]) for k = 1:4]
               reshape(reduce(hcat, y), 3, 3, 4) #performs better than reduce(cat...)
              end;
  1.599 μs (19 allocations: 1.48 KiB)

julia> @btime begin
               y = [reshape(Diagonal(x[:, k]), 3, 3) for k = 1:4]
               reshape(reduce(hcat, y), 3, 3, 4) #performs better than reduce(cat...)
              end;
  3.719 μs (69 allocations: 5.50 KiB)

However, gradient estimation is on the slower side:

julia> @btime gradient(reshape(1:12,3,4)) do x
                      y = transmute(x, (1,1,2))
                      sum(abs2, y)
                     end;
  287.887 μs (840 allocations: 39.48 KiB)

julia> @btime gradient(reshape(1:12,3,4)) do x
                      y = [LinearAlgebra.diagm(x[:, k]) for k=1:4]
                      y = reshape(reduce(hcat, y), 3, 3, 4)
                      sum(abs2, y)
                   end;
  6.104 μs (85 allocations: 8.06 KiB)

I’d definitely love to use it in my code if the speeds became comparable.

It’s a type-stability problem, I think. With transmute(x, Val((1,1,2))) it’s much quicker, <1μs.

That said, if you can avoid entirely big arrays which are most y zero, that’s probably even better.

1 Like