I’m interested in implementing a pixel-wise loss function in Flux where I would weight up pixels that separate segments from each other to force the network to learn those pixels preferentially. That is, I want to segment an input grayscale image into a binary mask, but want to make sure that the edges of the segments are as correct as possible. I’m struggling to see how I would pass these weights (which aren’t learned, they are precomputed) to the loss functions.
An example is panel D below (reproduced from the original UNet paper):
These weights are specific to the training data so I would assume that you would group them with the ground truth binary mask (as an additional dimension) and then design a custom loss function that knows the mask is the first index in that dimension and the weights are the second? (for example).
Am I on the right track with my thinking?