Following my last post about implementing WGAN-GP, now I’m running into another problem with gradients. Changing the discriminator’s BatchNorm()
to Metalhead’s ChannelLayerNorm()
solves the issue with the inner gradient. But now I’m getting a “Can’t differentiate foreigncall expression” on the outer gradient. This is the stacktrace (points to (m::ChannelLayerNorm)(x)
definition):
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_svec_ref), Any, svec(Any, Int64), 0, :(:ccall), %15, %13, %12)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] Pullback
@ .\essentials.jl:612 [inlined]
[3] (::typeof(∂(getindex)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[4] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\tools\builtins.jl:12 [inlined]
[5] (::typeof(∂(literal_getindex)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[6] Pullback
@ .\reflection.jl:792 [inlined]
[7] (::typeof(∂(fieldcount)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[8] Pullback
@ C:\Users\kaoid\.julia\packages\ChainRulesCore\C73ay\src\tangent_types\tangent.jl:220 [inlined]
[9] (::typeof(∂(canonicalize)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[10] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:116 [inlined]
[11] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\chainrules.jl:184 [inlined]
[12] (::typeof(∂(_project)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[13] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:234 [inlined]
[14] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[15] Pullback
@ C:\Users\kaoid\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
[16] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[17] Pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\stateless.jl:55 [inlined]
[18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[19] Pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:10 [inlined]
[20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[21] macro expansion
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[22] Pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:53 [inlined]
[23] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[24] Pullback
@ C:\Users\kaoid\.julia\packages\Flux\ZdbJr\src\layers\basic.jl:51 [inlined]
[25] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[26] Pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:25 [inlined]
[27] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[28] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45 [inlined]
[29] (::typeof(∂(λ)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[30] Pullback
@ C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97 [inlined]
[31] (::typeof(∂(gradient)))(Δ::Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[32] Pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:24 [inlined]
[33] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[34] Pullback
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:35 [inlined]
[35] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[36] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
[37] gradient(f::Function, args::Chain{Tuple{ChannelLayerNorm{Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}, Float32}}})
@ Zygote C:\Users\kaoid\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:97
[38] gpTermMWE()
@ Main c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:34
[39] top-level scope
@ c:\Users\kaoid\My Drive\Estudo\Poli\Pesquisa\Programas\QuickTO\QuickTO\mwe.jl:44
And this is the MWE:
using Flux, LinearAlgebra, MLUtils, Statistics
LinearAlgebra.norm(::Nothing, p::Real=2) = false
struct ChannelLayerNorm{D, T} diag::D; ϵ::T end
Flux.@functor ChannelLayerNorm
function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6)
diag = Flux.Scale(1, 1, sz, λ)
return ChannelLayerNorm(diag, ϵ)
end
(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))
function gpTermMWE()
discriminator = Chain(ChannelLayerNorm(7)) |> gpu # toy model
# toy data
realTopology = rand(Float32, (51, 141, 1, 64))
fakeTopology = rand(Float32, (51, 141, 1, 64))
conditions = rand(Float32, (51, 141, 6, 64))
# interpolate fake and true topologies
ϵ = reshape(rand(Float32, size(realTopology, 4)), (1, 1, 1, size(realTopology, 4)))
interpTopos = @. ϵ * realTopology + (1 - ϵ) * fakeTopology
# prepare discriminator input with interpolated topologies
discInputInterpGPU = cat(conditions, interpTopos; dims = 3) |> gpu
function wganGPloss(discOutReal, discOutFake)
gradNorm = norm(
Flux.gradient(discInputInterpGPU -> discriminator(discInputInterpGPU) |> cpu |> mean,
discInputInterpGPU
)[1])
println("gpNorm: ", gradNorm, " gpTerm: ", (gradNorm - 1) ^ 2)
return mean(discOutFake) - mean(discOutReal) + 10 * (gradNorm - 1) ^ 2
end
# discriminator inputs
discInputReal = cat(conditions, realTopology; dims = 3) |> gpu
discInputFake = cat(conditions, fakeTopology; dims = 3) |> gpu
discGrad = Flux.gradient(
discriminator -> wganGPloss(
discriminator(discInputReal) |> cpu,
discriminator(discInputFake) |> cpu
),
discriminator
)
println("discGradNorm: ", norm(discGrad[1]))
end
gpTermMWE()
Any suggestions?