Taking the derivative of a scalar loss, that involves a gradient inside, errors on GPU only

I am working on a meta-learning project and experimenting with MAML. I am able to get a first-order approximation fine, as it can be implemented without differentiating through gradient descent. When doing “full” MAML, I am running into an error on GPU. Below is a MWE

#device = Flux.gpu # Returns an error
device = Flux.cpu # Does not error

xs = rand(10,100) |> device;
ys = rand(10,100) |> device;

f = Chain(Dense(10,10)) |> device;

function inner_loop(f, x, y)
    gs_inner = gradient(Flux.params(f)) do
        return sum(f(x))
    end
end

gs = gradient(Flux.params(f)) do
   gs_inner = inner_loop(f, xs, ys)
   return  sum(LinearAlgebra.norm, gs_inner )
end

gs.grads[Flux.params(f)[1]] #nothing
gs.grads[Flux.params(f)[2]] #nothing         

If the device is Flux.cpu, this seems to works fine. If the device is Flux.gpu, then I encounter the following error

ERROR: this intrinsic must be compiled to be called
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:9
  [3] _pullback
    @ ./atomics.jl:358 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::typeof(getindex), args::Base.Threads.Atomic{Int64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [5] _pullback (repeats 2 times)
    @ ~/.julia/packages/CUDA/Uurn4/lib/utils/threading.jl:46 [inlined]
  [6] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:7 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(CUDA.device_properties), args::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:51 [inlined]
  [9] _pullback(::Zygote.Context, ::CUDA.var"##CUDACompilerTarget#238", ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.CUDACompilerTarget), ::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:51 [inlined]
 [11] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:294 [inlined]
 [12] _pullback(::Zygote.Context, ::CUDA.var"##cufunction#253", ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#17", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:291 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#17", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [15] macro expansion
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:102 [inlined]
 [16] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/gpuarrays.jl:17 [inlined]
 [17] _pullback(::Zygote.Context, ::CUDA.var"##launch_heuristic#282", ::Int64, ::Int64, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#17", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [18] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [19] adjoint
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:200 [inlined]
 [20] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Int64, Int64, typeof(GPUArrays.launch_heuristic), CUDA.CuArrayBackend, GPUArrays.var"#broadcast_kernel#17"}, ::Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64})
    @ Zygote ./none:0
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [22] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/gpuarrays.jl:17 [inlined]
 [23] _pullback(::Zygote.Context, ::GPUArrays.var"#launch_heuristic##kw", ::NamedTuple{(:elements, :elements_per_thread), Tuple{Int64, Int64}}, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#17", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [24] _pullback
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:73 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(GPUArrays._copyto!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [26] _pullback
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:51 [inlined]
 [27] _pullback
    @ ./broadcast.jl:891 [inlined]
 [28] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.materialize!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [29] _pullback
    @ ./broadcast.jl:887 [inlined]
 [30] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/broadcast.jl:280 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::Zygote.var"#1233#1240"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [32] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [33] _pullback(ctx::Zygote.Context, f::Zygote.var"#4049#back#1241"{Zygote.var"#1233#1240"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [34] _pullback
    @ ./REPL[90]:3 [inlined]
 [35] _pullback(ctx::Zygote.Context, f::typeof(∂(λ)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [36] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:357 [inlined]
 [37] _pullback(ctx::Zygote.Context, f::Zygote.var"#87#88"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(λ)), Zygote.Context}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [38] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [inlined]
 [39] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#109#110"{Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [40] _pullback
    @ ./REPL[90]:2 [inlined]
 [41] _pullback(::Zygote.Context, ::typeof(inner_loop), ::Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [42] _pullback
    @ ./REPL[91]:2 [inlined]
 [43] _pullback(::Zygote.Context, ::var"#113#114")
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [44] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:352
 [45] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:75
 [46] top-level scope
    @ REPL[91]:1
 [47] top-level scope
    @ ~/.julia/packages/CUDA/Uurn4/src/initialization.jl:52

I understand that there can be issues with nested calls to Zygote. Particularly because the function’s pullbacks need to be differentiable themselves. I do not see why this should work on CPU and not GPU, and would appreciate any pointers towards fixing this problem.

1 Like

Do you mind filing an issue for this? It looks like Zygote is trying to differentiate through Zygote.jl/broadcast.jl at master · FluxML/Zygote.jl · GitHub, which means it isn’t second derivative safe. We’d probably have to add a rule for it.

1 Like

I am having the exact same problem implementing Wasserstein GAN with gradient penalty (WGAN-GP), which involves nested gradient calls, using Flux/Zygote on GPU. It works fine on CPU.
Someone please fix this asap!

What’s extremely helpful is if you can boil examples down to the minimum thing which produces the same error. Ideally not involving Flux at all, just some gradient calls on some functions. Then make an issue on Zygote.jl with these, and which can be useful as tests for anyone trying to make this work.

1 Like

Note also that most if not all of the building blocks for WGAN-GP already exist and are doubly-differentiable. I want to say someone even got a version working.

Flux is developed by a small team of people, in their spare time, for free. If you want urgency, the way to go is to sponsor a dedicated dev or roll up your sleeves and try contributing yourself.

2 Likes

Here is my WGAN-GP implementation. It is almost line by line identical to the pseudocode of Algorithm 1 in the original paper. @ToucheSir, can you point me to an existing implementation in Julia? I couldn’t find one.

using Flux
using Flux: update!
using Zygote
using StatsBase

"""
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical.
"""
function train_WGAN_GP(𝐺, 𝐷, 𝐗::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, λ=10f0, ncritic=5, α=0.0001, β₁=0, β₂=0.9) where N
    n = size(𝐗)[end]    # length of dataset
    𝐺, 𝐷 = device_fn(deepcopy(𝐺)), device_fn(deepcopy(𝐷))
    θ, 𝑤 = params(𝐺), params(𝐷)
    adamθ, adam𝑤 = ADAM(α, (β₁, β₂)), ADAM(α, (β₁, β₂)) 

    for iter in 1:num_iters
        for t in 1:ncritic
            𝐱, 𝐳, 𝛜 = 𝐗[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample batch of real data x, latent variables z, random numbers ϵ ∼ U[0, 1].
            𝐱, 𝐳, 𝛜 = device_fn(𝐱), device_fn(𝐳), device_fn(𝛜)
            𝐱̃ = 𝐺(𝐳)
            𝐱̂ = 𝛜 .* 𝐱 + (1f0 .- 𝛜) .* 𝐱̃
            ∇𝑤L = gradient(𝑤) do
                ∇𝐱̂𝐷, = gradient(𝐱̂ ->  sum(𝐷(𝐱̂)), 𝐱̂)
                L = mean(𝐷(𝐱̃)) - mean(𝐷(𝐱)) + λ * mean((sqrt.(sum(∇𝐱̂𝐷.^2, dims=1) .+ 1f-12) .- 1f0).^2)
            end
            update!(adam𝑤, 𝑤, ∇𝑤L)
        end

        𝐳 = device_fn(randn(Float32, latent_size..., m))
        ∇θ𝐷 = gradient(θ) do
            -mean(𝐷(𝐺(𝐳)))
        end
        update!(adamθ, θ, ∇θ𝐷)
    end

    return 𝐺, 𝐷
end

𝐗 = rand(Float32, 50, 10000)  # dummy data
z = 16                        # latent size
𝐺 = Chain(Dense(z, 32, leakyrelu), Dense(32, 50))   # Generator
𝐷 = Chain(Dense(50, 32, leakyrelu), Dense(32, 1))   # Critic

𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, cpu) # works
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, gpu) # fails

On GPU, this code fails with error LoadError: this intrinsic must be compiled to be called on line ∇𝐱̂𝐷, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂).

Here is an isolated snippet that fails (only) on GPU:


using Flux
using Flux: update!
using Zygote
using StatsBase

function run_isolated_code_on(device_fn)
    D = Chain(Dense(50, 32, leakyrelu), Dense(32, 1)) |> device_fn  # Critic
    w = params(D)
    x = rand(Float32, 50, 32) |> device_fn                          # Dummy minibatch
    ∇wL = gradient(w) do
        ∇xD, = gradient(x ->  sum(D(x)), x)
        L = mean((sqrt.(sum(∇xD.^2, dims=1) .+ 1f-12) .- 1f0).^2)   # gradient penalty
    end
end

run_isolated_code_on(cpu)  # works
run_isolated_code_on(gpu)  # fails

I opened an issue on Zygote.jl: Taking nested gradient for implementing Wasserstein GAN with gradient penalty (WGAN-GP) on GPU · Issue #1262 · FluxML/Zygote.jl (github.com)

1 Like