Incorporating SVD factorization into GPUs with CUDA

I’m trying to run the following code to optimize the spectral radius of the weight matrices of a neural network:

using CUDA, Random, Lux, ComponentArrays, Optimization, OptimizationOptimisers, OptimizationOptimJL, LinearAlgebra
rng = Random.default_rng()

nn = Lux.Chain(Lux.Dense(5,10,tanh),Lux.Dense(10,5))
pinit,st = Lux.setup(rng,nn)

st = st |> gpu
p64 = Float64.(Lux.gpu(ComponentArray(pinit)))

function snorm(X)
    return CUDA.@allowscalar svd(X).S[1]
end

function loss(p)
	W1 = CUDA.@allowscalar reshape(p.layer_1[1:(end-length(pinit.layer_1[2]))],size(pinit.layer_1[1]))
	W2 = CUDA.@allowscalar reshape(p.layer_1[1:(end-length(pinit.layer_1[2]))],size(pinit.layer_1[1]))
	return CUDA.@allowscalar snorm(W1)*snorm(W2)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf,p64)

Optimization.solve(optprob,ADAM(0.005),maxiters = 10)

This code works correctly when running in CPU. However, when trying to run on GPU, I get the following error:

ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.

Where else should I specify to allow scalars? I have put the @allowscalar option anywhere it might be used in code.

This looks like a missing AD rule or bit of CUDA ↔ LinearAlgebra integration. Hard to tell without the full stacktrace though, so make sure to provide that. I’d also recommend creating a MWE which only uses Zygote and CUDA, because the other libraries here add a lot of extra layers to sift though and thus may slow down/prevent a good solution from being found.

Thanks for the input. Still getting used to the Zygote inner functions. I managed to reproduce the error only with CUDA, LinearAlgebra, and Zygote.

using CUDA, LinearAlgebra, Zygote

function snorm(X)
    return CUDA.@allowscalar svd(X).S[1]
end

dL(W) = gradient(X->snorm(X),W)
dL(CUDA.rand(3,2))

