The Training.apply_gradient
non-mutating function is implemented as below in Lux.jl
:
function apply_gradients(ts::TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
@set! ts.parameters = ps
@set! ts.optimizer_state = optimizer_state
@set! ts.step = ts.step + 1
return ts
end
The way I understand is, @set!
modifies the struct in place. But how does it still behave as a non-mutating function (the output of the function is the updated ts
, but the ts
passed into the function remains unchanged)?