Hello, folks.
I built a simple classification neural net in Flux
(6 inputs in a Dense
layer, a BatchNorm
hidden layer with 13 neurons, another Dense
13 neuron output layer) and I was looking for some way to correlate the outputs and inputs. For instance, based on the weights assigned by Flux
, what are the input neurons that are most important?
For instance: let’s say one of the outputs is “cat” and one is “mouse” and the inputs are height, number of legs, etc. By looking at the weights, maybe I can figure out that the most important distinguishing factor between a “cat” and a “mouse” is the height. So I’m looking for an elegant way to plot that.
I noticed that for the Dense
layers I can use layer.weight
, but I’m not sure for the BatchNorm
layer. Also, the weight matrices are large, so I’m looking for a clever way to look at all the weights easily for a given input. Maybe a graph where the lines connecting the neurons are colored by the weights?
Thanks a lot!