I’m making an implementation of a Evidential Deep Learning model in Flux and I would like to have a function predict(model, x)
where x is input data. The model is a Chain
which consists possibly of many layers. I want the predict function to change behavior depending on the type of the last layer in the Chain. Is that possible?
Alternatively I could create a struct per model type which contains the Chain and a type and then dispatch on that type. But it seems like there’s probably a better design.
Right now I simply have
m1 = Chain(Dense(10 => 100, tanh), MyLayerA(100 => 5))
m2 = Chain(Dense(10 => 100, tanh), MyLayerB(100 => 5))
predict(m, x) = m(x)
Where you can see that the predict function does not differentiate behavior between the two models. Another way could be to do an if statement in the predict function based on
typeof(m1.layers[end])
but that also seems rather hacky.