The Training.apply_gradient non-mutating function is implemented as below in Lux.jl:
function apply_gradients(ts::TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
@set! ts.parameters = ps
@set! ts.optimizer_state = optimizer_state
@set! ts.step = ts.step + 1
return ts
end
The way I understand is, @set! modifies the struct in place. But how does it still behave as a non-mutating function (the output of the function is the updated ts, but the ts passed into the function remains unchanged)?