Custom train function in SimpleChains.jl

Hi SimpleChains team!

I’m currently using SimpleChains and I want to implement my own train function based on the available template in Flux:

for d in training_set
    # Our super logic
    gs = gradient(params(m)) do
        l = loss(d...)
    end
    update!(opt, params(m), gs)
end

What I want to do, using SimpleChains, is to incorporate a simple logic that allows me to discard certain examples during training. I want to avoid computing the gradient for these examples. Is there a specific method to do this?

I have already implemented my own logic by modifying the code in the train_unbatched function, but I would like to achieve the same result using train_batched. However, train_batched appears to be more complex, and I’m finding it challenging to make the necessary modifications.

Additionally, I’ve observed that the performance (in terms of loss and accuracy) of train_batched is significantly better than train_unbatched. If it’s not possible to create an explicit train function, is there a way to improve train_unbatched so that it is comparable to train_batched?

Any guidance or suggestions would be greatly appreciated!
Thank you :slight_smile: