In PyTorch, I can define a network like this:
from torch import nn class Network(nn.Module): def __init__(self, ...): ... self.ln_1 = nn.Linear(64, 32) self.ln_2 = nn.Linear(32, 16) ...
named_parameters method (Module — PyTorch 1.9.0 documentation) lets you iterate through the modules in a network and access them (and their gradients) by name (e.g.,
Does Flux have similar functionality? MWE:
using Flux network = Chain(Dense(64, 32, tanh), Dense(32, 16, tanh)) ps = Flux.params(network) point = ... # data point criterion = ... # loss function gs = Flux.gradient(ps) do loss = criterion(point...) return loss_val end
then I can access the gradient of the first layer’s gradients with
gs[network.weights], but this is maybe a little less interpretable than in PyTorch, since
gs's keys are arrays, not strings.