Accessing a specific layer's weights in a Flux Chain

In PyTorch, I can define a network like this:

from torch import nn

class Network(nn.Module):
    def __init__(self, ...):
        ...
        self.ln_1 = nn.Linear(64, 32)
        self.ln_2 = nn.Linear(32, 16)
        ...

The named_parameters method (Module β€” PyTorch 1.12 documentation) lets you iterate through the modules in a network and access them (and their gradients) by name (e.g., ln_1).

Does Flux have similar functionality? MWE:

using Flux 
network = Chain(Dense(64, 32, tanh), Dense(32, 16, tanh))
ps = Flux.params(network)
point = ... # data point
criterion = ... # loss function
gs = Flux.gradient(ps) do 
    loss = criterion(point...)
    return loss_val
end

then I can access the gradient of the first layer’s gradients with gs[network[1].weights], but this is maybe a little less interpretable than in PyTorch, since gs’s keys are arrays, not strings.

This is exactly what explicit parameters were designed for:

...
criterion(m, ...) = ... # loss function
gs = Flux.gradient(network) do m
    loss = criterion(m, point...)
    return loss
end

size(network[1].weights) == size(gs[1].weights) # true
1 Like