Summary
I am looking for some help with implementing a custom validation loop using FluxTraining
. Briefly, I am using Flux
to optimize a matrix to satisfy a data-driven cost-function. Since the cost function depends on the data, it’s better for me to use an ML framework than a regular optimization package. The model is a single matrix, so in Flux parlance:
model = Dense(m => n, identity; bias=false)
Due to the structure of my loss function and evaluation metrics, I need to access the weights (the matrix) during computation of the loss and evaluation.
I am running into an error where I am being blocked by FluxTraining from accessing the mode weights during eval (but not training).
Epoch 1 MyValidationPhase() ...
ERROR: LoadError: FluxTraining.ProtectedException("Read access to Learner.model disallowed.")
Stacktrace:
[1] getfieldperm(data::Learner, field::Symbol, perm::Nothing)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:63
[2] getproperty(protected::FluxTraining.Protected{Learner}, field::Symbol)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:18
[3] step!(metric::ReconstructionMetric{Number}, learner::FluxTraining.Protected{Learner}, phase::MyValidationPhase)
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:260
[4] on(#unused#::FluxTraining.Events.StepEnd, phase::MyValidationPhase, metrics::Metrics, learner::FluxTraining.Protected{Learner})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/metrics.jl:74
[5] _on(e::FluxTraining.Events.StepEnd, p::MyValidationPhase, cb::Metrics, learner::Learner)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
[6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.StepEnd, phase::MyValidationPhase, learner::Learner)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
[7] (::FluxTraining.var"#handlefn#82"{Learner, MyValidationPhase})(e::FluxTraining.Events.StepEnd)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:129
[8] runstep(stepfn::var"#9#10"{Learner}, learner::Learner, phase::MyValidationPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Matrix{Float32}, Matrix{Float32}}})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:134
[9] step!(learner::Learner, phase::MyValidationPhase, batch::Tuple{Matrix{Float32}, Matrix{Float32}})
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:301
[10] (::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}})(#unused#::Function)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:24
[11] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, MyValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::MyValidationPhase)
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:105
[12] epoch!(learner::Learner, phase::MyValidationPhase, dataiter::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}})
@ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22
[13] main()
@ Main ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:449
[14] top-level scope
@ ~/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453
in expression starting at /home/alec/.guild/runs/1d889bb02af249a1a4fcf77833cc2296/train.jl:453
System/Julia information
I am running on Pop!_OS 22.04 with Julia 1.9.0. All code is being run on the CPU.
Here is my Manifest.toml file.
The Training Phase (this works)
My custom training phase is working fine.
function pack_loss(state::FluxTraining.PropDict, loss_names, loss_vals)
state.loss = reduce(+, loss_vals)
for i in eachindex(loss_names)
setproperty!(state, Symbol(loss_names[i]), loss_vals[i])
end
return state
end
function unpack_loss(state::FluxTraining.PropDict, loss_name)
return getproperty(state, Symbol(loss_name))
end
"""
Custom training phase
that exposes the model parameters.
"""
struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase end
"""
FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
`step!` function for the custom training phase.
The `runstep` function returns the `state`, but the `state` is also packaged
into `learner.step`, which OnlineStats.means that we can access it later.
This function handles the custom loss function with keyword arguments
and also handles the state, adding the loss components to the `state`.
"""
function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
xs, ys = batch
# `runstep` has the signature: runstep(stepfn, learner, phase, batch)
# where `stepfn` has the signature: stepfn(handle, state::PropDict)
FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do handle, state
# `stepfn` calls the `_gradient` method which is a fallback
# for different formats of Flux models,
# but requires a lossfn as the first input, which has the signature: lossfn(model)
state.grads = FluxTraining._gradient(learner.optimizer,
learner.model,
learner.params) do model
# get the model outputs
state.ŷs = model(state.xs)
handle(FluxTraining.LossBegin())
# compute the loss using the custom loss function
loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
# pack the loss in the state
state = pack_loss(state, loss_names, loss_vals)
handle(FluxTraining.BackwardBegin())
# return the total loss
return state.loss
end
handle(BackwardEnd())
# update parameters
learner.params, learner.model = FluxTraining._update!(learner.optimizer,
learner.params,
learner.model, state.grads)
end
end
The main differences between this and the standard training phase baked into FluxTraining
are that the loss function returns multiple components of the loss, and that the loss function has the signature lossfn(y_hat, y; m)
where m
is the model.
The Validation Phase (this is broken)
For reasons of protecting unpublished work, my advisor has asked me to obscure some of the details here, but I am certain that nothing I am leaving out is causing the problem. Thank you for understanding.
Here’s the validation phase code. I’m pretty sure this is working fine.
struct MyValidationPhase <: FluxTraining.AbstractValidationPhase end
function FluxTraining.step!(learner, phase::MyValidationPhase, batch)
xs, ys = batch
FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do _, state
state.ŷs = learner.model(state.xs)
loss_names, loss_vals = learner.lossfn(state.ŷs, state.ys; m = learner.model)
state = pack_loss(state, loss_names, loss_vals)
end
end
Here’s the code for the custom eval metrix callback that’s causing me problems…
mutable struct ReconstructionMetric{T} <: FluxTraining.AbstractMetric
decisionfn::Any
decisionfn_kwargs::NamedTuple
reconstructionfn::Any
reconstructionfn_kwargs::NamedTuple
statistic::OnlineStats.OnlineStat{T}
_statistic::Any
name::Any
device::Any
P::Any
last::Union{Nothing, T}
target_signal_repr::Any
end
function FluxTraining.reset!(metric::ReconstructionMetric{T}) where T
metric.statistic = deepcopy(metric._statistic)
end
function FluxTraining.step!(metric::ReconstructionMetric, learner, phase)
if phase isa metric.P
y = metric.decisionfn(learner.model.weight,
metric.target_signal_repr;
metric.decisionfn_kwargs...)
x = metric.reconstructionfn(y,
learner.model.weight;
metric.reconstructionfn_kwargs...)
metric.last = StatsBase.cor(x, metric.target_signal_repr)
OnlineStats.fit!(metric.statistic, metric.last)
else
metric.last = nothing
end
end
# Here is where I think the problem is.
# I am getting blocked from accessing learner.model.weight,
# even though I am in theory using the `stateaccess` named tuple correctly...afaict.
function FluxTraining.stateaccess(::ReconstructionMetric)
(
model = FluxTraining.Read(),
params = FluxTraining.Read(),
cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
step = FluxTraining.Read()
)
end
FluxTraining.runafter(::ReconstructionMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::ReconstructionMetric) = metric.last
FluxTraining.metricname(metric::ReconstructionMetric) = metric.name
function FluxTraining.epochvalue(metric::ReconstructionMetric)
if isnothing(metric.last)
nothing
else
OnlineStats.value(metric.statistic)
end
end
Any advice would be appreciated!