How to force flux parameters to be within a certain compact set? (projection)

Hi, I’m trying to realise projected gradient descent-like methods.

As Flux said, I constructed my custom training loop as follows,

function train!(m::ICNN, loss, data, opt)
    ps = Flux.params(ICNN)
    for d in data
        gs = gradient(ps) do
            training_loss = loss(d...)
            return training_loss
        end
        Flux.update!(opt, ps, gs)
        # projection
        for layer in m.layers
            layer.Wz = project_nonnegative(layer.Wz)
        end
    end
end

However, the network are defined as immutable struct, and that’s why the projection procedure occurs the following errors:

ERROR: LoadError: setfield! immutable struct of type ICNN_Layer cannot be changed
Stacktrace:
 [1] setproperty!(::ICNN_Layer, ::Symbol, ::Array{Float64,2}) at ./Base.jl:34
 [2] train!(::ICNN, ::Function, ::DataLoader{Tuple{LinearAlgebra.Adjoint{Float64,Array{Float64,2}},LinearAlgebra.Adjoint{Float64,Array{Float64,2}},LinearAlgebra.Adjoint{Float64,Array{Float64,2}}}}, ::Flux.Optimise.Optimiser) at /home/jinrae/.julia/dev/GliderPathPlanning/src/InputConvexNeuralNetworks.jl:168
 [3] top-level scope at /home/jinrae/.julia/dev/GliderPathPlanning/test/test.jl:81
 [4] include(::String) at ./client.jl:457
 [5] top-level scope at REPL[1]:1
in expression starting at /home/jinrae/.julia/dev/GliderPathPlanning/test/test.jl:81

What’s the best practical way to project network parameters?

EDIT: the network is constructed in a way similar to Dense function.

I used such an alternative way but it seems not be easily generalised for other cases:

function train!(m::ICNN, loss, data, opt)
    ps = Flux.params(ICNN)
    for d in data
        gs = gradient(ps) do
            training_loss = loss(d...)
            return training_loss
        end
        Flux.update!(opt, ps, gs)
        # projection
        for layer in m.layers
            if isdefined(layer, :Wz)  # some layer may not have Wz for convenience
                project_nonnegative!(layer.Wz)
                if !all(layer.Wz .>= 0.0)
                    error("Projection seems not work.")
                end
            end
        end
    end
end

function project_nonnegative!(ps)
    ps .-= ps .* (ps .< 0.0)
end