If you have a grasp on how the indexing and contractions for this N-D dense operation should go, GitHub - mcabbott/Tullio.jl: ⅀ is worth a try.
Edit: looking at the implementations of Flux.Dense and torch.nn.Linear, there doesn’t seem to be a reason why you couldn’t pass a higher-dimensional input to the former as well. Have you tried this? Seems like the only hurdle would be row- vs column-wise dimensions, though it’s hard to tell since PyTorch may be broadcasting the matmul implicitly. I know you don’t have an easy MWE, but one straightforward way to test this would be
a) to look at the input and output dims before/after/in-between layers in your PyTorch model,
b) replicate and test those layers and inputs (i.e. same shapes and params) in isolation, then
c) repeat with equivalent layers and inputs in Flux (accounting for column-major and batch-last).