In Flux.jl, I wanted to modify the behavior of BatchNorm, so I changed the code here. The modified code was effective when it ran on CPU, but I realized that it was not effective on GPU.
After some investigation, I found out that that after transferring the model m to GPU by m = m |> gpu, the above linked code was no longer executed.
What is the general procedure to make the changes in the Flux.jl source code effective on GPU as well? I think the actual code executed on GPU might be this, but I’m not sure how to modify it because it eventually uses @ccall.
For historical reasons, the CUDA path for Flux’s batchnorm lives in its own file. Thus you’d either have to remove this code or change those methods to make your own version work on GPU. Medium-term, we plan to clean up the API layering around norm layers so that all batchnorm layer methods can live in one place.