Thank you for your response. Unfortunately I cannot go into the details of the data but I can try to give you some more clarity. As mentioned in my original post here How to take full advantage of GPU Parallelism on Nested Sequential Data in Flux - #4 by jonathan-laurent, I have data in the form of sequences of sequences which I would like to analyze. While I could probably just concatenate the sub-sequences and process with only one recurrent model as one sequence, I believe that the model structure I have chosen will much better be able to follow the structure of my data.
I appreciate your advice to use a more common approach, but my main question now is simply, how I can process the columns of 2D slices of 3D data one slice at a time in a Flux model chain which still allows for the gradient to be computed and takes advantage of processing each column of the slice in parallel on a GPU?
I believe that my initial MWE on this post clearly shows the data manipulation I am trying to do with 3D data. I would greatly appreciate it if you could give me some direction on how to implement that in a way in which the gradient can be computed and which can be computed taking advantage of the GPU. Then I should be able to apply that to my full model chain!
Thank you for your help,
Jack