How to use Lux's `apply` function?

I am trying to get Lux to work, but am stumbling at a relatively early stage of the tutorials. I am trying to use a combination of the quickstart and getting started pages (Getting Started | Lux.jl Docs, Getting Started | Lux.jl Docs).

Here, I am trying to simply fit a neural network to match data from a simple R1 → R1 function:

# Fetch packages.
using Lux, Random, Optimisers, Zygote, StableRNGs, Plots

# Create training data and true function.
f(x) = 0.5 * (x^3)/(x^3 + 0.7^3)
xs = collect(0.0:0.1:1.0)
ys = [(0.9 + 0.2rand())*f(x) for x in xs]

Following the tutorial I get to

# Set RNG.
rng = StableRNG(1234)

# Create Neural network.
model = Lux.Chain(
    Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
    Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
    Lux.Dense(3 => 1, Lux.softplus, use_bias = false)
)

# Prepares the model for training.
ps, st = Lux.setup(rng, model)

However, next things go wrong, I try:

x = reshape(xs, 1, length(xs))
y, st = Lux.apply(model, xs, ps, st)

but get an error

ERROR: DimensionMismatch: matrix A has axes (Base.OneTo(3),Base.OneTo(1)), matrix B has axes (Base.OneTo(11),Base.OneTo(1))

What should I do here? My impression is that the first dimension of x should be the number of inputs to the network (in my case 1). The second dimension it is never really said what it is. My best guess here was, since I am training on 11 data points, I should have it be 11. This is obviously wrong.

In addition to reading the tutorial, I tried

?Lux.apply

but it just calls x the input. I also tried looking a bit at the manual, but didn’t really find anything.

Exactly what should x be here?

I think you meant x instead of xs in the second argument on the second line.

1 Like

Brilliant!
(and very embarassing)

Thansk a lot, works well :slight_smile: