Show neural net weights (Flux)

Hello, folks.
I built a simple classification neural net in Flux (6 inputs in a Dense layer, a BatchNorm hidden layer with 13 neurons, another Dense 13 neuron output layer) and I was looking for some way to correlate the outputs and inputs. For instance, based on the weights assigned by Flux, what are the input neurons that are most important?
For instance: let’s say one of the outputs is “cat” and one is “mouse” and the inputs are height, number of legs, etc. By looking at the weights, maybe I can figure out that the most important distinguishing factor between a “cat” and a “mouse” is the height. So I’m looking for an elegant way to plot that.

I noticed that for the Dense layers I can use layer.weight, but I’m not sure for the BatchNorm layer. Also, the weight matrices are large, so I’m looking for a clever way to look at all the weights easily for a given input. Maybe a graph where the lines connecting the neurons are colored by the weights?
Thanks a lot!

Hi @Ribeiro!
Not sure I can help a lot but for what it’s worth, before being a Flux.jl question, it’s a math research question. A keyword you might use is “feature importance”. So I’m not sure you’ll find a lot of things ready to use, but there are some packages that can do it:

Thank you @gdalle . I did not know the term “feature importance”, so that alone is already helpful!
It seems Duff.jl might be able to do what I want, maybe ShapML.jl too (though the documentation of the latter doesn’t have any Flux examples, it says it is model agnostic). There’s also KoalaTrees.jl, but that seems very outdated in terms of Julia versions.
If anyone has a working example with a Flux NN, that’d be much appreciated. Otherwise, I’ll dig into these two.
Thanks again!

1 Like