Using `Lux.WrappedFunction` for pre/post processing in Lux model

Hi!

I recently learned about Lux.WrappedFunction() layer (docs) that allows to prescribed a parameter less function inside the Chain. I think this is a great feature to allow building of preprocessing and postprocessing inside the chain itself, so I started using it like this. For example, I will define a model as

# Define normalization function for preprocessing
function normalize(
    X;
    lims::Tuple{F, F},
    method::Symbol = :shift
) where {F <: AbstractFloat}

    if method == :shift
        return (X .- lims[1]) ./ (lims[2] - lims[1]) .- 0.5
    else
        throw("Normalization method not implemented.")
    end
end

architecture = Lux.Chain(
    # Preprocessing step
    WrappedFunction(
        x -> [
            # TODO: problem coming from here!!!
            normalize(x[1]; lims = (0.0, 500)),
            normalize(x[2]; lims = (0.0, 0.6))
            ]
        ),
    Dense(2, 5, x -> gelu.(x)),
    Dense(5, 8, x -> gelu.(x)),
    Dense(8, 3, x -> gelu.(x)),
    Dense(3, 1, sigmoid),
    # Postprocessing step
    WrappedFunction(y -> 10.0 .* y)
)

This works for the uses cases I was working with, where each evaluation of the neural network was done just once. However, I recently notice that this does not work when making multiple inputs. To illustrate this, I grab the example in the documentation in Julia and Lux for the Uninitiated, Linear Regression, and adapted to have a WrappedFunction in the chain. Here the MWE:

using Lux, Random
using Optimisers, Printf

model = Chain(
    WrappedFunction(x -> [x[1], x[2]]),
    Dense(2 => 1),
    Dense(1, 1, sigmoid),
    WrappedFunction(y -> 1.1 .* y)
)

rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)

n_samples = 20
x_dim = 2
y_dim = 1

ft = Float64

W = randn(rng, ft, y_dim, x_dim)
b = randn(rng, ft, y_dim)

x_samples = randn(rng, ft, x_dim, n_samples)
y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, ft, y_dim, n_samples)
println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))

lossfn = MSELoss()

println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))

function train_model!(model, ps, st, opt, nepochs::Int)
    tstate = Training.TrainState(model, ps, st, opt)
    for i in 1:nepochs
        grads, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), lossfn, (x_samples, y_samples), tstate
        )
        if i % 1000 == 1 || i == nepochs
            @printf "Loss Value after %6d iterations: %.8f\n" i loss
        end
    end
    return tstate.model, tstate.parameters, tstate.states
end

model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)

println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))

The error here comes from how the forward of the NN is evaluated, which returns a single response y, rather than one per sample point:

ERROR: LoadError: DimensionMismatch: loss function expects size(ŷ) = (1,) to match size(y) = (1, 20)

The two question I want to raise are:

  1. Is there a simple way to fix this? I imagined changing how the wrapped function is prescribed may be important.
  2. Is this the right thing to do if I want to do some custom preprocessing/postprocessing on top of my chain? I found very elegant that I can do this inside the definition of the chain, but maybe this is not really a good idea.

@avikpal do you have any comments on this matter?

Thanks!

model = Chain(
    WrappedFunction(x -> [x[1], x[2]]),
    Dense(2 => 1),
    Dense(1, 1, sigmoid),
    WrappedFunction(y -> 1.1 .* y)
)

this is taking a batch of inputs (2, 20) and throwing away 38 elements and keeping the first 2. Your preprocess/postprocess steps need to ensure that it can handle batching and such.

The error here comes from how the forward of the NN is evaluated, which returns a single response y, rather than one per sample point:

Right, because in this case MSE or any elementwise loss functions aren’t well defined. You can define your custom loss function (See Utilities | Lux.jl Docs point 2 for how it needs to be specified)

Hi @avikpal , thank you for your response.

this is taking a batch of inputs (2, 20) and throwing away 38 elements and keeping the first 2. Your preprocess/postprocess steps need to ensure that it can handle batching and such.

I was aware that that is what was happening, but there is a simple way to fix it? So this works for both cases?

Right, because in this case MSE or any elementwise loss functions aren’t well defined. You can define your custom loss function (See Utilities | Lux.jl Docs point 2 for how it needs to be specified)

If well I can define a loss function for this, I think a better solution is to use the ones defined inside Lux and instead solve the problem in the chain definition, right? I would like this to work for both cases.

Again, I think this will be a very elegant solution if it can be easily done with Lux without further customization. If not, it may be better just to define the preprocessing/postprocessing functions outside the chain, but it wit will be very elegant to do it with Lux :wink:

Instead of this rewrite as x -> vcat(x[1:1, :], v[2:2, :])

1 Like

Thank you @avikpal for the answer!

Your fix does make the code to run and the loss is evaluated, but I am not sure this is doing the right job. Right now I have

model = Chain(
    WrappedFunction(x -> vcat(x[1:1, :], x[2:2, :])),
    Dense(2 => 1),
    Dense(1, 1, sigmoid),
    WrappedFunction(y -> 1.1 .* y)
)

as you suggested, however when I evaluate the neural networks this seems to be avaluating just the first data point:

smodel = StatefulLuxLayer{true}(model, ps, st)

@show smodel(X_samples[:, 1])
# 1-element Vector{Float64}:
# 5.351669698013934
@show smodel(X_samples)
# 1-element Vector{Float64}:
# 5.351669698013934

Seems to be evaluating this just in the first sample point… maybe the loss function internally runs this for each sample point but the call I made with smodel is wrong?

I managed to get something working for those cases with this:

function wrapped(v::Union{Vector,SubArray})
    @assert (length(v) == 2)
    # @show length(v)
    return 3.0 .* v
end

function wrapped(V::Matrix)
    @assert size(x_samples)[1] == 2
    M = reduce(hcat, map(v -> wrapped(v), eachcol(V)))
    return M
end

model = Chain(
    # WrappedFunction(x -> vcat(x[1:1, :], x[2:2, :])),
    WrappedFunction(x -> wrapped(x)),
    Dense(2 => 1),
    Dense(1, 1, sigmoid),
    WrappedFunction(y -> 1.1 .* y)
)

However, I am not sure this is the ideal solutions, and I am bit worried this may introduce some problems in my code…