How can I differentiate a subset of the outputs of a neural network in Flux or Lux?

Is there a way in Flux or Lux to differentiate a subset of the outputs of a Neural Network model? The reason why I am asking is that I have a target matrix with 6 outputs and 3000 observations (very little data) but the bad thing is that I have quite a few missing here and there in the target matrix. Thus at any given batch I would like to backpropagate the errors from all the known targets while ignoring the missing. Is there a way to do that? Happy to provide some dummy code for it if it helps.

I guess you’re gonna have to do manual autodiff using Zygote, defining the loss function yourself as the sum of errors for all non-missing outputs

1 Like

I’ll give that a go. Thanks. :slight_smile:

1 Like

Could you share an example if you succeed?
At first I was thinking that it’s trivial and then I couldn’t find the solution at a glance

You may be interested in some of the previous discussions on here and GitHub about masking in losses. Those threads should have code examples

So this seems to work but my loss function is a horrible hack and will for sure run extremely slow on a GPU.

function loss2(y, ŷ)
    l(x, x̂) = sum((x - x̂) .^ 2)
    totloss = 0
    for j in 1:(size(y)[1])
        for i in 1:(size(y)[2])
            if !ismissing(y[j, i])
                totloss = totloss + l(y[j, i], ŷ[j, i])
            end
        end
    end
    totloss / prod(size(y))
end
model = Chain(Dense(size(X)[1] => 10), Dense(10 => size(Y)[1]))
opt_state = Flux.setup(AdamW(0.005), model)
@info loss2(Y, model(X))
for e in 1:10
    # Calculate the gradient of the objective
    # with respect to the parameters within the model:
    grads = Flux.gradient(model) do m
        result = m(X)
        loss2(Y, result)
    end
    # Update the parameters so as to reduce the objective,
    # according the chosen optimisation rule:
    Flux.update!(opt_state, model, grads[1])
    @info loss2(Y, model(X))
end

Did you try something like
Flux.Losses.mse(model(x), y; agg=skipmissing|>mean)
?

No but I tried this:

loss(y, ŷ) = sum(skipmissing(y - ŷ) .^ 2)

which did not work. I will try your suggestion.

1 Like

I think they are basically the same.
Sorry for inconvenience, I just cannot try this right now :frowning:
But yeah, give it a try if you don’t mind :>

You could also try something with coalesce, like

loss(ŷ, y) = mse(coalesce.(y, ŷ),  ŷ)
3 Likes

So this was precisely what I hoped to achieve with skipmissing. Did not know about coalesce. Very neat. This solves my problem. Thanks guys. :slight_smile: