[FluxML + Functors] How to walk the model and ∇model simultaneously?

The following two lines of code are commonly used when training a model in Julia.

∇model = gradient(m → loss(m, x, y), model)[1]
opt_state, model = Optimisers.update!(opt_state, model, ∇model)

For debugging pursposes, we can do statistics of model parameters via the fmap and fmapstructure functions from Functors.

For example, I can see the average values of parameters (for each weight matrix and for each bias in all Dense layers) using the following very generic line of code:

stats = fmapstructure(mean, model, exclude=x->x isa Union{Array, CUDA.CuArray} && Functors.isleaf(x))

Output:
(model = (layers = ((weight = -0.00033604607f0, bias = 1.9988879f0, σ = ()), (weight = -0.0045190137f0, bias = -0.0030693f0, σ = ()), (weight = -0.002681108f0, bias = -0.0048592985f0, σ = ()), (weight = -0.015205776f0, bias = -0.012156701f0, 
σ = ()), (paths = NamedTuple{(:weight, :bias, :σ), Tuple{Float32, Float32, Tuple{}}}[(weight = 0.021802468, bias = -0.001723722, σ = ()), (weight = 0.055412512, bias = -0.0015859936, σ = ()), (weight = -0.05021051, bias = 0.0025667339, σ = ()), (weight = -0.04975954, bias = 0.0021612642, σ = ()), (weight = 0.034353513, bias = -0.002381153, σ = ()), (weight = -0.021517549, bias = 0.0015386648, σ = ())],)),),)

Question:
I would like to do the ratio of model/∇model because I want to see how the gradients are related to the original parameters. After having the ratio, we can get the mean, min and max ratio in each layer. Now the thing is how can we generically parse the model and ∇model at the same time?

I know Optimisers do something similar so maybe @mcabbott can help.

1 Like

I have just found a simple solution.

First recursively transform the model into a tree containing NamedTuples, just like ∇model.

model_nt = fmapstructure(identity, model)

At this point in time, model_nt and ∇model have the very same tree structure.
Second step is to apply fmap:

mean_ratios = fmap((v1, v2)->mean(v1./v2), model_nt, ∇model, exclude=x->x isa Union{Array, CUDA.CuArray} && Functors.isleaf(x))

Output:

(model = (layers = ((weight = 4.7656164f0, bias = 0.0f0, σ = NNlib.relu), (weight = 6.3872375f0, bias = 0.0f0, σ = NNlib.relu), (weight = NaN32, bias = 0.0f0, σ = NNlib.relu), (weight = NaN32, bias = 0.0f0, σ = NNlib.relu), (paths = NamedTuple{(:weight, :bias, :σ), Tuple{Float32, Float32, typeof(identity)}}[(weight = 5.61777, bias = 0.0, σ = identity), (weight = 3.6640332, bias = 0.0, σ = identity), (weight = -13.421525, bias = 0.0, σ = identity), (weight = -28.083906, bias = 0.0, σ = identity), (weight = -7.2097316, bias = 0.0, σ = identity), (weight = -0.7657783, bias = 0.0, σ = identity)],)),))
1 Like

Remark:
I have found no dependency of Zygote on Functors. Still, Zygote outputs the gradients in the same form as fmapstructure (tree of NameTuples). The problem is that Zygote can eventually change the gradients output format and in this case the above solution will not work. How likely is this to happen?

1 Like

That looks right. There is no reason we couldn’t have multi-arg fmapstructure(f, x, ys...), except that nobody had a need for it. You could use exclude = Optimisers.isnumeric as the leaf-like condition.

Note that what fmap(f, x, ys...) won’t handle is gradients which are missing branches. This is one of the things making Optimisers.jl more complicated than seems necessary:

julia> using Zygote, Functors, Statistics

julia> x = (a=[1.0, 2.0], b=(3.0, [4.0]));

julia> dx1 = gradient(x -> prod(x.a) + x.b[1] + sum(x.b[2]), x)[1]  # no problem
(a = [2.0, 1.0], b = (1.0, Fill(1.0, 1)))

julia> fmap((v1, v2) -> mean(v1./v2), x, dx1)
(a = 1.25, b = (3.0, 4.0))

julia> dx2 = gradient(x -> prod(x.a), x)[1]  # problem
(a = [2.0, 1.0], b = nothing)

julia> fmap((v1, v2) -> mean(v1./v2), x, dx2)
ERROR: MethodError: no method matching length(::Nothing)

This is unlikely, but other AD might. Diffractor will return a nested set of Tangent types, instead of NamedTuples. Functors needs to know how to walk both of these, to make Optimisers.jl work… but see #46 for gory details.

You may be able to get away with just fmap((v1, v2)->mean(v1./v2), ∇model, model). fmap reconstructs into the structure of the first argument, so if you lead with a tree of (named)tuples you’ll end up with the same. As Michael mentioned though, this is mostly specific to Zygote.

If we ever change the output format from nested plain old Julia objects to something else, you can expect a heads up far in advance and some sort of transition plan.

Yes, it works. And I have found it also solves the missing gradient branches issue @mcabbott mentioned previously.

See the following code to sum up all the gradients in a dataset. It works nicely, but the problem is that every time I call fmap, new memory is created. Overall, a lot of memory is wasted here. This function is very slow.
This leads to the need of having fmap!, which is allowed to update the arrays.

function ∇global(mtw, data_loader)
    xf, yf = first(data_loader) |> gpu
    ∇mtw = gradient(m -> m(xf, yf), mtw)[1]
    # init all gradients to zero
    ∇mtw = fmap(-, ∇mtw, ∇mtw, exclude=Optimisers.isnumeric)

    # sum up all the gradients
    for batch in data_loader
        x, y = batch |> gpu
        ∇mtw_xy = gradient(m -> m(x, y), mtw)[1]
        ∇mtw = fmap((g, gxy)->g .+ gxy, ∇mtw, ∇mtw_xy, exclude=Optimisers.isnumeric) 
    end 

    cnt = length(data_loader)
    ∇mtw = fmap(x->x/cnt, ∇mtw, exclude=Optimisers.isnumeric)
    
    return ∇mtw 
end

Problem solved again. fmap only needs a function that does in-place operations.

∇mg = fmap((v1, v2)->broadcast!(+, v1, v1, v2), ∇mg, ∇m, exclude=Optimisers.isnumeric)

Bellow is the MWE and the output. We can see the in-place version still does some allocations, but it’s not allocating arrays.

using Flux
using Optimisers

struct Foo
    x
    y
    z
end

@functor Foo
(f::Foo)(a) = sum(f.z(f.y(f.x(a))))

function test()
    m = Foo(Dense(250=>200), Dense(200, 150), Dense(150=>100))
    x = rand(250)
    ∇m = gradient(m_->m_(x), m)[1]
    # init gradients to zeros; use Array(zero(x)) to transform FillArrays types into vectors
    ∇mg = fmap(x->Array(zero(x)), ∇m, exclude=Optimisers.isnumeric)
    # in place operations
    @time ∇mg = fmap((v1, v2)->broadcast!(+, v1, v1, v2), ∇mg, ∇m, exclude=Optimisers.isnumeric)
    # memory allocations
    @time ∇mg2 = fmap((v1, v2)-> v1 .+ v2, ∇mg, ∇m, exclude=Optimisers.isnumeric)
end

test()

The timing output:

0.000170 seconds (41 allocations: 1.797 KiB)
0.000208 seconds (50 allocations: 375.078 KiB)
1 Like