Flux: gradient of Chain returns nothing

I have defined a MLP (multi layer perceptron) with the Chain command, and this MLP is wrapped in a mutable structure.
The gradient of this Chain (2 layer dense NN) returns nothing. Is this a bug, or am I doing this the wrong way?

Minimum working example:

function loss(filter, x, y)
    return norm(y - filter.f(x))
end

mutable struct Filter{a}
    f::a
end

funct = Chain(Dense(2, 10, sigmoid), Dense(10, 2))
filt = Filter(funct)
opt = ADAM()

y = [1.0; 1.0]
x = [0.0; 0.0]
ps = Flux.Params(filt.f)


for i= 1:5
    grads = gradient(ps) do
        loss(filt, x, y)
    end
    update!(opt, ps, grads)
    println(grads[ps[1]])
    println(ps[1].weight, "\n")
end

which prints:

nothing
Float32[-0.30972472 0.48900542; 0.40722048 -0.47656116; 0.27757087 0.12060472; 0.23041193 -0.7000659; 0.118268944 0.45202392; -0.32605898 -0.16675131; 0.63128626 -0.24441159; -0.67778873 -0.36194095; 0.14881344 -0.52690184; -0.0017206029 -0.22051147]

nothing
Float32[-0.30972472 0.48900542; 0.40722048 -0.47656116; 0.27757087 0.12060472; 0.23041193 -0.7000659; 0.118268944 0.45202392; -0.32605898 -0.16675131; 0.63128626 -0.24441159; -0.67778873 -0.36194095; 0.14881344 -0.52690184; -0.0017206029 -0.22051147]

nothing
Float32[-0.30972472 0.48900542; 0.40722048 -0.47656116; 0.27757087 0.12060472; 0.23041193 -0.7000659; 0.118268944 0.45202392; -0.32605898 -0.16675131; 0.63128626 -0.24441159; -0.67778873 -0.36194095; 0.14881344 -0.52690184; -0.0017206029 -0.22051147]

nothing
Float32[-0.30972472 0.48900542; 0.40722048 -0.47656116; 0.27757087 0.12060472; 0.23041193 -0.7000659; 0.118268944 0.45202392; -0.32605898 -0.16675131; 0.63128626 -0.24441159; -0.67778873 -0.36194095; 0.14881344 -0.52690184; -0.0017206029 -0.22051147]

nothing
Float32[-0.30972472 0.48900542; 0.40722048 -0.47656116; 0.27757087 0.12060472; 0.23041193 -0.7000659; 0.118268944 0.45202392; -0.32605898 -0.16675131; 0.63128626 -0.24441159; -0.67778873 -0.36194095; 0.14881344 -0.52690184; -0.0017206029 -0.22051147]

As such, the weights of the NN are obviously not updated.
Any help would be greatly appreciated

Change

ps = Flux.Params(filt.f)

to

ps = Flux.params(filt.f)
2 Likes