Lux, optimization on gpu

Hello I try to perform the example from Train NN with Sophia [1] on GPU

I tried to get sample working on gpu by

using CUDA, LuxCUDA

x = rand(10,10)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 1)

# Define the neural network
model = Chain(Dense(10, 32, tanh), Dense(32, 1))
dev= gpu_device() 
ps, st = Lux.setup(Random.default_rng(), model) |> dev
CUDA.@allowscalar ps_ca = ComponentArray(ps)

smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
    state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
    return l < 1e-1 ## Terminate if loss is small
end

function loss(ps, data)
    ypred=smodel(CuArray(data), ps)
    return sum(abs2, ypred .- CuArray(data[2]))
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)```

but it give error

ERROR: MethodError: no method matching CuArray(::Tuple{Matrix{Float64}, Matrix{Float64}})
The type `CuArray` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  CuArray(::Function, Any...)
   @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/array.jl:159
  CuArray(::LinearAlgebra.QRPivoted)
   @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/lib/cusolver/linalg.jl:149
  CuArray(::LinearAlgebra.Diagonal{T, <:Vector{T}}) where T
   @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/lib/cublas/linalg.jl:309
  ...

Stacktrace:
  [1] macro expansion
    @ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Type{CuArray}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87
  [3] loss
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:22 [inlined]
  [4] _pullback(::Zygote.Context{…}, ::typeof(loss), ::ComponentVector{…}, ::Tuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
  [5] pullback(::Function, ::Zygote.Context{false}, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
  [6] pullback(::Function, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Tuple{Matrix{…}, Matrix{…}})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
  [7] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
  [8] value_and_gradient
    @ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:81 [inlined]
  [9] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:99
 [10] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…}, p::Tuple{…})
    @ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:58
 [11] __solve(cache::OptimizationCache{…})
    @ Optimization /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:81
 [12] solve!(cache::OptimizationCache{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:186
 [13] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
 [14] top-level scope
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:29
Some type information was truncated. Use `show(err)` to see complete types.

Although “ps” is casted on CUDA

Your data is a tuple, so you need to broadcast that, CuArray.(data)

You then probably want to reuse it for CuArray(data[2])

1 Like

The best way here would be to do data = DataLoader(....) |> dev. This constructs a MLDataDevices | Lux.jl Docs which automatically manages data movement and free-ing of data

1 Like

Thanks @ChrisRackauckas and @avikpal !!yes here I made an error ! Hovewer the code

using Optimization, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
using CUDA, LuxCUDA

x = rand(10,10)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 1)

# Define the neural network
model = Chain(Dense(10, 32, tanh), Dense(32, 1))
dev= gpu_device() 
ps, st = Lux.setup(Random.default_rng(), model) |> dev
CUDA.@allowscalar ps_ca = ComponentArray(ps)

smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
    state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
    return l < 1e-1 ## Terminate if loss is small
end

function loss(ps, data)
    d1=CuArray(data[1])
    d2=CuArray(data[2])
    print("\n typeof(d1) $(typeof(d1)) typeof(d2) $(typeof(d1)) \n")
    ypred=smodel(d1, ps)
    return sum(abs2, ypred .- d2)
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)

give error

ERROR: ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
  [1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:127
  [2] macro expansion
    @ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:211 [inlined]
  [3] unrolled_mapreduce
    @ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:198 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:189
  [5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:161
  [6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ MLDataDevices /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/public.jl:372
  [7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ LuxLib /usr/local/share/julia/packages/LuxLib/wAt3f/src/traits.jl:210
  [8] select_fastest_activation(::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:128
  [9] rrule(::typeof(LuxLib.Impl.select_fastest_activation), ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:138
 [10] rrule(::Zygote.ZygoteRuleConfig{…}, ::Function, ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ ChainRulesCore /usr/local/share/julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
 [11] chain_rrule
    @ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
 [12] macro expansion
    @ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::typeof(LuxLib.Impl.select_fastest_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87
 [14] fused_dense_bias_activation
    @ /usr/local/share/julia/packages/LuxLib/wAt3f/src/api/dense.jl:35 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::typeof(fused_dense_bias_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [16] Dense
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/basic.jl:343 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [18] apply
    @ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [20] applychain
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:0 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [22] Chain
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:480 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [24] apply
    @ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [26] StatefulLuxLayer
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/helpers/stateful.jl:119 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::StatefulLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [28] loss
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:25 [inlined]
 [29] _pullback(::Zygote.Context{…}, ::typeof(loss), ::ComponentVector{…}, ::Tuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [30] pullback(::Function, ::Zygote.Context{false}, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
 [31] pullback(::Function, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Tuple{Matrix{…}, Matrix{…}})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
 [32] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
 [33] value_and_gradient
    @ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:81 [inlined]
 [34] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:99
 [35] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…}, p::Tuple{…})
    @ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:58
 [36] __solve(cache::OptimizationCache{…})
    @ Optimization /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:81
 [37] solve!(cache::OptimizationCache{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:186
 [38] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
 [39] top-level scope
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:32
Some type information was truncated. Use `show(err)` to see complete types.

It is most probably related to the fact that component array for some reson is not on CUDA

julia> ps_ca
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.57011324 0.638751 … -0.19225933 -0.12354956; 0.20622618 -0.12579119 … 0.644531 0.44288114; … ; 0.69624126 0.7610798 … 0.34806055 -0.15016149; 0.08788869 -0.011543258 … 0.256488 -0.7310434], bias = Float32[-0.11362305, 0.10988254, 0.23947266, 0.29654834, 0.19410786, 0.029232662, 0.19418657, 0.2171803, 0.056266934, -0.074235305  …  0.07852854, -0.04083943, -0.1578278, -0.20590375, -0.22106005, -0.19504064, 0.09862628, 0.095733166, 0.29774165, 0.06439602]), layer_2 = (weight = Float32[-0.12554047 0.004761639 … -0.06296449 -0.29443085], bias = Float32[0.09707039]))

Although it is created from ps that is on GPU.
Morover even when I manually cast is to CuArray like

CUDA.@allowscalar ps_ca["layer_1"]["weight"]=CuArray(ps_ca["layer_1"]["weight"])
       ps_ca
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.57011324 0.638751 … -0.19225933 -0.12354956; 0.20622618 -0.12579119 … 0.644531 0.44288114; … ; 0.69624126 0.7610798 … 0.34806055 -0.15016149; 0.08788869 -0.011543258 … 0.256488 -0.7310434], bias = Float32[-0.11362305, 0.10988254, 0.23947266, 0.29654834, 0.19410786, 0.029232662, 0.19418657, 0.2171803, 0.056266934, -0.074235305  …  0.07852854, -0.04083943, -0.1578278, -0.20590375, -0.22106005, -0.19504064, 0.09862628, 0.095733166, 0.29774165, 0.06439602]), layer_2 = (weight = Float32[-0.12554047 0.004761639 … -0.06296449 -0.29443085], bias = Float32[0.09707039]))

Component array still remain on CPU, as it would

Other idea so to get a dataloader on device

using Optimization, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
using CUDA, LuxCUDA

x = rand(10,10)
y = sin.(x)
dev= gpu_device() 

data = MLUtils.DataLoader((x, y), batchsize = 1)|> dev

# Define the neural network
model = Chain(Dense(10, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model) |> dev
CUDA.@allowscalar ps_ca = ComponentArray(ps)

ps_ca

CUDA.@allowscalar ps_ca["layer_1"]["weight"]=CuArray(ps_ca["layer_1"]["weight"])
ps_ca

smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
    state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
    return l < 1e-1 ## Terminate if loss is small
end

function loss(ps, data)
    d1=data[1]
    d2=data[2]
    print("\n typeof(d1) $(typeof(d1)) typeof(d2) $(typeof(d1)) \n")
    ypred=smodel(d1, ps)
    return sum(abs2, ypred .- d2)
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)

get similar error

ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
  [1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:127
  [2] macro expansion
    @ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:211 [inlined]
  [3] unrolled_mapreduce
    @ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:198 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:189
  [5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:161
  [6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ MLDataDevices /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/public.jl:372
  [7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
    @ LuxLib /usr/local/share/julia/packages/LuxLib/wAt3f/src/traits.jl:210
  [8] select_fastest_activation(::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:128
  [9] rrule(::typeof(LuxLib.Impl.select_fastest_activation), ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:138
 [10] rrule(::Zygote.ZygoteRuleConfig{…}, ::Function, ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ ChainRulesCore /usr/local/share/julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
 [11] chain_rrule
    @ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
 [12] macro expansion
    @ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::typeof(LuxLib.Impl.select_fastest_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87
 [14] fused_dense_bias_activation
    @ /usr/local/share/julia/packages/LuxLib/wAt3f/src/api/dense.jl:35 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::typeof(fused_dense_bias_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [16] Dense
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/basic.jl:343 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [18] apply
    @ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [20] applychain
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:0 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [22] Chain
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:480 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [24] apply
    @ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [26] StatefulLuxLayer
    @ /usr/local/share/julia/packages/Lux/JbRSn/src/helpers/stateful.jl:119 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::StatefulLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [28] loss
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:31 [inlined]
 [29] _pullback(::Zygote.Context{…}, ::typeof(loss), ::ComponentVector{…}, ::Tuple{…})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [30] pullback(::Function, ::Zygote.Context{false}, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
 [31] pullback(::Function, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Tuple{CuArray{…}, CuArray{…}})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
 [32] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
    @ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
 [33] value_and_gradient
    @ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:81 [inlined]
 [34] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:99
 [35] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…}, p::Tuple{…})
    @ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:58
 [36] __solve(cache::OptimizationCache{…})
    @ Optimization /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:81
 [37] solve!(cache::OptimizationCache{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:186
 [38] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
 [39] top-level scope
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:38
Some type information was truncated. Use `show(err)` to see complete types.

see Training Lux Models using Optimization.jl | Lux.jl Docs. you need to the ps to GPU after constructing the ComponentArray not before.

@allowscalar is hiding an error which would tell you that it is trying to copy the cuarray into a array on CPU

1 Like

Thanks @avikpal after adding your suggestion Lux part seem to work well I suppose Hovewer it seem that the algorithm is not supporting the CUDA acceleration in hessian computations

ERROR: GPUCompiler.KernelError(GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(MethodInstance for (::GPUArrays.var"#34#36")(::CUDA.CuKernelContext, ::ComponentVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64), GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}(GPUCompiler.PTXCompilerTarget(v"8.6.0", v"7.8.0", true, nothing, nothing, nothing, nothing, false, nothing, nothing), CUDA.CUDACompilerParams(v"8.6.0", v"8.2.0"), true, nothing, :specfunc, false, 2), 0x00000000000069a1), "passing and using non-bitstype argument", "Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{ComponentArrays.CombinedAxis{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}, Base.OneTo{Int64}}}, typeof(+), Tuple{Base.Broadcast.Extruded{ComponentVector{Float32, CuDeviceVector{Float32, 1}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{DifferentiationInterface.OneElement{Int64, 1, Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:\n  .args is of type Tuple{Base.Broadcast.Extruded{ComponentVector{Float32, CuDeviceVector{Float32, 1}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{DifferentiationInterface.OneElement{Int64, 1, Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}}, Tuple{Bool}, Tuple{Int64}}} which is not isbits.\n    .2 is of type Base.Broadcast.Extruded{DifferentiationInterface.OneElement{Int64, 1, Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}}, Tuple{Bool}, Tuple{Int64}} which is not isbits.\n      .x is of type DifferentiationInterface.OneElement{Int64, 1, Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}}} which is not isbits.\n        .a is of type ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:352, Axis(weight = ViewAxis(1:320, ShapedAxis((32, 10))), bias = 321:352)), layer_2 = ViewAxis(353:385, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32))), bias = 33:33)))}}} which is not isbits.\n          .data is of type CuArray{Float32, 1, CUDA.DeviceMemory} which is not isbits.\n            .data is of type GPUArrays.DataRef{CUDA.Managed{CUDA.DeviceMemory}} which is not isbits.\n              .rc is of type GPUArrays.RefCounted{CUDA.Managed{CUDA.DeviceMemory}} which is not isbits.\n                .obj is of type CUDA.Managed{CUDA.DeviceMemory} which is not isbits.\n                  .stream is of type CuStream which is not isbits.\n                    .ctx is of type Union{Nothing, CuContext} which is not isbits.\n                .finalizer is of type Any which is not isbits.\n                .count is of type Base.Threads.Atomic{Int64} which is not isbits.\n", Base.StackTraces.StackFrame[])
Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/validation.jl:92
  [2] macro expansion
    @ /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:92 [inlined]
  [3] macro expansion
    @ /usr/local/share/julia/packages/TimerOutputs/NRdsv/src/TimerOutput.jl:253 [inlined]
  [4] 
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:90
  [5] codegen
    @ /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:82 [inlined]
  [6] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:79
  [7] compile
    @ /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:74 [inlined]
  [8] #1145
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/compilation.jl:250 [inlined]
  [9] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:34
 [10] JuliaContext(f::Function)
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/driver.jl:25
 [11] compile(job::GPUCompiler.CompilerJob)
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/compilation.jl:249
 [12] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/execution.jl:237
 [13] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/execution.jl:151
 [14] macro expansion
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:380 [inlined]
 [15] macro expansion
    @ ./lock.jl:273 [inlined]
 [16] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:375
 [17] cufunction
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:372 [inlined]
 [18] macro expansion
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:112 [inlined]
 [19] #launch_heuristic#1200
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/gpuarrays.jl:17 [inlined]
 [20] launch_heuristic
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/gpuarrays.jl:15 [inlined]
 [21] _copyto!
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:78 [inlined]
 [22] copyto!
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:44 [inlined]
 [23] copy
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [24] materialize
    @ ./broadcast.jl:867 [inlined]
 [25] broadcast_preserving_zero_d
    @ ./broadcast.jl:856 [inlined]
 [26] +(A::ComponentVector{…}, B::DifferentiationInterface.OneElement{…})
    @ Base ./arraymath.jl:8
 [27] basis
    @ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/src/utils/basis.jl:64 [inlined]
 [28] basis
    @ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/src/utils/basis.jl:49 [inlined]
 [29] #46
    @ ./none:0 [inlined]
 [30] iterate
    @ ./generator.jl:48 [inlined]
 [31] collect(itr::Base.Generator{ComponentArrays.CombinedAxis{…}, DifferentiationInterface.var"#46#51"{…}})
    @ Base ./array.jl:780
 [32] _prepare_hessian_aux(batch_size_settings::DifferentiationInterface.BatchSizeSettings{…}, f::typeof(loss), backend::DifferentiationInterface.SecondOrder{…}, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterface /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/src/second_order/hessian.jl:94
 [33] prepare_hessian(f::typeof(loss), backend::DifferentiationInterface.SecondOrder{…}, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterface /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/src/second_order/hessian.jl:83
 [34] instantiate_function(f::OptimizationFunction{…}, x::ComponentVector{…}, adtype::AutoZygote, p::Tuple{…}, num_cons::Int64; g::Bool, h::Bool, hv::Bool, fg::Bool, fgh::Bool, cons_j::Bool, cons_vjp::Bool, cons_jvp::Bool, cons_h::Bool, lag_h::Bool)
    @ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:71
 [35] instantiate_function
    @ /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:21 [inlined]
 [36] #instantiate_function#38
    @ /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:281 [inlined]
 [37] OptimizationCache(prob::OptimizationProblem{…}, opt::Optimization.Sophia; callback::Function, maxiters::Int64, maxtime::Nothing, abstol::Nothing, reltol::Nothing, progress::Bool, structural_analysis::Bool, manifold::Nothing, kwargs::@Kwargs{…})
    @ OptimizationBase /usr/local/share/julia/packages/OptimizationBase/gvXsf/src/cache.jl:60
 [38] OptimizationCache
    @ /usr/local/share/julia/packages/OptimizationBase/gvXsf/src/cache.jl:25 [inlined]
 [39] #__init#32
    @ /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:25 [inlined]
 [40] __init
    @ /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:22 [inlined]
 [41] #init#726
    @ /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:172 [inlined]
 [42] init
    @ /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:170 [inlined]
 [43] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
    @ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
 [44] top-level scope
    @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:35
Some type information was truncated. Use `show(err)` to see complete types.

cc @Vaibhavdixit02 .

The way hessians are implemented using DI won’t work with Lux (or NNlib and its upstream users) in general. The way forward would be to use something like Nested Automatic Differentiation | Lux.jl Docs or wait for something like Add Reactant support · Issue #132 · SciML/OptimizationBase.jl · GitHub to be implemented

1 Like

Thanks for help so I will probably wait till rectant now (fantastic development), thank you for your time !

Just to clarify, that only holds when we try to take the hessian of a loss which already contains autodiff, right? Otherwise I don’t see the issue which would prevent hessians from working with Lux?