Pixelwise loss weights in Flux?

Hi all,

I’m interested in implementing a pixel-wise loss function in Flux where I would weight up pixels that separate segments from each other to force the network to learn those pixels preferentially. That is, I want to segment an input grayscale image into a binary mask, but want to make sure that the edges of the segments are as correct as possible. I’m struggling to see how I would pass these weights (which aren’t learned, they are precomputed) to the loss functions.

An example is panel D below (reproduced from the original UNet paper):

These weights are specific to the training data so I would assume that you would group them with the ground truth binary mask (as an additional dimension) and then design a custom loss function that knows the mask is the first index in that dimension and the weights are the second? (for example).

Am I on the right track with my thinking?


for the pixelwise binary classification task, the final output of the network before the softmax will be a WxHx1xN array, let’s call it ypred. The weights in eq. 1 and 2 of the Unet paper have the same role as the usual labels in classification tasks. You will have to arrange them into a WxHx1xN array, let’s call it y, then yoour loss will be logitbinarycrossentropy(ypred, y) that you find in Flux.Losses module.

If you have pixelwise multiclass classification instead, you will work with WxHxCxN arrays, with C the number of classes, and loss logitcrossentropy(ypred, y; dims=3)

You will find an implementation of UNet here https://github.com/DhairyaLGandhi/UNet.jl

1 Like