Hi,
I’m working on implementing a seq2seq model in Flux. Is there a standard method, or if not a suitable approach, for ignoring masked tokens when computing the gradient? E.g. sentences of differing lengths, following the EOS token
Cheers
Hi,
I’m working on implementing a seq2seq model in Flux. Is there a standard method, or if not a suitable approach, for ignoring masked tokens when computing the gradient? E.g. sentences of differing lengths, following the EOS token
Cheers
I’m not aware of any existing functionality. If you have some insight into how other frameworks handle this, we could workshop something.
It appears that PyTorch has an ignore_index argument for its loss functions, which allows you to specify indices to ignore when computing the loss and will therefore affect gradient computations in a way I desire.
I guess it would be relatively straightforward to write my own loss function in Julia/Flux that does the same.
yeah, one of the “problems” of Julia is that it’s often too straightforward to write your own little something once you’re semi domain expert.
This means absolute beginners find Julia ecosystem harder to use due to less canned (although trivial) functionalities.
Honestly this is so true its painful. TF/pytorch are functionally languages unto themselves, and I often feel frozen out of writing custom algo’s because I simply don’t know TF/pytorch well enough. Always left asking, am I writing this correctly?
With Julia/Flux, the code is the math, and everything just works. Love it!
If you do come up with something, feel free to open a PR and we’ll try to find a good home for it ![]()
Cheers,
I am stuck with a somehow similar problem, hope someone can advise.
Have tried to mimic Torchvision’s crossentropy, which accepts a mask of integers as input. The function also has an ignore_class argument which is useful to exclude classes with small representation that just add noise to the process (typically class 255). Long story short, I came up with an algorithm that masks out such pixels from the loss calculation.
Problem occurs when the training loop is precompiled. The loss function uses the Flux.onehotbatch() instruction that breaks during compilation. However, inference compiles and executes flawlessly.
My good friend Grok AI tells me that this is a known issue, and suggests switching the AD engine to Enzyme. With the new engine, compilation during training breaks for a different reason: excess of allocations.
Can anyone please confirm the issue with onehotbatch() and Zygote, and perhaps suggest a bypass?
Thanks in advance.