In PyTorch, I can define a network like this:
from torch import nn
class Network(nn.Module):
def __init__(self, ...):
...
self.ln_1 = nn.Linear(64, 32)
self.ln_2 = nn.Linear(32, 16)
...
The named_parameters
method (Module β PyTorch 1.12 documentation) lets you iterate through the modules in a network and access them (and their gradients) by name (e.g., ln_1
).
Does Flux have similar functionality? MWE:
using Flux
network = Chain(Dense(64, 32, tanh), Dense(32, 16, tanh))
ps = Flux.params(network)
point = ... # data point
criterion = ... # loss function
gs = Flux.gradient(ps) do
loss = criterion(point...)
return loss_val
end
then I can access the gradient of the first layerβs gradients with gs[network[1].weights]
, but this is maybe a little less interpretable than in PyTorch, since gs
βs keys are arrays, not strings.