Thanks @ChrisRackauckas and @avikpal !!yes here I made an error ! Hovewer the code
using Optimization, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
using CUDA, LuxCUDA
x = rand(10,10)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 1)
# Define the neural network
model = Chain(Dense(10, 32, tanh), Dense(32, 1))
dev= gpu_device()
ps, st = Lux.setup(Random.default_rng(), model) |> dev
CUDA.@allowscalar ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)
function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-1 ## Terminate if loss is small
end
function loss(ps, data)
d1=CuArray(data[1])
d2=CuArray(data[2])
print("\n typeof(d1) $(typeof(d1)) typeof(d2) $(typeof(d1)) \n")
ypred=smodel(d1, ps)
return sum(abs2, ypred .- d2)
end
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)
res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)
give error
ERROR: ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
[1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:127
[2] macro expansion
@ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:211 [inlined]
[3] unrolled_mapreduce
@ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:198 [inlined]
[4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:189
[5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:161
[6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ MLDataDevices /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/public.jl:372
[7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ LuxLib /usr/local/share/julia/packages/LuxLib/wAt3f/src/traits.jl:210
[8] select_fastest_activation(::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:128
[9] rrule(::typeof(LuxLib.Impl.select_fastest_activation), ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:138
[10] rrule(::Zygote.ZygoteRuleConfig{…}, ::Function, ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ ChainRulesCore /usr/local/share/julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
[11] chain_rrule
@ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
[12] macro expansion
@ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
[13] _pullback(::Zygote.Context{…}, ::typeof(LuxLib.Impl.select_fastest_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87
[14] fused_dense_bias_activation
@ /usr/local/share/julia/packages/LuxLib/wAt3f/src/api/dense.jl:35 [inlined]
[15] _pullback(::Zygote.Context{…}, ::typeof(fused_dense_bias_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[16] Dense
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/basic.jl:343 [inlined]
[17] _pullback(::Zygote.Context{…}, ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[18] apply
@ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
[19] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[20] applychain
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:0 [inlined]
[21] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[22] Chain
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:480 [inlined]
[23] _pullback(::Zygote.Context{…}, ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[24] apply
@ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
[25] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[26] StatefulLuxLayer
@ /usr/local/share/julia/packages/Lux/JbRSn/src/helpers/stateful.jl:119 [inlined]
[27] _pullback(::Zygote.Context{…}, ::StatefulLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[28] loss
@ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:25 [inlined]
[29] _pullback(::Zygote.Context{…}, ::typeof(loss), ::ComponentVector{…}, ::Tuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[30] pullback(::Function, ::Zygote.Context{false}, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
[31] pullback(::Function, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Tuple{Matrix{…}, Matrix{…}})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
[32] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
[33] value_and_gradient
@ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:81 [inlined]
[34] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:99
[35] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…}, p::Tuple{…})
@ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:58
[36] __solve(cache::OptimizationCache{…})
@ Optimization /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:81
[37] solve!(cache::OptimizationCache{…})
@ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:186
[38] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
@ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
[39] top-level scope
@ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:32
Some type information was truncated. Use `show(err)` to see complete types.
It is most probably related to the fact that component array for some reson is not on CUDA
julia> ps_ca
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.57011324 0.638751 … -0.19225933 -0.12354956; 0.20622618 -0.12579119 … 0.644531 0.44288114; … ; 0.69624126 0.7610798 … 0.34806055 -0.15016149; 0.08788869 -0.011543258 … 0.256488 -0.7310434], bias = Float32[-0.11362305, 0.10988254, 0.23947266, 0.29654834, 0.19410786, 0.029232662, 0.19418657, 0.2171803, 0.056266934, -0.074235305 … 0.07852854, -0.04083943, -0.1578278, -0.20590375, -0.22106005, -0.19504064, 0.09862628, 0.095733166, 0.29774165, 0.06439602]), layer_2 = (weight = Float32[-0.12554047 0.004761639 … -0.06296449 -0.29443085], bias = Float32[0.09707039]))
Although it is created from ps that is on GPU.
Morover even when I manually cast is to CuArray like
CUDA.@allowscalar ps_ca["layer_1"]["weight"]=CuArray(ps_ca["layer_1"]["weight"])
ps_ca
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.57011324 0.638751 … -0.19225933 -0.12354956; 0.20622618 -0.12579119 … 0.644531 0.44288114; … ; 0.69624126 0.7610798 … 0.34806055 -0.15016149; 0.08788869 -0.011543258 … 0.256488 -0.7310434], bias = Float32[-0.11362305, 0.10988254, 0.23947266, 0.29654834, 0.19410786, 0.029232662, 0.19418657, 0.2171803, 0.056266934, -0.074235305 … 0.07852854, -0.04083943, -0.1578278, -0.20590375, -0.22106005, -0.19504064, 0.09862628, 0.095733166, 0.29774165, 0.06439602]), layer_2 = (weight = Float32[-0.12554047 0.004761639 … -0.06296449 -0.29443085], bias = Float32[0.09707039]))
Component array still remain on CPU, as it would
Other idea so to get a dataloader on device
using Optimization, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
using CUDA, LuxCUDA
x = rand(10,10)
y = sin.(x)
dev= gpu_device()
data = MLUtils.DataLoader((x, y), batchsize = 1)|> dev
# Define the neural network
model = Chain(Dense(10, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model) |> dev
CUDA.@allowscalar ps_ca = ComponentArray(ps)
ps_ca
CUDA.@allowscalar ps_ca["layer_1"]["weight"]=CuArray(ps_ca["layer_1"]["weight"])
ps_ca
smodel = StatefulLuxLayer{true}(model, nothing, st)
function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-1 ## Terminate if loss is small
end
function loss(ps, data)
d1=data[1]
d2=data[2]
print("\n typeof(d1) $(typeof(d1)) typeof(d2) $(typeof(d1)) \n")
ypred=smodel(d1, ps)
return sum(abs2, ypred .- d2)
end
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)
res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)
get similar error
ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
[1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:127
[2] macro expansion
@ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:211 [inlined]
[3] unrolled_mapreduce
@ /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:198 [inlined]
[4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:189
[5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ MLDataDevices.Internal /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/internal.jl:161
[6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ MLDataDevices /usr/local/share/julia/packages/MLDataDevices/MFmTU/src/public.jl:372
[7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}, SubArray{…}})
@ LuxLib /usr/local/share/julia/packages/LuxLib/wAt3f/src/traits.jl:210
[8] select_fastest_activation(::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:128
[9] rrule(::typeof(LuxLib.Impl.select_fastest_activation), ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ LuxLib.Impl /usr/local/share/julia/packages/LuxLib/wAt3f/src/impl/activation.jl:138
[10] rrule(::Zygote.ZygoteRuleConfig{…}, ::Function, ::Function, ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ ChainRulesCore /usr/local/share/julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
[11] chain_rrule
@ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
[12] macro expansion
@ /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
[13] _pullback(::Zygote.Context{…}, ::typeof(LuxLib.Impl.select_fastest_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87
[14] fused_dense_bias_activation
@ /usr/local/share/julia/packages/LuxLib/wAt3f/src/api/dense.jl:35 [inlined]
[15] _pullback(::Zygote.Context{…}, ::typeof(fused_dense_bias_activation), ::typeof(tanh_fast), ::Base.ReshapedArray{…}, ::CuArray{…}, ::SubArray{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[16] Dense
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/basic.jl:343 [inlined]
[17] _pullback(::Zygote.Context{…}, ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[18] apply
@ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
[19] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[20] applychain
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:0 [inlined]
[21] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[22] Chain
@ /usr/local/share/julia/packages/Lux/JbRSn/src/layers/containers.jl:480 [inlined]
[23] _pullback(::Zygote.Context{…}, ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[24] apply
@ /usr/local/share/julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
[25] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[26] StatefulLuxLayer
@ /usr/local/share/julia/packages/Lux/JbRSn/src/helpers/stateful.jl:119 [inlined]
[27] _pullback(::Zygote.Context{…}, ::StatefulLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[28] loss
@ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:31 [inlined]
[29] _pullback(::Zygote.Context{…}, ::typeof(loss), ::ComponentVector{…}, ::Tuple{…})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[30] pullback(::Function, ::Zygote.Context{false}, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
[31] pullback(::Function, ::ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Tuple{CuArray{…}, CuArray{…}})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
[32] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
@ Zygote /usr/local/share/julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
[33] value_and_gradient
@ /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:81 [inlined]
[34] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt /usr/local/share/julia/packages/DifferentiationInterface/DSrNZ/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:99
[35] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…}, p::Tuple{…})
@ OptimizationZygoteExt /usr/local/share/julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:58
[36] __solve(cache::OptimizationCache{…})
@ Optimization /usr/local/share/julia/packages/Optimization/cfp9i/src/sophia.jl:81
[37] solve!(cache::OptimizationCache{…})
@ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:186
[38] solve(::OptimizationProblem{…}, ::Optimization.Sophia; kwargs::@Kwargs{…})
@ SciMLBase /usr/local/share/julia/packages/SciMLBase/XzPx0/src/solve.jl:94
[39] top-level scope
@ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/debug_optimazation/to_remove.jl:38
Some type information was truncated. Use `show(err)` to see complete types.