Yields:

ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/6STCb/src/host/indexing.jl:9 [inlined]
  [4] svd_rev(USV::SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Ū::ChainRulesCore.ZeroTangent, s̄::Zygote.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, V̄::ChainRulesCore.ZeroTangent)
    @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/LinearAlgebra/factorization.jl:256
  [5] _svd_pullback
    @ ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/LinearAlgebra/factorization.jl:219 [inlined]
  [6] svd_pullback
    @ ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/LinearAlgebra/factorization.jl:225 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}})(dy::NamedTuple{(:U, :S, :Vt), Tuple{Nothing, Zygote.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:211
  [8] Pullback
    @ ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:125 [inlined]
  [9] (::Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#ad_pullback#50"{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:263
 [11] (::ChainRules.var"#task_local_storage_pullback#1257"{Zygote.var"#ad_pullback#50"{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}}}})(dy::Float32)
    @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/base.jl:261
 [12] (::Zygote.ZBack{ChainRules.var"#task_local_storage_pullback#1257"{Zygote.var"#ad_pullback#50"{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}}}}})(dy::Float32)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:211
 [13] macro expansion
    @ ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:124 [inlined]
 [14] Pullback
    @ ~/NODE_Community_Forecast/test.jl:4 [inlined]
 [15] (::Zygote.Pullback{Tuple{typeof(snorm), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{ChainRules.var"#task_local_storage_pullback#1257"{Zygote.var"#ad_pullback#50"{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}}}}}, Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, GPUArraysCore.ScalarIndexing}}, Zygote.var"#2100#back#226"{Zygote.Jnew{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Nothing, false}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/NODE_Community_Forecast/test.jl:7 [inlined]
 [17] (::Zygote.var"#60#61"{Zygote.Pullback{Tuple{var"#3#4", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(snorm), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{ChainRules.var"#task_local_storage_pullback#1257"{Zygote.var"#ad_pullback#50"{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.ZBack{ChainRules.var"#svd_pullback#2113"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#getproperty_svd_pullback#2114"{SVD{Float32, Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Symbol}}, Zygote.var"#2077#back#218"{Zygote.var"#back#217"{:X, Zygote.Context{false}, var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Val{1}}, Tuple{Zygote.var"#2427#back#375"{Zygote.var"#379#381"{1, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Int64}}}}}}}}}}, Zygote.var"#1923#back#149"{Zygote.var"#147#148"{Zygote.Context{false}, GlobalRef, GPUArraysCore.ScalarIndexing}}, Zygote.var"#2100#back#226"{Zygote.Jnew{var"#1#2"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Nothing, false}}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:45
 [18] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:97
 [19] dL(W::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Main ~/NODE_Community_Forecast/test.jl:7
 [20] top-level scope
    @ ~/NODE_Community_Forecast/test.jl:8
in expression starting at /home/jarroyoesquivel/NODE_Community_Forecast/test.jl:8
srun: error: vgpu-002: task 0: Exited with exit code 1

Besides the error – is there any reason you’re doing this by computing the full SVD, rather than using power iterations? (Which is probably much more efficient for reasonably-sized matrices.)

I tried using the opnorm function, which would be the most efficient scenario. Sadly, Zygote doesn’t have chain rule rules for it. This link suggested using SVD instead.

Power iterations may be an alternative, but I’d be concerned in convergence speed.

opnorm just computes an SVD (it sucks, see use power iteration for opnorm2 by oscardssmith · Pull Request #49487 · JuliaLang/julia · GitHub). Power iteration will almost certainly be faster unless you have very similar sized singular values.

Thanks, that benchmark is something I wasn’t looking at. I changed from SVD to power iteration:

using CUDA, LinearAlgebra, Zygote

function loss(X,niters=10)
        x = rand(size(X,2))
        A = X'*X
        y = x
	for i in 1:niters
                CUDA.@allowscalar y = A*y
                CUDA.@allowscalar y = y/norm(y)
        end
	return CUDA.@allowscalar dot(y,A*y)/dot(y,y)
end

dL(W) = gradient(X->loss(X),W)
dL(CUDA.rand(3,2))

Still getting the error.

ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:103
  [3] getindex(xs::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/6STCb/src/host/indexing.jl:9
  [4] generic_matvecmul!(C::Vector{Float64}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::Vector{Float64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /carnegie/binaries/centos7/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:791
  [5] mul!
    @ /carnegie/binaries/centos7/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:115 [inlined]
  [6] mul!
    @ /carnegie/binaries/centos7/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
  [7] *
    @ /carnegie/binaries/centos7/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:101 [inlined]
  [8] (::ChainRules.var"#1478#1481"{Vector{Float64}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/arraymath.jl:37
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
 [10] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:110 [inlined]
 [11] map
    @ ./tuple.jl:223 [inlined]
 [12] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:111 [inlined]
 [13] (::Zygote.ZBack{ChainRules.var"#times_pullback#1479"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}})(dy::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:211
 [14] Pullback
    @ ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:125 [inlined]
 [15] (::Zygote.Pullback{Tuple{var"#3#6"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#ad_pullback#50"{Tuple{var"#3#6"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#3#6"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Any}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:263
 [17] task_local_storage_pullback
    @ ~/.julia/packages/ChainRules/bEtjZ/src/rulesets/Base/base.jl:261 [inlined]
 [18] (::Zygote.ZBack{ChainRules.var"#task_local_storage_pullback#1257"{Zygote.var"#ad_pullback#50"{Tuple{var"#3#6"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{var"#3#6"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Any}}}})(dy::Float64)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/chainrules.jl:211
 [19] macro expansion
    @ ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:124 [inlined]
 [20] Pullback
    @ ~/NODE_Community_Forecast/test.jl:11 [inlined]
 [21] (::Zygote.Pullback{Tuple{typeof(loss), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Int64}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/NODE_Community_Forecast/test.jl:4 [inlined]
 [23] Pullback
    @ ~/NODE_Community_Forecast/test.jl:14 [inlined]
 [24] (::Zygote.var"#60#61"{Zygote.Pullback{Tuple{var"#7#8", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Int64}, Any}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:45
 [25] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:97
 [26] dL(W::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Main ~/NODE_Community_Forecast/test.jl:14
 [27] top-level scope
    @ ~/NODE_Community_Forecast/test.jl:15
in expression starting at /home/jarroyoesquivel/NODE_Community_Forecast/test.jl:15
srun: error: vgpu-002: task 0: Exited with exit code 1

The problem is x (and y) are CPU Float64 arrays, while X is a Float32 GPU array. You want:

function loss(X, niters=10)
    y = randn!(similar(X, size(X, 2)))
    tmp = X * y
    for i in 1:niters
        tmp = X*y
        tmp = tmp / norm(tmp)
        y = X' * tmp
        y = y / norm(y)
    end
    return norm(X*y)
end

This makes sure your temporaries are all on the GPU (and are of the same eltype as X). For a 300x200 CuArray, I’m seeing this take 0.003 seconds for the dL function (using an RTX 2060)

1 Like