## Summary

I am looking for some help with implementing a custom validation loop using `FluxTraining`

. Briefly, I am using `Flux`

to optimize a matrix to satisfy a data-driven cost-function. Since the cost function depends on the data, it’s better for me to use an ML framework than a regular optimization package. The model is a single matrix, so in Flux parlance:

```
model = Dense(m => n, identity; bias=false)
```

Due to the structure of my loss function and evaluation metrics, I need to access the weights (the matrix) during computation of the loss and evaluation.

I am running into an error where I am being blocked *by FluxTraining* from accessing the mode weights during eval (but not training).

```
Epoch 1 MyValidationPhase() ...
ERROR: LoadError: FluxTraining.ProtectedException("Read access to Learner.model disallowed.")
Stacktrace:
[1] getfieldperm(data::Learner, field::Symbol, perm::Nothing)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:63
[2] getproperty(protected::FluxTraining.Protected{Learner}, field::Symbol)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:18
[3] step!(metric::ReconstructionMetric{Number}, learner::FluxTraining.Protected{Learner}, phase::MyValidationPhase)
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:260
[4] on(#unused#::FluxTraining.Events.StepEnd, phase::MyValidationPhase, metrics::Metrics, learner::FluxTraining.Protected{Learner})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/metrics.jl:74
[5] _on(e::FluxTraining.Events.StepEnd, p::MyValidationPhase, cb::Metrics, learner::Learner)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
[6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.StepEnd, phase::MyValidationPhase, learner::Learner)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
[7] (::FluxTraining.var"#handlefn#82"{Learner, MyValidationPhase})(e::FluxTraining.Events.StepEnd)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:129
[8] runstep(stepfn::var"#9#10"{Learner}, learner::Learner, phase::MyValidationPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Matrix{Float32}, Matrix{Float32}}})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:134
[9] step!(learner::Learner, phase::MyValidationPhase, batch::Tuple{Matrix{Float32}, Matrix{Float32}})
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:301
[10] (::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}})(#unused#::Function)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:24
[11] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::MyValidationPhase)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:105
[12] epoch!(learner::Learner, phase::MyValidationPhase, dataiter::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22
[13] main()
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:449
[14] top-level scope
@ ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453
in expression starting at /home/alec/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453
```

## System/Julia information

I am running on Pop!_OS 22.04 with Julia 1.9.0. All code is being run on the CPU.

Here is my Manifest.toml file.

## The Training Phase (this works)

My custom training phase is working fine.

```
function pack_loss(state::FluxTraining.PropDict, loss_names, loss_vals)
state.loss = reduce(+, loss_vals)
for i in eachindex(loss_names)
setproperty!(state, Symbol(loss_names[i]), loss_vals[i])
end
return state
end
function unpack_loss(state::FluxTraining.PropDict, loss_name)
return getproperty(state, Symbol(loss_name))
end
"""
Custom training phase
that exposes the model parameters.
"""
struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase end
"""
FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
`step!` function for the custom training phase.
The `runstep` function returns the `state`, but the `state` is also packaged
into `learner.step`, which OnlineStats.means that we can access it later.
This function handles the custom loss function with keyword arguments
and also handles the state, adding the loss components to the `state`.
"""
function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
xs, ys = batch
# `runstep` has the signature: runstep(stepfn, learner, phase, batch)
# where `stepfn` has the signature: stepfn(handle, state::PropDict)
FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do handle, state
# `stepfn` calls the `_gradient` method which is a fallback
# for different formats of Flux models,
# but requires a lossfn as the first input, which has the signature: lossfn(model)
state.grads = FluxTraining._gradient(learner.optimizer,
learner.model,
learner.params) do model
# get the model outputs
state.ŷs = model(state.xs)
handle(FluxTraining.LossBegin())
# compute the loss using the custom loss function
loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
# pack the loss in the state
state = pack_loss(state, loss_names, loss_vals)
handle(FluxTraining.BackwardBegin())
# return the total loss
return state.loss
end
handle(BackwardEnd())
# update parameters
learner.params, learner.model = FluxTraining._update!(learner.optimizer,
learner.params,
learner.model, state.grads)
end
end
```

The main differences between this and the standard training phase baked into `FluxTraining`

are that the loss function returns multiple components of the loss, and that the loss function has the signature `lossfn(y_hat, y; m)`

where `m`

is the model.

## The Validation Phase (this is broken)

For reasons of protecting unpublished work, my advisor has asked me to obscure some of the details here, but I am certain that nothing I am leaving out is causing the problem. Thank you for understanding.

Here’s the validation phase code. I’m pretty sure this is working fine.

```
struct MyValidationPhase <: FluxTraining.AbstractValidationPhase end
function FluxTraining.step!(learner, phase::MyValidationPhase, batch)
xs, ys = batch
FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do _, state
state.ŷs = learner.model(state.xs)
loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
state = pack_loss(state, loss_names, loss_vals)
end
end
```

Here’s the code for the custom eval metrix callback that’s causing me problems…

```
mutable struct ReconstructionMetric{T} <: FluxTraining.AbstractMetric
decisionfn::Any
decisionfn_kwargs::NamedTuple
reconstructionfn::Any
reconstructionfn_kwargs::NamedTuple
statistic::OnlineStats.OnlineStat{T}
_statistic::Any
name::Any
device::Any
P::Any
last::Union{Nothing, T}
target_signal_repr::Any
end
function FluxTraining.reset!(metric::ReconstructionMetric{T}) where T
metric.statistic = deepcopy(metric._statistic)
end
function FluxTraining.step!(metric::ReconstructionMetric, learner, phase)
if phase isa metric.P
y = metric.decisionfn(learner.model.weight,
metric.target_signal_repr;
metric.decisionfn_kwargs...)
x = metric.reconstructionfn(y,
learner.model.weight;
metric.reconstructionfn_kwargs...)
metric.last = StatsBase.cor(x, metric.target_signal_repr)
OnlineStats.fit!(metric.statistic, metric.last)
else
metric.last = nothing
end
end
# Here is where I think the problem is.
# I am getting blocked from accessing learner.model.weight,
# even though I am in theory using the `stateaccess` named tuple correctly...afaict.
function FluxTraining.stateaccess(::ReconstructionMetric)
(
model = FluxTraining.Read(),
params = FluxTraining.Read(),
cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
step = FluxTraining.Read()
)
end
FluxTraining.runafter(::ReconstructionMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::ReconstructionMetric) = metric.last
FluxTraining.metricname(metric::ReconstructionMetric) = metric.name
function FluxTraining.epochvalue(metric::ReconstructionMetric)
if isnothing(metric.last)
nothing
else
OnlineStats.value(metric.statistic)
end
end
```

Any advice would be appreciated!