Let me try to remove PermutedDimsArray
and see if it solves the problem. Thank you!!!
UPDATE. I managed to avoid creating PermutedDimsArray
inside the model by pre-permuting the data outside the model. Now the code works on GPU and about 10X faster than on CPU. Maybe I can further optimize the code for GPU, but for now this is good enough. Thank you all for the help!