 Simple custom recurrent layer in Flux

Hi,

I am looking into Flux and trying to understand the recurrent functionality. The following code defines a first order low pass filter with Flux.Recur.

using Flux
using Flux.Tracker

c = 0.5

function LP(h, x)
h = c * x + (1.0 - c) * h
return h, h
end

h = 0.
recurLP = Flux.Recur(LP, h)

function model(x)
out = zeros(1000,1)
for ii = 1:1000
out[ii] = recurLP(x[ii])
end
return out
end

function loss(x, y)
return sum((model(x) .- y).^2)
end

Let’s filter some noise and calculate the loss.

inp = randn(1000,1);
loss(inp, inp)

317.4691466179187

It seems that the filter is filtering something So far so good, now we could optimize the filter coefficient, such that the filter stops filtering, e.g. for c = 1.0 the loss should become 0.
However, neither:

Flux.train!(loss, params(model), [(inp, inp)], ADAM(0.1))

nor

c = param(c)
gs = Tracker.gradient(() -> loss(inp, inp), params(c))

seem to work.

I get error messages
MethodError: no method matching back!(::Float64)
and
MethodError: no method matching Float64(::Tracker.TrackedReal{Float64}), respectively.

What am I missing?

Thanks in advance 1 Like

The problem is that your loss is a regular Float64 instead of a Trackedreal number. Through this, Tracker is unable to backpropagate.
Also, I believe mutation of arrays will give difficulties so it’s better to push! to the out array.

function model(x)
out = []
for feature in x
push!(out, recurLP(feature))
end
return out
end

Hej,
Thanks for the help. I had another look and now made it work. Instead of looping over the timeseries we can also use the dot notation. And also the problem with Tracker is solved. Here the code:

using Flux
using Flux.Tracker

c = 0.5

function LP(h, x)
h = c * x + (1.0 - c) * h
return h, h
end

h = 0.
recurLP = Flux.Recur(LP, h)

function model(x)
return recurLP.(x)
end

function loss(x, y)
return sum((model(x) .- y).^2)
end

inp = randn(1000,1);
println(loss(inp, inp))

Flux.reset!(recurLP)
println(recurLP.state)

c = param(c)
gs = Tracker.gradient(() -> loss(inp, inp), params(c))
Δ = gs[c]

for i = 1:50
Tracker.update!(c, -0.00001Δ)
if i%10 == 0
println("loss= ", loss(inp, inp))
println("c= ", c)
println()
end
end