Zygote with Tullio gives wrong gradients/pullbacks using CUDA

using Tullio, Zygote, CUDA, KernelAbstractions, OMEinsum

# Show outer product of strings
A = ["x", "y", "z", "w"]
res = Array{String}(undef, length(A), length(A))
for (i, r) in enumerate(A)
    for (j, c) in enumerate(A)
        res[i, j] = string(r, c)
    end
end
display(res)

# Test outer product using einstein summation
A = rand(length(A), 100) # Last dim is batch
batchmul(A, B) = @tullio C[i,j,k] := A[i,k] * B[j,k]
# batchmul(A, B) = ein"ik,jk->ijk"(A, B)
outer_prod(A, B) = reshape(batchmul(A, B), size(A, 1)*size(B, 1), size(A, 2))
@show reshape(outer_prod(A, A), 4, 4, :) == batchmul(A, A)
(loss,), back = pullback(p -> sum(outer_prod(p, p)), A)
gs = back((one(loss)))[1]
display(gs)

# Cuda
A_cu = CuArray(Float32.(A))
(loss,), back = pullback(p -> sum(outer_prod(p, p)), A_cu)
gs = back((one(loss)))[1]
display(gs)

Iā€™m trying compute polynomial terms (represented by a vector, is there a package for this?) that can be automatically differentiated. In this code, Iā€™m trying to compute the terms by applying the outer product on two vectors. There is an extra dimension for the batch dimension for the neural network (Lux) composability case. Using OMEinsum with CUDA gives consistent and correct results.

Problem: The pullback gives different results when I use CUDA with Tullio.

1 Like

Seems like a bug, probably open an issue on GitHub.

But you can rewrite it like this, I believe:

julia> mybatchmul(A, B) = PermutedDimsArray(A .* 
                             PermutedDimsArray(reshape(B, (1, size(B, 1), size(B, 2))), 
                                              (1, 3, 2)),
                                            (1, 3, 2))
mybatchmul (generic function with 1 method)

julia> mybatchmul(A, B)
2Ɨ2Ɨ3 PermutedDimsArray(::Array{Float64, 3}, (1, 3, 2)) with eltype Float64:
[:, :, 1] =
 -0.0758972  1.16069
 -0.0902016  1.37944

[:, :, 2] =
  0.025559   0.248286
 -0.147447  -1.43234

[:, :, 3] =
  0.073602  -0.331459
 -0.062974   0.283597

julia> batchmul(A, B)
2Ɨ2Ɨ3 Array{Float64, 3}:
[:, :, 1] =
 -0.0758972  1.16069
 -0.0902016  1.37944

[:, :, 2] =
  0.025559   0.248286
 -0.147447  -1.43234

[:, :, 3] =
  0.073602  -0.331459
 -0.062974   0.283597

julia> A_cu
4Ɨ100 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.545619  0.822741  0.0427121  0.28609   0.241786  0.548029   0.299155  0.344738   ā€¦  0.973161  0.152359  0.543345  0.785741  0.927969   0.231272  0.430827
 0.686676  0.536921  0.613779   0.74993   0.86244   0.0551125  0.75397   0.0116247     0.598434  0.448246  0.729716  0.183434  0.0310681  0.200019  0.910467
 0.347076  0.928443  0.204908   0.189597  0.551348  0.393395   0.561077  0.954183      0.307319  0.766647  0.410037  0.656858  0.456261   0.439622  0.608114
 0.180162  0.131455  0.152134   0.966236  0.112832  0.587855   0.870898  0.923121      0.258794  0.483758  0.669112  0.934022  0.293671   0.516583  0.901604

julia> gradient(p -> sum(batchmul(p, p)), A_cu) # wrong
(Float32[1.7779145 2.182403 ā€¦ 0.6625629 1.7721217; 1.7779145 2.182403 ā€¦ 0.6625629 1.7721217; 1.7779145 2.182403 ā€¦ 0.6625629 1.7721217; 1.7779145 2.182403 ā€¦ 0.6625629 1.7721217],)

julia> gradient(p -> sum(mybatchmul(p, p)), A_cu)
(Float32[3.5190659 4.83912 ā€¦ 2.774992 5.7020245; 3.5190659 4.83912 ā€¦ 2.774992 5.7020245; 3.5190659 4.83912 ā€¦ 2.774992 5.7020245; 3.5190659 4.83912 ā€¦ 2.774992 5.7020245],)

julia> gradient(p -> sum(mybatchmul(p, p)), Array(A_cu))
(Float32[3.519066 4.83912 ā€¦ 2.774992 5.7020245; 3.519066 4.83912 ā€¦ 2.774992 5.7020245; 3.519066 4.83912 ā€¦ 2.774992 5.7020245; 3.519066 4.83912 ā€¦ 2.774992 5.7020245],)

EDIT:
@mcabbott had a shorter and faster solution:

julia> mybatchmul(A, B) = PermutedDimsArray(A .* 
                                    PermutedDimsArray(reshape(B, (1, size(B, 1), size(B, 2))), 
                                                     (1, 3, 2)),
                                                   (1, 3, 2))
mybatchmul (generic function with 1 method)

julia> bcmul(A, B) = reshape(A, size(A,1), 1, :) .* reshape(B, 1, size(B)...);

julia> A = rand(Float32, 4, 2); B = rand(Float32, 4, 2);

julia> bcmul(A, B) ā‰ˆ mybatchmul(A, B)
true
1 Like