Unable to calculate the gradient of opnorm/svdvals/svd on gpu?

Hello,
I’m able to calculate the gradients of the spectral norm in CPU but not in GPU :disappointed:

This works fine :

julia> using CUDA, Zygote, LinearAlgebra

julia> X = randn(Float32, 5, 1);

julia> function snorm(W::AbstractArray)
           S = svdvals(W)
           return maximum(S)
       end
snorm (generic function with 1 method)

julia> Zygote.gradient(x -> snorm(x), X)
(Float32[0.06287575; -0.03134879; … ; -0.20840532; -0.61453784;;],)

If I move X to gpu then it errors out :

julia> using Flux

julia> X = X |> gpu
5×1 CuArray{Float32, 2, CUDA.DeviceMemory}:
  0.12716164
 -0.06340073
  1.5322151
 -0.42148516
 -1.2428597

julia> Zygote.gradient(x -> snorm(x), X)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
  [5] getindex
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:50 [inlined]
  [6] scalar_getindex
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:36 [inlined]
  [7] _getindex
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:19 [inlined]
  [8] getindex
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/indexing.jl:17 [inlined]
  [9] macro expansion
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:351 [inlined]
 [10] macro expansion
    @ ./simdloop.jl:77 [inlined]
 [11] __muldiag!(out::CuArray{Float32, 2, CUDA.DeviceMemory}, D::Diagonal{Float32, ChainRules.OneElement{…}}, B::CuArray{Float32, 2, CUDA.DeviceMemory}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:350
 [12] _mul_diag!
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:424 [inlined]
 [13] _mul!
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:430 [inlined]
 [14] _mul!
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/bidiag.jl:434 [inlined]
 [15] mul!
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [16] mul!
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [17] *
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [18] _tri_matmul(A::CuArray{Float32, 2, CUDA.DeviceMemory}, B::Diagonal{Float32, ChainRules.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}, C::CuArray{Float32, 2, CUDA.DeviceMemory}, δ::Nothing)
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:1142
 [19] _tri_matmul
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:1136 [inlined]
 [20] *
    @ ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:1132 [inlined]
 [21] svdvals_pullback
    @ ~/.julia/packages/ChainRules/H7bwg/src/rulesets/LinearAlgebra/factorization.jl:290 [inlined]
 [22] ZBack
    @ ~/.julia/packages/Zygote/kdCjv/src/compiler/chainrules.jl:222 [inlined]
 [23] snorm
    @ ./REPL[4]:2 [inlined]
 [24] #3
    @ ./REPL[9]:1 [inlined]
 [25] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{var"#3#4", CuArray{Float32, 2, CUDA.DeviceMemory}}, Tuple{Zygote.Pullback{Tuple{typeof(snorm), CuArray{…}}, Tuple{Zygote.ZBack{…}, Zygote.ZBack{…}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/kdCjv/src/compiler/interface.jl:97
 [26] gradient(f::Function, args::CuArray{Float32, 2, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/kdCjv/src/compiler/interface.jl:154
 [27] top-level scope
    @ REPL[9]:1

If I don’t use the maximum of the singular values but the sum, it works on gpu:

julia> function snorm(W::AbstractArray)
           S = svdvals(W)
           return sum(S)
       end
snorm (generic function with 1 method)

julia> X = X |> gpu
5×1 CuArray{Float32, 2, CUDA.DeviceMemory}:
  0.12716164
 -0.06340073
  1.5322151
 -0.42148516
 -1.2428597

julia> Zygote.gradient(x -> snorm(x), X)
(Float32[0.06287575; -0.03134879; … ; -0.20840532; -0.61453784;;],)

However the sum is not something I want here :slightly_smiling_face: . Do you know how I can make it work ? Thanks

This question is related to : Implementation of Spectral Normalization for Machine Learning

The issue you have here is that the derivative of the max function is not implemented in a way that works on the GPU. Not a big problem, though, as Julia’s svd functions deliver the singular values in descending order, meaning the snorm is svdvals(X)[1]. However, scalar indexing is disallowed so you still have to jump through a little bit of a hoop.
Way I see it, you have two simple ways to solve this and implement spectral normalization:

  1. Use a view to get the spectral norm:
function snorm(X::AbstractMatrix{<:Real})
     s = svdvals(X)
     return view(s, 1)
end

The one issue is, your output will not be a scalar but a rank zero array. However, it will behave identically to a scalar by doing broadcasting. Then, running Zygote.jacobian does give me the gradient no problem, so you can normalize your outputs by just doing

function spectral_normalization(X::AbstractMatrix{<:Real})
     return X ./ snorm(X)
end

and I have tested, Zygote.gradient will work! Zygote is the standard backend used by Flux/Lux, so you should be able to train no problems.
2. It is a scalar function, you can write up your own snorm and dsnorm functions, then connect them with ChainRules.jl’s @scalar_rule macro and ForwardDiff.jl’s Dual numbers. This is the least error prone way to go, I’d imagine, but it is a bit more involved. Tell me if you have issues with the first method.

1 Like

Thanks for the reply ! I got it to work with your solution adding a sum over the view :

# In CPU 
julia> using Flux, CUDA, Zygote, LinearAlgebra

julia> X = randn(Float32, 4, 2);

julia> Y = randn(Float32, 1, 2);

julia> model = Dense(4=>1)
Dense(4 => 1)       # 5 parameters

julia> function snorm(W::AbstractArray)
           S = svdvals(W)
           return view(S,1)
       end
snorm (generic function with 1 method)

julia> function loss(model, X, Y)
           reg = sum(snorm(model.weight))
           Flux.mse(model(X),Y) + reg
       end
loss (generic function with 1 method)

julia> g = Zygote.gradient(m -> loss(m, X, Y), model)
((weight = Float32[-6.5172167 0.67096055 1.8280481 -0.010930598], bias = Float32[3.718598], σ = nothing),)
# In GPU
julia> X = X |> gpu
4×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
 0.295092  -1.28734
 0.195742   0.136765
 0.875829   1.96808
 0.577029   1.07548

julia> Y = Y |> gpu
1×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
 -0.598539  -2.2533

julia> model = model |> gpu
Dense(4 => 1)       # 5 parameters

julia> g = Zygote.gradient(m -> loss(m, X, Y), model)
((weight = Float32[-3.6229815 -0.3276634 4.686322 1.8717637], bias = Float32[2.225389], σ = nothing),)

The weird thing is the gradient calculated in cpu is not the same as the one on gpu. I guess this is not normal.

Definitely not normal. I will try running your mwe here and try to debug what is going on. On my end I tried looking at the jacobian and whether I use sum or nothing at all, the gradients are still pretty close…
EDIT: On my side I got identical results, doing the same as you, so I am not sure where the discrepancy is coming from.
Could you try just running this one on you REPL and see if anything changes?

using Flux, CUDA, Zygote, LinearAlgebra

X = randn(Float32, 4, 2);
Y = randn(Float32, 1, 2);

model = Flux.Dense(4=>1)

function snorm(W::AbstractArray)
    S = svdvals(W)
    return view(S,1)
end

function loss(model, X, Y)
    reg = sum(snorm(model.weight))
    Flux.mse(model(X),Y) + reg
end

g = first(Zygote.gradient(m -> loss(m, X, Y), model))

X_gpu = X|>gpu
Y_gpu = Y|>gpu
model_gpu = model|>gpu

g_gpu = first(Zygote.gradient(m -> loss(m, X_gpu, Y_gpu), model_gpu))

g_cpu = g_gpu|>cpu
maximum(abs, g.weight - g_cpu.weight)# 0f0 on my machine
maximum(abs, g.bias - g_cpu.bias)# 0f0 on my machine

It does not work anymore and I don’t know why :

julia> X = randn(Float32, 4, 2);

julia> Y = randn(Float32, 1, 2);

julia> model = Flux.Dense(4=>1)
Dense(4 => 1)       # 5 parameters

julia> function snorm(W::AbstractArray)
           S = svdvals(W)
           return view(S,1)
       end
snorm (generic function with 1 method)

julia> function loss(model, X, Y)
           reg = sum(snorm(model.weight))
           Flux.mse(model(X),Y) + reg
       end
loss (generic function with 1 method)

julia> g = first(Zygote.gradient(m -> loss(m, X, Y), model))
(weight = Float32[-1.302959 1.2842926 -2.1166115 -1.1868622], bias = Float32[2.390978], σ = nothing)

julia> X_gpu = X|>gpu
4×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
 -0.969246  -0.266507
  0.512716   0.23067
  0.196472  -0.79053
  1.38998   -1.03224

julia> Y_gpu = Y|>gpu
1×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
 -0.241278  -1.15451

julia> model_gpu = model|>gpu
Dense(4 => 1)       # 5 parameters

julia> g_gpu = first(Zygote.gradient(m -> loss(m, X_gpu, Y_gpu), model_gpu))
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}}}, which is not a bitstype:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}} which is not isbits.
      .x is of type Array{Float32, 0} which is not isbits.
        .ref is of type MemoryRef{Float32} which is not isbits.
          .mem is of type Memory{Float32} which is not isbits.


Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.

Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/validation.jl:108
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:87 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Tracy/GcShf/src/tracepoint.jl:158 [inlined]
  [4] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:85
  [5] compile_unhooked
    @ ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:80 [inlined]
  [6] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:67
  [7] compile
    @ ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:55 [inlined]
  [8] #1181
    @ ~/.julia/packages/CUDA/ja0IX/src/compiler/compilation.jl:250 [inlined]
  [9] JuliaContext(f::CUDA.var"#1181#1184"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:34
 [10] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/driver.jl:25
 [11] compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/ja0IX/src/compiler/compilation.jl:249
 [12] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/execution.jl:245
 [13] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Emuht/src/execution.jl:159
 [14] macro expansion
    @ ~/.julia/packages/CUDA/ja0IX/src/compiler/execution.jl:373 [inlined]
 [15] macro expansion
    @ ./lock.jl:273 [inlined]
 [16] cufunction(f::GPUArrays.var"#gpu_broadcast_kernel_linear#38", tt::Type{Tuple{…}}; kwargs::@Kwargs{always_inline::Bool, maxthreads::Nothing})
    @ CUDA ~/.julia/packages/CUDA/ja0IX/src/compiler/execution.jl:368
 [17] macro expansion
    @ ~/.julia/packages/CUDA/ja0IX/src/compiler/execution.jl:112 [inlined]
 [18] (::KernelAbstractions.Kernel{…})(::CuArray{…}, ::Vararg{…}; ndrange::Tuple{…}, workgroupsize::Nothing)
    @ CUDA.CUDAKernels ~/.julia/packages/CUDA/ja0IX/src/CUDAKernels.jl:122
 [19] _copyto!
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/broadcast.jl:71 [inlined]
 [20] materialize!
    @ ~/.julia/packages/GPUArrays/uiVyU/src/host/broadcast.jl:38 [inlined]
 [21] materialize!
    @ ./broadcast.jl:880 [inlined]
 [22] ∇getindex!(dx::CuArray{Float32, 1, CUDA.DeviceMemory}, dy::Array{Float32, 0}, inds::UnitRange{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/H7bwg/src/rulesets/Base/indexing.jl:180
 [23] ∇getindex(x::CuArray{Float32, 1, CUDA.DeviceMemory}, dy::Array{Float32, 0}, inds::UnitRange{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/H7bwg/src/rulesets/Base/indexing.jl:89
 [24] #672
    @ ~/.julia/packages/ChainRules/H7bwg/src/rulesets/Base/indexing.jl:69 [inlined]
 [25] unthunk
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:213 [inlined]
 [26] unthunk
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:252 [inlined]
 [27] (::ChainRules.var"#svdvals_pullback#1187"{ChainRulesCore.ProjectTo{…}, CuArray{…}, CuArray{…}})(s::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#671#673"{…}})
    @ ChainRules ~/.julia/packages/ChainRules/H7bwg/src/rulesets/LinearAlgebra/factorization.jl:289
 [28] ZBack
    @ ~/.julia/packages/Zygote/kdCjv/src/compiler/chainrules.jl:222 [inlined]
 [29] snorm
    @ ./REPL[7]:2 [inlined]
 [30] loss
    @ ./REPL[8]:2 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/kdCjv/src/compiler/interface2.jl:0
 [32] #3
    @ ./REPL[13]:1 [inlined]
 [33] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{var"#3#4", Dense{…}}, Tuple{Zygote.var"#2006#back#208"{…}, Zygote.Pullback{…}, Zygote.var"#2006#back#208"{…}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/kdCjv/src/compiler/interface.jl:97
 [34] gradient(f::Function, args::Dense{typeof(identity), CuArray{Float32, 2, CUDA.DeviceMemory}, CuArray{Float32, 1, CUDA.DeviceMemory}})
    @ Zygote ~/.julia/packages/Zygote/kdCjv/src/compiler/interface.jl:154
 [35] top-level scope
    @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.

This is my machine :

julia> versioninfo()
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) W-11955M CPU @ 2.60GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 16 virtual cores)

julia> CUDA.versioninfo()
CUDA runtime 12.9, artifact installation
CUDA driver 12.9
NVIDIA driver 550.144.3

CUDA libraries:
- CUBLAS: 12.9.0
- CURAND: 10.3.10
- CUFFT: 11.4.0
- CUSOLVER: 11.7.4
- CUSPARSE: 12.5.9
- CUPTI: 2025.2.0 (API 27.0.0)
- NVML: 12.0.0+550.144.3

Julia packages:
- CUDA: 5.8.2
- CUDA_Driver_jll: 0.13.1+0
- CUDA_Runtime_jll: 0.17.0+0

Toolchain:
- Julia: 1.11.4
- LLVM: 16.0.6

1 device:
  0: NVIDIA RTX A5000 Laptop GPU (sm_86, 15.382 GiB / 16.000 GiB available)

The env is minimal :

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
  [587475ba] Flux v0.16.4
  [e88e6eb3] Zygote v0.7.9
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0

That is very odd, I just ran your code on my machine line-by-line and no problems occurred. My versions for the concerned libraries are almost exactly the same as yours. I am running Julia 1.11.5, but this should not break anything because I’ve been using this svdvals + view idea for some time…

I’m narrowing down the problem :

I updated julia to 1.11.5 just in case.

It works with Flux 1.14 & Zygote v0.6.77 :

(tmp2) pkg> st
Status `~/.julia/dev/tmp2/Project.toml`
  [052768ef] CUDA v5.8.2
⌅ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.77
  [02a925ec] cuDNN v1.4.3
  [37e2e46d] LinearAlgebra v1.11.0

[...]
julia> g_gpu = first(Zygote.gradient(m -> loss(m, X_gpu, Y_gpu), model_gpu))
(weight = Float32[1.2598536 -2.5608957 -5.783833 -1.5805628], bias = Float32[2.2775662], σ = nothing)

julia> g_cpu = g_gpu|>cpu
(weight = Float32[1.2598536 -2.5608957 -5.783833 -1.5805628], bias = Float32[2.2775662], σ = nothing)

julia> maximum(abs, g.weight - g_cpu.weight) 
4.7683716f-7

julia> maximum(abs, g.bias - g_cpu.bias)
0.0f0

It also works with Flux v0.15.2 & Zygote v0.6.77

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
⌅ [587475ba] Flux v0.15.2
⌅ [e88e6eb3] Zygote v0.6.77
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

[...]

julia> g_gpu = first(Zygote.gradient(m -> loss(m, X_gpu, Y_gpu), model_gpu))
(weight = Float32[-0.6557447 1.9090092 0.6884821 -0.8441588], bias = Float32[-2.0864716], σ = nothing)

julia> g_cpu = g_gpu|>cpu
(weight = Float32[-0.6557447 1.9090092 0.6884821 -0.8441588], bias = Float32[-2.0864716], σ = nothing)

julia> maximum(abs, g.weight - g_cpu.weight)
5.9604645f-8

julia> maximum(abs, g.bias - g_cpu.bias)
0.0f0

It does not work with Flux v0.16.4 & Zygote v0.7.9

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
  [587475ba] Flux v0.16.4
  [e88e6eb3] Zygote v0.7.9
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0

[...]

julia> g_gpu = first(Zygote.gradient(m -> loss(m, X_gpu, Y_gpu), model_gpu))
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}}}, which is not a bitstype:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}} which is not isbits.
      .x is of type Array{Float32, 0} which is not isbits.
        .ref is of type MemoryRef{Float32} which is not isbits.
          .mem is of type Memory{Float32} which is not isbits.


Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.

So I guess Zygote v0.7.9 is the problem ?

1 Like

I can confirm at least that my combination here (Flux v0.16.4 + Zygote v0.6.77) is not yielding any issues, so I do think you isolated the problem correctly.
EDIT: apologies for the typos, it is late here. Anyway, this is probably sufficient for training right now.

1 Like

Just adding some more debugging:

It does not work with Flux v0.16.4 & Zygote v0.7.8 :

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
  [587475ba] Flux v0.16.4
⌅ [e88e6eb3] Zygote v0.7.8
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0

It does not work with Flux v0.16.4 & Zygote v0.7.7 :

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
  [587475ba] Flux v0.16.4
⌅ [e88e6eb3] Zygote v0.7.7
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0

I think the solution is to go back to using maximum but with the view :

With this setup (latest Flux and Zygote) :

(tmp) pkg> st
Status `~/.julia/dev/tmp/Project.toml`
  [052768ef] CUDA v5.8.2
  [082447d4] ChainRules v1.72.4
  [587475ba] Flux v0.16.4
  [e88e6eb3] Zygote v0.7.9
  [02a925ec] cuDNN v1.4.3
  [76a88914] CUDA_Runtime_jll v0.17.0+0
  [37e2e46d] LinearAlgebra v1.11.0
julia> using Flux, CUDA, Zygote, LinearAlgebra

julia> X = randn(Float32, 4, 2);

julia> Y = randn(Float32, 1, 2);

julia> model = Flux.Dense(4=>1)
Dense(4 => 1)       # 5 parameters

julia> function snorm(W::AbstractArray)
           S = svdvals(W)
           return view(S,1) # important 
       end
snorm (generic function with 1 method)

julia> function loss1(model, X, Y)
           reg = sum(snorm(model.weight))
           Flux.mse(model(X),Y) + reg
       end # this will work only on CPU
loss1 (generic function with 1 method)

julia> function loss2(model, X, Y)
           reg = maximum(snorm(model.weight))
           Flux.mse(model(X),Y) + reg
       end # This will work on both
loss2 (generic function with 1 method)

julia> g = first(Zygote.gradient(m -> loss1(m, X, Y), model))
(weight = Float32[4.0736847 -2.1274047 -0.38781103 0.3700309], bias = Float32[3.694799], σ = nothing)

julia> g = first(Zygote.gradient(m -> loss2(m, X, Y), model))
(weight = Float32[4.0736847 -2.1274047 -0.38781103 0.3700309], bias = Float32[3.694799], σ = nothing)

julia> X_gpu = X|>gpu
4×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
  1.10269    0.656666
 -0.426718  -0.572763
 -0.475849   0.668518
 -0.642911   0.802882

julia> Y_gpu = Y|>gpu
1×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
 -1.0348  -0.511742

julia> model_gpu = model|>gpu
Dense(4 => 1)       # 5 parameters

julia> g_gpu = first(Zygote.gradient(m -> loss2(m, X_gpu, Y_gpu), model_gpu))
(weight = Float32[4.0736847 -2.1274047 -0.3878111 0.37003076], bias = Float32[3.6947992], σ = nothing)

julia> g_gpu = first(Zygote.gradient(m -> loss1(m, X_gpu, Y_gpu), model_gpu))
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}}}, which is not a bitstype:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Array{Float32, 0}, Tuple{}, Tuple{}} which is not isbits.
      .x is of type Array{Float32, 0} which is not isbits.
        .ref is of type MemoryRef{Float32} which is not isbits.
          .mem is of type Memory{Float32} which is not isbits.


Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.

Thanks for the help :slight_smile: