Individual ctivation function for each network output

Hi,

Can I use a different activation function for each output of my network?

using Lux
net = Chain(Dense(3, 10, tanh), Dense(10, 3, tanh))

I know that output 2 can only be positive, but 1 and 3 can be negative.

Is it possible to define a particular activation function for each output of the network?
1 and 3 → tanh
2 → relu

If so, how can I do that?

You can probably do this with a custom layer.

struct ChannelActivations{T}
  activations::T
end

ChannelActivations(args...) = ChannelActivations(args)

function (a::ChannelActivations)(x::AbstractMatrix, ps, st::NamedTuple)
    y = reduce(vcat, [f.(v) for (f, v) in zip(a.activations, eachrow(x))])
    return y, st
end

Lux.parameterlength(::ChannelActivations) = 0

Lux.statelength(::ChannelActivations) = 0

# no specified activation on the second Dense makes it a linear layer.
net = Chain(Dense(3, 10, tanh), Dense(10, 3), ChannelActivations(tanh, relu, tanh))