I built a simple classification neural net in
Flux (6 inputs in a
Dense layer, a
BatchNorm hidden layer with 13 neurons, another
Dense 13 neuron output layer) and I was looking for some way to correlate the outputs and inputs. For instance, based on the weights assigned by
Flux, what are the input neurons that are most important?
For instance: let’s say one of the outputs is “cat” and one is “mouse” and the inputs are height, number of legs, etc. By looking at the weights, maybe I can figure out that the most important distinguishing factor between a “cat” and a “mouse” is the height. So I’m looking for an elegant way to plot that.
I noticed that for the
Dense layers I can use
layer.weight, but I’m not sure for the
BatchNorm layer. Also, the weight matrices are large, so I’m looking for a clever way to look at all the weights easily for a given input. Maybe a graph where the lines connecting the neurons are colored by the weights?
Thanks a lot!