I am trying to get Lux to work, but am stumbling at a relatively early stage of the tutorials. I am trying to use a combination of the quickstart and getting started pages (Getting Started | Lux.jl Docs, Getting Started | Lux.jl Docs).
Here, I am trying to simply fit a neural network to match data from a simple R1 → R1 function:
# Fetch packages.
using Lux, Random, Optimisers, Zygote, StableRNGs, Plots
# Create training data and true function.
f(x) = 0.5 * (x^3)/(x^3 + 0.7^3)
xs = collect(0.0:0.1:1.0)
ys = [(0.9 + 0.2rand())*f(x) for x in xs]
Following the tutorial I get to
# Set RNG.
rng = StableRNG(1234)
# Create Neural network.
model = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 1, Lux.softplus, use_bias = false)
)
# Prepares the model for training.
ps, st = Lux.setup(rng, model)
However, next things go wrong, I try:
x = reshape(xs, 1, length(xs))
y, st = Lux.apply(model, xs, ps, st)
but get an error
ERROR: DimensionMismatch: matrix A has axes (Base.OneTo(3),Base.OneTo(1)), matrix B has axes (Base.OneTo(11),Base.OneTo(1))
What should I do here? My impression is that the first dimension of x
should be the number of inputs to the network (in my case 1). The second dimension it is never really said what it is. My best guess here was, since I am training on 11 data points, I should have it be 11. This is obviously wrong.
In addition to reading the tutorial, I tried
?Lux.apply
but it just calls x
the input
. I also tried looking a bit at the manual, but didn’t really find anything.
Exactly what should x
be here?