Dispatch on the last layer type in Flux

I’m making an implementation of a Evidential Deep Learning model in Flux and I would like to have a function predict(model, x) where x is input data. The model is a Chain which consists possibly of many layers. I want the predict function to change behavior depending on the type of the last layer in the Chain. Is that possible?

Alternatively I could create a struct per model type which contains the Chain and a type and then dispatch on that type. But it seems like there’s probably a better design.

Right now I simply have

m1 = Chain(Dense(10 => 100, tanh), MyLayerA(100 => 5))
m2 = Chain(Dense(10 => 100, tanh), MyLayerB(100 => 5))

predict(m, x) = m(x)

Where you can see that the predict function does not differentiate behavior between the two models. Another way could be to do an if statement in the predict function based on


but that also seems rather hacky.

Probably easiest to define predict(m, x) = predict(typeof(m.layers[end]), m, x)
and write dispatches for e.g. predict(::Type{<:Dense}, m, x) = [...]

1 Like

Cool that looks somewhat like what I was aiming for. :pray:t2:

I’m almost certain one could make this work without the dynamic dispatch, because Chain has all the layer types as parameters in its own type. But I can’t figure out how right now.