Sparse loss with Flux

I guess it does update everything, as you have a different mask per image in the batch. If you use one mask for all, then you can see that only some rows of the weight matrix are updated. The batch just does many such updates at the same time.

using Flux

X = rand(Float32, 6, 10);  # batch of 10 "images" in 1D
Y = rand(Float32, 6, 10);

MASK = [true, false, true, false, true, false]  # same mask for all images
train_loader = Flux.DataLoader((X, Y), batchsize = 1)

model = Chain(Dense(6, 12; init=ones32), Dense(12, 6; init=ones32))
opt_state = Flux.setup(Flux.Adam(), model)

for (x, y) in train_loader
   _, grads = Flux.withgradient(model) do m
       y_hat = m(x)
       Flux.Losses.mse(y_hat[MASK], y[MASK])
    end
    Flux.update!(opt_state, model, grads[1])
end

model[2].weight

#=

julia> model[2].weight  # only some rows changed
6×12 Matrix{Float32}:
 0.990377  0.990377  0.990377  0.990377  0.990377  …  0.990377  0.990377  0.990377  0.990377
 1.0       1.0       1.0       1.0       1.0          1.0       1.0       1.0       1.0
 0.990374  0.990374  0.990374  0.990374  0.990374     0.990374  0.990374  0.990374  0.990374
 1.0       1.0       1.0       1.0       1.0          1.0       1.0       1.0       1.0
 0.990375  0.990375  0.990375  0.990375  0.990375     0.990375  0.990375  0.990375  0.990375
 1.0       1.0       1.0       1.0       1.0       …  1.0       1.0       1.0       1.0
 
julia> model[1].weight  # all changed
12×6 Matrix{Float32}:
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847
 0.990828  0.99101  0.9915  0.991691  0.99049  0.990847

=#
2 Likes