Hello, I want to train a Neural network that improves the pixels of an image, but with a sparse loss function, i.e. the true pixel information exists only sparsely, and the positions are different for each training example.
How do I write the loss function and the optimiser so that only the right weights are updated?
A minimal 1D example that I think does not do what I want it to do:
using Flux, MLUtils
#create some random data, including a mask that tells where the target is known
X = rand(Float32, 64, 1000);
Y = rand(Float32, 64, 1000);
#the mask defines the positions in which the label exists.
#Only for these pixels the corresponding gradients should be updated
MASK = rand(Bool, 64, 1000);
train_loader = DataLoader((X, Y, MASK), batchsize = 100)
#setup some very basic network
model = Chain(Dense(64, 128), Dense(128, 64))
opt_state = Flux.setup(Flux.Adam(), model)
for (x, y, mask) in train_loader
loss, grads = Flux.withgradient(model) do m
y_hat = m(x)
Flux.Losses.mse(y_hat[mask], y[mask])
end
#I assume that the gradients are computed for the entire network
#and not just for the weights that affect mask
Flux.update!(opt_state, model, grads[1])
end
Is your complaint about the present code is that it masks the output not the input? Then you might just want y_hat = m(x .* MASK).
Or can you clarify what “does not do what I want it to do” means? I believe what’s written will only update some entries of model[2].weight, because the loss does not depend on all of them.
Thank you for your reply. I want to mask the output, not the input. I have the full 2D image that I want to postprocess, but only for a few pixels I actually have the ground truth.
If my code actually updates only a few of the gradients because the loss depends on only a few of them, my problem is actually solved… I thought that in my code Zygote would not see the mask as it was hidden away, but then it is smarter than I had thought. I will test this more and then mark your answer as solution
I guess it does update everything, as you have a different mask per image in the batch. If you use one mask for all, then you can see that only some rows of the weight matrix are updated. The batch just does many such updates at the same time.