I directly copied the Lux example for creating a neural ODE on MNIST: MNIST Classification using NeuralODE - Lux.jl
and simply replaced CUDA with Metal to use my Mac M1 GPU. However, just from this single change the code now breaks:
cd(@__DIR__)
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using Lux
using ComponentArrays,
SciMLSensitivity,
Metal,
Optimisers,
OrdinaryDiffEq,
Random,
Statistics,
Zygote,
OneHotArrays
import MLDatasets: MNIST
import MLUtils: DataLoader, splitobs
Metal.functional()
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = 1500
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize=batchsize, shuffle=false))
end
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
kwargs::K
end
function NeuralODE(model::Lux.AbstractExplicitLayer;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
kwargs...)
return NeuralODE(model, solver, sensealg, tspan, kwargs)
end
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(u, p, st)
return u_
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
end
function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:MtlArray}}) where {T, N}
dev = gpu_device()
return dropdims(dev(x); dims=3)
end
diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3)
function create_model()
# Construct the Neural ODE Model
model = Chain(FlattenLayer(),
Dense(784, 20, tanh),
NeuralODE(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh));
save_everystep=false,
reltol=1.0f-3,
abstol=1.0f-3,
save_start=false),
diffeqsol_to_array,
Dense(20, 10))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
dev = gpu_device()
ps = ComponentArray(ps) |> dev
st = st |> dev
return model, ps, st
end
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))
function loss(x, y, model, ps, st)
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st
end
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
iterator = CUDA.functional() ? CuIterator(dataloader) : dataloader
cpu_dev = cpu_device()
for (x, y) in iterator
target_class = onecold(cpu_dev(y))
predicted_class = onecold(cpu_dev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
function train()
model, ps, st = create_model()
# Training
train_dataloader, test_dataloader = loadmnist(128, 0.9)
opt = Optimisers.ADAM(0.001f0)
st_opt = Optimisers.setup(opt, ps)
dev = gpu_device()
### Warmup the Model
img, lab = dev(train_dataloader.data[1][:, :, :, 1:1]),
dev(train_dataloader.data[2][:, 1:1])
loss(img, lab, model, ps, st)
(l, _), back = pullback(p -> loss(img, lab, model, p, st), ps)
back((one(l), nothing))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
x = dev(x)
y = dev(y)
(l, st), back = pullback(p -> loss(x, y, model, p, st), ps)
### We need to add `nothing`s equal to the number of returned values - 1
gs = back((one(l), nothing))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
ttime = time() - stime
println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " *
"$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " *
"Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%")
end
end
train()
and returns the following error
ERROR: ArgumentError: cannot take the CPU address of a MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}
Stacktrace:
[1] unsafe_convert(#unused#::Type{Ptr{Float32}}, x::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ Metal ~/.julia/packages/Metal/qeZqc/src/array.jl:139
[2] unsafe_convert(#unused#::Type{Ptr{Float32}}, V::SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true})
@ Base ./subarray.jl:437
[3] unsafe_convert(#unused#::Type{Ptr{Float32}}, a::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Base ./reshapedarray.jl:283
[4] gemm!(transA::Char, transB::Char, alpha::Float32, A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, beta::Float32, C::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ LinearAlgebra.BLAS /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
[5] gemm_wrapper!(C::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, tA::Char, tB::Char, A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:674
[6] mul!
@ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:161 [inlined]
[7] mul!
@ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
[8] *(A::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, B::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
@ LinearAlgebra /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:148
[9] Dense
@ ~/.julia/packages/Lux/5YzHA/src/layers/basic.jl:223 [inlined]
[10] apply(model::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, ps::ComponentVector{Float32, SubArray{Float32, 1, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))}}}, st::NamedTuple{(), Tuple{}})
@ LuxCore ~/.julia/packages/LuxCore/yC3wg/src/LuxCore.jl:100
[11] macro expansion
@ ./abstractarray.jl:0 [inlined]
[12] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{FlattenLayer, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, NeuralODE{Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float32, Float32, Bool}}}}, WrappedFunction{typeof(diffeqsol_to_array)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, x::MtlArray{Float32, 4, Metal.MTL.MTLResourceStorageModePrivate}, ps::ComponentVector{Float32, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
@ Lux ~/.julia/packages/Lux/5YzHA/src/layers/containers.jl:493
[13] Chain
@ ~/.julia/packages/Lux/5YzHA/src/layers/containers.jl:491 [inlined]
[14] loss(x::MtlArray{Float32, 4, Metal.MTL.MTLResourceStorageModePrivate}, y::MtlMatrix{Bool, Metal.MTL.MTLResourceStorageModePrivate}, model::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{FlattenLayer, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, NeuralODE{Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol, :save_start), Tuple{Bool, Float32, Float32, Bool}}}}, WrappedFunction{typeof(diffeqsol_to_array)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, ps::ComponentVector{Float32, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
@ Main ~/Research/constitutive_history/examples/mnist.jl:99
[15] train()
@ Main ~/Research/constitutive_history/examples/mnist.jl:131
[16] top-level scope
@ ~/Research/constitutive_history/examples/mnist.jl:155