My Flux application runs painfully slowly.
My training set is about 250,000 with about 100 inputs with the following setup:
Chain(Dense(100,40,relu), Dense(40,40,relu), Dense(40,1,identity))
The mini-batches are about 6000 and are taking about 15 seconds each training using ADAM or RMSProp, which seems painfully slow.
My colleague wrote a PyTorch version which runs nearly 100 times faster on the same machine.
There must be a bottleneck somewhere, but I can’t seem to find it.
I am using a cross-entropy loss function (slightly modified, but compatible with Zygote).
I timed the evaluation of the loss function for the entire dataset and it was only about 0.3 seconds
Does anyone have any idea why it should be running so slowly?
Thanks for any hints!