I guess it does update everything, as you have a different mask per image in the batch. If you use one mask for all, then you can see that only some rows of the weight matrix are updated. The batch just does many such updates at the same time.
using Flux
X = rand(Float32, 6, 10); # batch of 10 "images" in 1D
Y = rand(Float32, 6, 10);
MASK = [true, false, true, false, true, false] # same mask for all images
train_loader = Flux.DataLoader((X, Y), batchsize = 1)
model = Chain(Dense(6, 12; init=ones32), Dense(12, 6; init=ones32))
opt_state = Flux.setup(Flux.Adam(), model)
for (x, y) in train_loader
_, grads = Flux.withgradient(model) do m
y_hat = m(x)
Flux.Losses.mse(y_hat[MASK], y[MASK])
end
Flux.update!(opt_state, model, grads[1])
end
model[2].weight
#=
julia> model[2].weight # only some rows changed
6×12 Matrix{Float32}:
0.990377 0.990377 0.990377 0.990377 0.990377 … 0.990377 0.990377 0.990377 0.990377
1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
0.990374 0.990374 0.990374 0.990374 0.990374 0.990374 0.990374 0.990374 0.990374
1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
0.990375 0.990375 0.990375 0.990375 0.990375 0.990375 0.990375 0.990375 0.990375
1.0 1.0 1.0 1.0 1.0 … 1.0 1.0 1.0 1.0
julia> model[1].weight # all changed
12×6 Matrix{Float32}:
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
0.990828 0.99101 0.9915 0.991691 0.99049 0.990847
=#