I’m new to Flux/Zygote and I’m trying to train a model with dense layers such that the associated weight matrices are lower triangular. It seems like the way to do this is to write a custom train! function and zero the gradients associated with “unwanted” entries of the weight matrices, e.g. via broadcast multiplication by the appropriate mask. I can quite figure out exactly how to accomplish this though since it seems like the Zygote gradients are immutable. Here is my current code:
function custom_train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) for d in data train_loss, back = pullback(() -> loss(d...), ps) gs = back(one(train_loss)) # zero all gradients outside of lower triangular weight matrices in each dense layer for i in 1:2:length(ps) nrows, ncols = size(ps[i]) mask = [x >= y ? 1.0 : 0. for x in 1:nrows, y in 1:ncols] gs[ps[i]] .*= nograd(mask) end update!(opt, ps, gs) cb() end end
This throws the error: "ERROR: *Only reference types can be differentiated with
Is there some way to “detach” the gradients (e.g. via converting them to a new type), mutating them via the mask, and then converting them back to Zygote.gradient objects?
Any advice is very much appreciated, especially if there is a better way to impose arbitrary constraints of weight matrices – I’ve been down a rabbit hole trying to figure out how to make this work for hours…