Getting gradients with loss using for-loop is slow in Flux.jl

Hello.

I am trying to train an original model using Flux.jl for sequential data. I want to use a loss function that utilizes a for loop to recursively use the output of the neural network as input for each data point. However, when I write the code as follows, the gradient calculation takes an long time. It seems strange that the code recompiles every time even when running it multiple times. I’m not sure how to fix it.

using Flux

function loss(model1, model2, xs, y0, ŷs)
    l = zero(eltype(x[begin]))
    y = y0
    for (i,x) in enumerate(xs)
        h = model1(vcat(x, y))
        y = model2(h)
        l += Flux.mse(ŷs[i], sum(y))
    end
    return l
end

xs = [randn(Float32, (16)) for _ in 1:8]
y0 = randn(Float32, (16))
ŷs = [[1f0] for _ in 1:8]

m1 = Dense(32=>16)
m2 = Dense(16=>16)
julia> @time gradient(m1,m2) do m1,m2
           loss(m1, m2, xs, y0, ŷs)
       end
  3.957895 seconds (27.08 M allocations: 1.390 GiB, 8.21% gc time, 99.87% compilation time: 7% of which was recompilation)

julia> @time gradient(m1,m2) do m1,m2
           loss(m1, m2, xs, y0, ŷs)
       end
  0.059321 seconds (379.88 k allocations: 20.261 MiB, 98.93% compilation time)

I would appreciate it if you could provide guidance on how to address this problem.

Thank you.

Untyped global variables are a performance problem. Try making those values local to a function:

using Flux

function loss(model1, model2, xs, y0, ŷs)
    l = zero(eltype(xs[begin]))
    y = y0
    for (i,x) in enumerate(xs)
        h = model1(vcat(x, y))
        y = model2(h)
        l += Flux.mse(ŷs[i], sum(y))
    end
    return l
end

function testit()
    xs = [randn(Float32, (16)) for _ in 1:8]
    y0 = randn(Float32, (16))
    ŷs = [[1f0] for _ in 1:8]

    m1 = Dense(32=>16)
    m2 = Dense(16=>16)

    gradient(m1,m2) do m1,m2
        loss(m1, m2, xs, y0, ŷs)
    end
end
julia> @time testit();
  4.467115 seconds (11.37 M allocations: 661.364 MiB, 6.02% gc time, 99.83% compilation time)

julia> @time testit();
  0.002532 seconds (2.63 k allocations: 153.891 KiB)

Or use a function

xs = [randn(Float32, (16)) for _ in 1:8]
y0 = randn(Float32, (16))
ŷs = [[1f0] for _ in 1:8]

m1 = Dense(32=>16)
m2 = Dense(16=>16)

function testitagain(xs, y0, ŷs, m1, m2)
    gradient(m1,m2) do m1,m2
        loss(m1, m2, xs, y0, ŷs)
    end
end

julia> @time testitagain(xs, y0, ŷs, m1, m2)
0.233817 seconds (782.57 k allocations: 42.406 MiB, 6.90% gc time, 99.66% compilation time)

julia> @time testitagain(xs, y0, ŷs, m1, m2)
0.000642 seconds (2.58 k allocations: 145.219 KiB)

Thank you very much!

I now understand that the code example I provided had inappropriate handling of global variables and was causing recompilation with each execution. On the other hand, even though I was originally passing the training data and other parameters as function arguments in the code I wrote, it still resulted in recompilation every time.

Upon re-examining the code, I realized that the recompilation was occurring due to the use of Flux.params(model1) for L2 regularization. For instance, consider the following loss function.

function loss2(model1, model2, xs, y0, ŷs)
    0.01f0 * sum(x->sum(x.^2), Flux.params(model1))
end

Regarding this phenomenon, do you happen to know any effective methods for regularization?

The L2 regularization is typically done in optimiser

I apologize for the delay in my response. Thank you very much!

I didn’t know that L2 regularization had already been applied. It’s convenient that explicit implementation of regularization is not necessary!