I’m new to Flux/Zygote and I’m trying to train a model with dense layers such that the associated weight matrices are lower triangular. It seems like the way to do this is to write a custom train! function and zero the gradients associated with “unwanted” entries of the weight matrices, e.g. via broadcast multiplication by the appropriate mask. I can quite figure out exactly how to accomplish this though since it seems like the Zygote gradients are immutable. Here is my current code:
function custom_train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
for d in data
train_loss, back = pullback(() -> loss(d...), ps)
gs = back(one(train_loss))
# zero all gradients outside of lower triangular weight matrices in each dense layer
for i in 1:2:length(ps)
nrows, ncols = size(ps[i])
mask = [x >= y ? 1.0 : 0. for x in 1:nrows, y in 1:ncols]
gs[ps[i]] .*= nograd(mask)
end
update!(opt, ps, gs)
cb()
end
end
This throws the error: "ERROR: *Only reference types can be differentiated with Params*
"
Is there some way to “detach” the gradients (e.g. via converting them to a new type), mutating them via the mask, and then converting them back to Zygote.gradient objects?
Any advice is very much appreciated, especially if there is a better way to impose arbitrary constraints of weight matrices – I’ve been down a rabbit hole trying to figure out how to make this work for hours…