Compilation error in Zygote, Flux and CUDA interaction

I’m trying to implement a conditioned adaptation of the wasserstein generative adversarial networks with gradient penalty (WGAN-GP) training algorithm. But my function to calculate the gradient penalty isn’t working on the gpu, although it worked on the cpu. I’m getting the following error and stacktrace when trying to take the gradient (points to backpropagation through BatchNorm layer):

ERROR: Compiling Tuple{NNlibCUDA.var"##cudnnBNBackward!#71", Nothing, Float32,
Int64, Int64, Float32, Float32, Bool, Bool, typeof(NNlibCUDA.cudnnBNBackward!),
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer},
CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer},
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}: try/catch is not supported.    
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:121
  [3] #Primal#23
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\emit.jl:101
  [6] #s2948#1074
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:28 [inlined]
  [7] var"#s2948#1074"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote .\none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:582
  [9] _pullback
    @ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:124 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::NNlibCUDA.var"#cudnnBNBackward!##kw", ::NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}, ::typeof(NNlibCUDA.cudnnBNBackward!), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [11] _pullback
    @ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:115 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::NNlibCUDA.var"##∇batchnorm#70", ::Bool, ::Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, ::typeof(NNlibCUDA.∇batchnorm), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [13] _pullback
    @ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:109 [inlined]
 [14] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:17 [inlined]
 [15] _pullback(ctx::Zygote.Context{false}, f::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}, args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [16] _pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:206 [inlined]
 [17] _pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:232 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#kw_zpullback#45"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}}, args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [19] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:9 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [21] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:6 [inlined]
 [22] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [23] macro expansion
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [24] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [25] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(_applychain)), args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [26] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
 [27] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [28] macro expansion
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [29] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [30] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(_applychain)), args::Transpose{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [31] _pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
 [32] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::Transpose{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [33] _pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:27 [inlined]
 [34] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [35] _pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45 [inlined]
 [36] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#60#61"{typeof(∂(λ))}, args::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [37] _pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97 [inlined]
 [38] _pullback(::Zygote.Context{false}, ::typeof(Zygote.gradient), ::var"#391#393"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, var"#placeHolderLoss#392"}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [39] _pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:27 [inlined]
 [40] _pullback(::Zygote.Context{false}, ::typeof(gpTerm), ::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), 
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, ::Array{Float32, 4}, ::Array{Float32, 4}, ::Array{Float32, 4}, ::Array{Float32, 4})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [41] _pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:57 [inlined]
 [42] _pullback(::Zygote.Context{false}, ::var"#wganGPloss#396"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}}, ::Vector{Float32}, ::Vector{Float32})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [43] _pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:63 [inlined]
 [44] _pullback(ctx::Zygote.Context{false}, f::var"#394#397"{var"#wganGPloss#396"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 
4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 
2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}, args::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [45] pullback(f::Function, cx::Zygote.Context{false}, args::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), 
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:44
 [46] pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:42 [inlined]
.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:132
 [48] wganGPepoch(metaData::GANmetaData, goal::Symbol, nCritic::Int64)
    @ Main c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:62
 [49] top-level scope
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:107

Here is an MWE:

using Flux, LinearAlgebra, MLUtils, Statistics
LinearAlgebra.norm(::Nothing, p::Real=2) = false

function gpTermMWE()
  discriminator = Chain(BatchNorm(7)) |> gpu # toy model
  # toy data
  realTopology = rand(Float32, (51, 141, 1, 64))
  fakeTopology = rand(Float32, (51, 141, 1, 64))
  conditions = rand(Float32, (51, 141, 6, 64))
  # interpolate fake and true topologies
  ϵ = reshape(rand(Float32, size(realTopology, 4)), (1, 1, 1, size(realTopology, 4)))
  interpTopos = @. ϵ * realTopology + (1 - ϵ) * fakeTopology
  # prepare discriminator input with interpolated topologies
  discInputInterpGPU = cat(conditions, interpTopos; dims = 3) |> gpu
  function wganGPloss(discOutReal, discOutFake)
    gradNorm = norm(
      Flux.gradient(discInputInterpGPU -> discriminator(discInputInterpGPU) |> cpu |> mean,
      discInputInterpGPU
    )[1])
    println("gpNorm: ", gradNorm, "  gpTerm: ", (gradNorm - 1) ^ 2)
    return mean(discOutFake) - mean(discOutReal) + 10 * (gradNorm - 1) ^ 2
  end
  # discriminator inputs
  discInputReal = cat(conditions, realTopology; dims = 3) |> gpu
  discInputFake = cat(conditions, fakeTopology; dims = 3) |> gpu
  discGrad = Flux.gradient(
    discriminator -> wganGPloss(
      discriminator(discInputReal) |> cpu,
      discriminator(discInputFake) |> cpu
    ),
    discriminator
  )
  println("discGradNorm: ", norm(discGrad[1]))
end

gpTermMWE()

Note that the gradient inside the wganGPloss() function is with respect to the input, and not the model parameters. After trying many things, I can’t find what’s wrong.

I have a decent idea what this is, but not the time to make a MWE. Try reducing your model down to a single BatchNorm layer and adding the original gradient penalty (you won’t get the error without it, regardless of what loss function/model you use).

From my understanding of your comment, I edited the post a bit and changed the MWE, which now reproduces the problem with a single BatchNorm layer.

Thanks. This is an issue because Flux.BatchNorm doesn’t support nested differentiation on GPU. I’ve opened `BatchNorm` is not twice-differentiable · Issue #2154 · FluxML/Flux.jl · GitHub to track this.

In the meantime, one thing you can try is copying Flux’s BatchNorm layer and calling it something else so that you don’t hit the dispatches in Flux.jl/cudnn.jl at master · FluxML/Flux.jl · GitHub. I’m unsure if that’s sufficient to make things twice-differentiable on GPU, but it should get you farther.