Converting PyTorch to Flux while keeping performance

Yes, but note here how you’re calling sum(xs) and not sum(f, xs).

You may need to import LoopVectorization in order for Tullio to generate a fully optimized kernel. More importantly, I would extract deformation_indexed[:, 1, :] into its own local variable to potentially save on a lot of compute/memory overhead.

Also, what is register? It seems like there is more code here that may have an influence on performance (e.g. if register is a mutable struct), so a MWE would be much appreciated.

1 Like