Hello, folks.
I built a simple classification neural net in Flux (6 inputs in a Dense layer, a BatchNorm hidden layer with 13 neurons, another Dense 13 neuron output layer) and I was looking for some way to correlate the outputs and inputs. For instance, based on the weights assigned by Flux, what are the input neurons that are most important?
For instance: let’s say one of the outputs is “cat” and one is “mouse” and the inputs are height, number of legs, etc. By looking at the weights, maybe I can figure out that the most important distinguishing factor between a “cat” and a “mouse” is the height. So I’m looking for an elegant way to plot that.
I noticed that for the Dense layers I can use layer.weight, but I’m not sure for the BatchNorm layer. Also, the weight matrices are large, so I’m looking for a clever way to look at all the weights easily for a given input. Maybe a graph where the lines connecting the neurons are colored by the weights?
Thanks a lot!