How to read model weights using FluxTraining (stateaccess issues)?

Summary

I am looking for some help with implementing a custom validation loop using FluxTraining. Briefly, I am using Flux to optimize a matrix to satisfy a data-driven cost-function. Since the cost function depends on the data, it’s better for me to use an ML framework than a regular optimization package. The model is a single matrix, so in Flux parlance:

model = Dense(m => n, identity; bias=false)

Due to the structure of my loss function and evaluation metrics, I need to access the weights (the matrix) during computation of the loss and evaluation.

I am running into an error where I am being blocked by FluxTraining from accessing the mode weights during eval (but not training).

Epoch 1 MyValidationPhase() ...
ERROR: LoadError: FluxTraining.ProtectedException("Read access to Learner.model disallowed.")
Stacktrace:
  [1] getfieldperm(data::Learner, field::Symbol, perm::Nothing)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:63
  [2] getproperty(protected::FluxTraining.Protected{Learner}, field::Symbol)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:18
  [3] step!(metric::ReconstructionMetric{Number}, learner::FluxTraining.Protected{Learner}, phase::MyValidationPhase)
    @ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:260
  [4] on(#unused#::FluxTraining.Events.StepEnd, phase::MyValidationPhase, metrics::Metrics, learner::FluxTraining.Protected{Learner})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/metrics.jl:74
  [5] _on(e::FluxTraining.Events.StepEnd, p::MyValidationPhase, cb::Metrics, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
  [6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.StepEnd, phase::MyValidationPhase, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
  [7] (::FluxTraining.var"#handlefn#82"{Learner, MyValidationPhase})(e::FluxTraining.Events.StepEnd)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:129
  [8] runstep(stepfn::var"#9#10"{Learner}, learner::Learner, phase::MyValidationPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Matrix{Float32}, Matrix{Float32}}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:134
  [9] step!(learner::Learner, phase::MyValidationPhase, batch::Tuple{Matrix{Float32}, Matrix{Float32}})
    @ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:301
 [10] (::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}})(#unused#::Function)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:24
 [11] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::MyValidationPhase)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:105
 [12] epoch!(learner::Learner, phase::MyValidationPhase, dataiter::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22
 [13] main()
    @ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:449
 [14] top-level scope
    @ ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453
in expression starting at /home/alec/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453

System/Julia information

I am running on Pop!_OS 22.04 with Julia 1.9.0. All code is being run on the CPU.
Here is my Manifest.toml file.

The Training Phase (this works)

My custom training phase is working fine.


function pack_loss(state::FluxTraining.PropDict, loss_names, loss_vals)
    state.loss = reduce(+, loss_vals)
    for i in eachindex(loss_names)
        setproperty!(state, Symbol(loss_names[i]), loss_vals[i])
    end
    return state
end

function unpack_loss(state::FluxTraining.PropDict, loss_name)
    return getproperty(state, Symbol(loss_name))
end

"""
Custom training phase
that exposes the model parameters.
"""
struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase end

"""
    FluxTraining.step!(learner, phase::MyTrainingPhase, batch)

`step!` function for the custom training phase.
The `runstep` function returns the `state`, but the `state` is also packaged
into `learner.step`, which OnlineStats.means that we can access it later.

This function handles the custom loss function with keyword arguments
and also handles the state, adding the loss components to the `state`.
"""
function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
    xs, ys = batch

    # `runstep` has the signature: runstep(stepfn, learner, phase, batch)
    # where `stepfn` has the signature: stepfn(handle, state::PropDict)
    FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do handle, state
        # `stepfn` calls the `_gradient` method which is a fallback
        # for different formats of Flux models,
        # but requires a lossfn as the first input, which has the signature: lossfn(model)
        state.grads = FluxTraining._gradient(learner.optimizer,
            learner.model,
            learner.params) do model
            # get the model outputs
            state.ŷs = model(state.xs)
            handle(FluxTraining.LossBegin())
            # compute the loss using the custom loss function
            loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
            # pack the loss in the state
            state = pack_loss(state, loss_names, loss_vals)
            handle(FluxTraining.BackwardBegin())
            # return the total loss
            return state.loss
        end
        handle(BackwardEnd())
        # update parameters
        learner.params, learner.model = FluxTraining._update!(learner.optimizer,
            learner.params,
            learner.model, state.grads)
    end
end

The main differences between this and the standard training phase baked into FluxTraining are that the loss function returns multiple components of the loss, and that the loss function has the signature lossfn(y_hat, y; m) where m is the model.

The Validation Phase (this is broken)

For reasons of protecting unpublished work, my advisor has asked me to obscure some of the details here, but I am certain that nothing I am leaving out is causing the problem. Thank you for understanding.

Here’s the validation phase code. I’m pretty sure this is working fine.

struct MyValidationPhase <: FluxTraining.AbstractValidationPhase end

function FluxTraining.step!(learner, phase::MyValidationPhase, batch)
    xs, ys = batch
    FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do _, state
        state.ŷs = learner.model(state.xs)
        loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
        state = pack_loss(state, loss_names, loss_vals)
    end
end

Here’s the code for the custom eval metrix callback that’s causing me problems…

mutable struct ReconstructionMetric{T} <: FluxTraining.AbstractMetric
    decisionfn::Any
    decisionfn_kwargs::NamedTuple
    reconstructionfn::Any
    reconstructionfn_kwargs::NamedTuple
    statistic::OnlineStats.OnlineStat{T}
    _statistic::Any
    name::Any
    device::Any
    P::Any
    last::Union{Nothing, T}
    target_signal_repr::Any
end

function FluxTraining.reset!(metric::ReconstructionMetric{T}) where T
    metric.statistic = deepcopy(metric._statistic)
end

function FluxTraining.step!(metric::ReconstructionMetric, learner, phase)
    if phase isa metric.P
        y = metric.decisionfn(learner.model.weight,
            metric.target_signal_repr;
            metric.decisionfn_kwargs...)
        x = metric.reconstructionfn(y,
            learner.model.weight;
            metric.reconstructionfn_kwargs...)
        metric.last = StatsBase.cor(x, metric.target_signal_repr)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

# Here is where I think the problem is.
# I am getting blocked from accessing learner.model.weight,
# even though I am in theory using the `stateaccess` named tuple correctly...afaict.
function FluxTraining.stateaccess(::ReconstructionMetric)
    (
        model = FluxTraining.Read(),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
        step = FluxTraining.Read()
    )
end

FluxTraining.runafter(::ReconstructionMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::ReconstructionMetric) = metric.last
FluxTraining.metricname(metric::ReconstructionMetric) = metric.name

function FluxTraining.epochvalue(metric::ReconstructionMetric)
    if isnothing(metric.last)
        nothing
    else
        OnlineStats.value(metric.statistic)
    end
end

Any advice would be appreciated!

It looks like you have enough to create a MWE without revealing any unpublished details. Doing that would help immensely with tracking down what the issue is. As-is, having to mentally trace through the execution of each code path with just the snippets you’ve provided is a bit too much effort.

Fair enough, here is a MWE. Any help you can provide would be appreciated. Thanks!

using Flux
using FluxTraining
using OnlineStats
using LinearAlgebra

"""
Custom type for my metric.
This is basically a duplicate of the standard Metric.
"""
mutable struct MyMetric{T} <: FluxTraining.AbstractMetric
    statistic::OnlineStats.OnlineStat{T}
    _statistic::Any
    name::Any
    device::Any
    P::Any
    last::Union{Nothing, T}
end

"""
Outer constructor for MyMetric.
"""
function MyMetric(;
        statistic = OnlineStats.Mean(Float32),
        device = cpu,
        phase = ValidationPhase,
        name = "MyMetric")
    return MyMetric(statistic, deepcopy(statistic), name, device, phase, nothing)
end

"""
Reset MyMetric back to the initial value.
"""
function FluxTraining.reset!(metric::MyMetric{T}) where T
    metric.statistic = deepcopy(metric._statistic)
end

"""
We will use the L1 norm as an "example function"
that requires access to model weights.
"""
function l1_metric(W::Matrix)
    return norm(W, 1) / size(W, 1)
end

"""
Compute the metric by taking the L1 norm of the model weight matrix.
"""
function FluxTraining.step!(metric::MyMetric, learner, phase)
    if phase isa metric.P
        metric.last = l1_metric(learner.model.weight)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

function Base.show(io::IO, metric::MyMetric{T}) where {T}
    print(io, "Metric(", metric.name, ")")
end

FluxTraining.runafter(::MyMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::MyMetric) = metric.last
FluxTraining.metricname(metric::MyMetric) = metric.name

function FluxTraining.epochvalue(metric::MyMetric)
    if isnothing(metric.last)
        nothing
    else
        OnlineStats.value(metric.statistic)
    end
end

function FluxTraining.stateaccess(::MyMetric)
    return (
        model = FluxTraining.Read(),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
        step = FluxTraining.Read(),
    )
end

function main()
    in_dim = 10
    out_dim = 1
    n_samples = 64
    model = Dense(in_dim => out_dim, identity; bias=false)

    X = rand(in_dim, n_samples) |> f32
    y = rand(out_dim, n_samples) |> f32
    train_dataloader = Flux.DataLoader((X, y))
    val_dataloader = deepcopy(train_dataloader)

    callbacks = [FluxTraining.Metrics(MyMetric())]
    opt_state = Flux.Adam(1f-4)

    learner = FluxTraining.Learner(model, Flux.mse; callbacks = callbacks, optimizer = opt_state)

    for i = 1:3
        FluxTraining.epoch!(learner, FluxTraining.TrainingPhase(), train_dataloader)
        FluxTraining.epoch!(learner, ValidationPhase(), val_dataloader)
    end
end

The issue is that stateaccess need to be for ::Metrics not ::MyMetric.

function FluxTraining.stateaccess(::FluxTraining.Metrics)
    (model = FluxTraining.Read(),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(),
            metricsepoch = FluxTraining.Write(),
            history = FluxTraining.Read()),
        step = FluxTraining.Read())
end