I’m trying to implement a conditioned adaptation of the wasserstein generative adversarial networks with gradient penalty (WGAN-GP) training algorithm. But my function to calculate the gradient penalty isn’t working on the gpu, although it worked on the cpu. I’m getting the following error and stacktrace when trying to take the gradient (points to backpropagation through BatchNorm layer):
ERROR: Compiling Tuple{NNlibCUDA.var"##cudnnBNBackward!#71", Nothing, Float32,
Int64, Int64, Float32, Float32, Bool, Bool, typeof(NNlibCUDA.cudnnBNBackward!),
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer},
CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer},
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:121
[3] #Primal#23
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:205 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\reverse.jl:330
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\emit.jl:101
[6] #s2948#1074
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:28 [inlined]
[7] var"#s2948#1074"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote .\none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core .\boot.jl:582
[9] _pullback
@ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:124 [inlined]
[10] _pullback(::Zygote.Context{false}, ::NNlibCUDA.var"#cudnnBNBackward!##kw", ::NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}, ::typeof(NNlibCUDA.cudnnBNBackward!), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[11] _pullback
@ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:115 [inlined]
[12] _pullback(::Zygote.Context{false}, ::NNlibCUDA.var"##∇batchnorm#70", ::Bool, ::Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, ::typeof(NNlibCUDA.∇batchnorm), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[13] _pullback
@ C:\Users\kaoid\.julia\packages\NNlibCUDA\kCpTE\src\cudnn\batchnorm.jl:109 [inlined]
[14] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:17 [inlined]
[15] _pullback(ctx::Zygote.Context{false}, f::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}, args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[16] _pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:206 [inlined]
[17] _pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:232 [inlined]
[18] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#kw_zpullback#45"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}}, args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[19] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:9 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[21] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\cuda\cudnn.jl:6 [inlined]
[22] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[23] macro expansion
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[24] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[25] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(_applychain)), args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[26] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
[27] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[28] macro expansion
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[29] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[30] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(_applychain)), args::Transpose{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[31] _pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
[32] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::Transpose{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[33] _pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:27 [inlined]
[34] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(λ)), args::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[35] _pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45 [inlined]
[36] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#60#61"{typeof(∂(λ))}, args::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[37] _pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97 [inlined]
[38] _pullback(::Zygote.Context{false}, ::typeof(Zygote.gradient), ::var"#391#393"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, var"#placeHolderLoss#392"}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[39] _pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:27 [inlined]
[40] _pullback(::Zygote.Context{false}, ::typeof(gpTerm), ::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity),
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, ::Array{Float32, 4}, ::Array{Float32, 4}, ::Array{Float32, 4}, ::Array{Float32, 4})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[41] _pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:57 [inlined]
[42] _pullback(::Zygote.Context{false}, ::var"#wganGPloss#396"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}}, ::Vector{Float32}, ::Vector{Float32})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[43] _pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:63 [inlined]
[44] _pullback(ctx::Zygote.Context{false}, f::var"#394#397"{var"#wganGPloss#396"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32,
4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32,
2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}, args::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[45] pullback(f::Function, cx::Zygote.Context{false}, args::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity),
CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:44
[46] pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:42 [inlined]
.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(leakyrelu), typeof(flatten)}}, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:132
[48] wganGPepoch(metaData::GANmetaData, goal::Symbol, nCritic::Int64)
@ Main c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:62
[49] top-level scope
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\wganGPmwe.jl:107
Here is an MWE:
using Flux, LinearAlgebra, MLUtils, Statistics
LinearAlgebra.norm(::Nothing, p::Real=2) = false
function gpTermMWE()
discriminator = Chain(BatchNorm(7)) |> gpu # toy model
# toy data
realTopology = rand(Float32, (51, 141, 1, 64))
fakeTopology = rand(Float32, (51, 141, 1, 64))
conditions = rand(Float32, (51, 141, 6, 64))
# interpolate fake and true topologies
ϵ = reshape(rand(Float32, size(realTopology, 4)), (1, 1, 1, size(realTopology, 4)))
interpTopos = @. ϵ * realTopology + (1 - ϵ) * fakeTopology
# prepare discriminator input with interpolated topologies
discInputInterpGPU = cat(conditions, interpTopos; dims = 3) |> gpu
function wganGPloss(discOutReal, discOutFake)
gradNorm = norm(
Flux.gradient(discInputInterpGPU -> discriminator(discInputInterpGPU) |> cpu |> mean,
discInputInterpGPU
)[1])
println("gpNorm: ", gradNorm, " gpTerm: ", (gradNorm - 1) ^ 2)
return mean(discOutFake) - mean(discOutReal) + 10 * (gradNorm - 1) ^ 2
end
# discriminator inputs
discInputReal = cat(conditions, realTopology; dims = 3) |> gpu
discInputFake = cat(conditions, fakeTopology; dims = 3) |> gpu
discGrad = Flux.gradient(
discriminator -> wganGPloss(
discriminator(discInputReal) |> cpu,
discriminator(discInputFake) |> cpu
),
discriminator
)
println("discGradNorm: ", norm(discGrad[1]))
end
gpTermMWE()
Note that the gradient inside the wganGPloss()
function is with respect to the input, and not the model parameters. After trying many things, I can’t find what’s wrong.