Thanks. This is an issue because Flux.BatchNorm doesn’t support nested differentiation on GPU. I’ve opened `BatchNorm` is not twice-differentiable · Issue #2154 · FluxML/Flux.jl · GitHub to track this.
In the meantime, one thing you can try is copying Flux’s BatchNorm layer and calling it something else so that you don’t hit the dispatches in Flux.jl/cudnn.jl at master · FluxML/Flux.jl · GitHub. I’m unsure if that’s sufficient to make things twice-differentiable on GPU, but it should get you farther.