Mutate Zygote Gradients with a Custom Mask Before Update?

I’m new to Flux/Zygote and I’m trying to train a model with dense layers such that the associated weight matrices are lower triangular. It seems like the way to do this is to write a custom train! function and zero the gradients associated with “unwanted” entries of the weight matrices, e.g. via broadcast multiplication by the appropriate mask. I can quite figure out exactly how to accomplish this though since it seems like the Zygote gradients are immutable. Here is my current code:

function custom_train!(loss, ps, data, opt; cb = () -> ())
   ps = Params(ps)
    for d in data
        train_loss, back = pullback(() -> loss(d...), ps)
        gs = back(one(train_loss))
        # zero all gradients outside of lower triangular weight matrices in each dense layer
        for i in 1:2:length(ps)
            nrows, ncols = size(ps[i])
            mask = [x >= y ? 1.0 : 0. for x in 1:nrows, y in 1:ncols]
            gs[ps[i]] .*= nograd(mask)
        update!(opt, ps, gs)

This throws the error: "ERROR: *Only reference types can be differentiated with Params* "

Is there some way to “detach” the gradients (e.g. via converting them to a new type), mutating them via the mask, and then converting them back to Zygote.gradient objects?

Any advice is very much appreciated, especially if there is a better way to impose arbitrary constraints of weight matrices – I’ve been down a rabbit hole trying to figure out how to make this work for hours…

There’s no need to do that for the code you have above, because the only place gradients are “tracked” is inside the callback to pullback. In other words, there’s no such thing as a Zygote.gradient object and gs is just a collection of arrays :slight_smile:.

With that in mind, the error has to be occurring within the call stack of loss somewhere. If you share a complete MWE with the loss and all the auxiliary training code, we can troubleshoot that part.