Zygote error in backprop through NN

Following my last post about implementing WGAN-GP, now I’m running into another problem with gradients. Changing the discriminator’s BatchNorm() to Metalhead’s ChannelLayerNorm() solves the issue with the inner gradient. But now I’m getting a “Can’t differentiate foreigncall expression” on the outer gradient. This is the stacktrace (points to (m::ChannelLayerNorm)(x) definition):

ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_svec_ref), Any, svec(Any, Int64), 0, :(:ccall), %15, %13, %12)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace: 
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] Pullback
    @ .\essentials.jl:612 [inlined]
  [3] (::typeof(∂(getindex)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\tools\builtins.jl:12 [inlined] 
  [5] (::typeof(∂(literal_getindex)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [6] Pullback
    @ .\reflection.jl:792 [inlined]
  [7] (::typeof(∂(fieldcount)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\kaoid\.julia\packages\ChainRulesCore\C73ay\src\tangent_types\tangent.jl:220 [inlined]
  [9] (::typeof(∂(canonicalize)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [10] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:116 [inlined]
 [11] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:184 [inlined]
 [12] (::typeof(∂(_project)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [13] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:234 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\kaoid\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\stateless.jl:55 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [19] Pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:10 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [21] macro expansion
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [22] Pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [24] Pullback
    @ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [26] Pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:25 [inlined]
 [27] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [28] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45 [inlined]
 [29] (::typeof(∂(λ)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [30] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97 [inlined]
 [31] (::typeof(∂(gradient)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [32] Pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:24 [inlined]
 [33] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [34] Pullback
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:35 [inlined]
 [35] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [36] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
 [37] gradient(f::Function, args::Chain{Tuple{ChannelLayerNorm{Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}, Float32}}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97
 [38] gpTermMWE()
    @ Main c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:34
 [39] top-level scope
    @ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:44

And this is the MWE:

using Flux, LinearAlgebra, MLUtils, Statistics
LinearAlgebra.norm(::Nothing, p::Real=2) = false

struct ChannelLayerNorm{D, T} diag::D; ϵ::T end
Flux.@functor ChannelLayerNorm
function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6)
  diag = Flux.Scale(1, 1, sz, λ)
  return ChannelLayerNorm(diag, ϵ)
end
(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))

function gpTermMWE()
  discriminator = Chain(ChannelLayerNorm(7)) |> gpu # toy model
  # toy data
  realTopology = rand(Float32, (51, 141, 1, 64))
  fakeTopology = rand(Float32, (51, 141, 1, 64))
  conditions = rand(Float32, (51, 141, 6, 64))
  # interpolate fake and true topologies
  ϵ = reshape(rand(Float32, size(realTopology, 4)), (1, 1, 1, size(realTopology, 4)))
  interpTopos = @. ϵ * realTopology + (1 - ϵ) * fakeTopology
  # prepare discriminator input with interpolated topologies
  discInputInterpGPU = cat(conditions, interpTopos; dims = 3) |> gpu
  function wganGPloss(discOutReal, discOutFake)
    gradNorm = norm(
      Flux.gradient(discInputInterpGPU -> discriminator(discInputInterpGPU) |> cpu |> mean,
      discInputInterpGPU
    )[1])
    println("gpNorm: ", gradNorm, "  gpTerm: ", (gradNorm - 1) ^ 2)
    return mean(discOutFake) - mean(discOutReal) + 10 * (gradNorm - 1) ^ 2
  end
  # discriminator inputs
  discInputReal = cat(conditions, realTopology; dims = 3) |> gpu
  discInputFake = cat(conditions, fakeTopology; dims = 3) |> gpu
  discGrad = Flux.gradient(
    discriminator -> wganGPloss(
      discriminator(discInputReal) |> cpu,
      discriminator(discInputFake) |> cpu
    ),
    discriminator
  )
  println("discGradNorm: ", norm(discGrad[1]))
end

gpTermMWE()

Any suggestions?

This looks similar to Taking nested gradient for implementing Wasserstein GAN with gradient penalty (WGAN-GP) on GPU · Issue #1262 · FluxML/Zygote.jl · GitHub. You could give the dev branches mentioned in the last comment a try (they may need rebasing first). If they work, drop a mention on the GH thread and we’ll go from there.

In a new folder, I created a new environment and used Pkg to add https://github.com/mcabbott/Zygote.jl#for1262 and https://github.com/mcabbott/ChainRulesCore.jl#nothunk . I then added Flux and MLUtils. Running the MWE resulted in a lot of printing in the REPL. After that, I got the same “Can’t differentiate foreigncall expression” error, but now I couldn’t get past the inner gradient call. Here’s the stacktrace:

ERROR: LoadError: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] Pullback
    @ .\iddict.jl:114 [inlined]
  [3] (::typeof(∂(pop!)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
  [4] Pullback
    @ .\iddict.jl:131 [inlined]
  [5] (::typeof(∂(delete!)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\lib\base.jl:45 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\kaoid\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [10] Pullback
    @ C:\Users\kaoid\.julia\packages\Functors\V2McK\src\functor.jl:46 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [12] Pullback
    @ C:\Users\kaoid\.julia\packages\Functors\V2McK\src\functor.jl:44 [inlined]
 [13] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [14] Pullback
    @ C:\Users\kaoid\.julia\packages\Flux\FKl3M\src\functor.jl:154 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [16] Pullback
    @ .\operators.jl:911 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [18] Pullback
    @ C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:25 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [20] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface.jl:41 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [22] Pullback
    @ C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface.jl:76 [inlined]
 [23] (::typeof(∂(gradient)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [24] Pullback
    @ C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:24 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [26] Pullback
    @ C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:35 [inlined]
 [27] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface2.jl:0
 [28] (::Zygote.var"#70#71"{typeof(∂(λ))})(Δ::Float32)
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface.jl:41
 [29] gradient(f::Function, args::Chain{Tuple{ChannelLayerNorm{Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}, Float32}}})
    @ Zygote C:\Users\kaoid\.julia\packages\Zygote\5GQsG\src\compiler\interface.jl:76
 [30] gpTermMWE()
    @ Main C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:34
 [31] top-level scope
    @ C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:44
 [32] include(fname::String)
    @ Base.MainInclude .\client.jl:476
 [33] top-level scope
    @ REPL[2]:1
in expression starting at C:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\gpTermMWEenv\mwe.jl:44

Would Flux.Optimise.ClipNorm() (docs) have the same effect as the original gradient penalty term in WGAN-GP? I’m not getting AD problems anymore (although I ran into another issue)

Good question. I haven’t heard of this or done the derivation myself, but if it works like how weight decay == L2 regularization then I would assume so?