Taking the derivative of a scalar loss, that involves a gradient inside, errors on GPU only

I am working on a meta-learning project and experimenting with MAML. I am able to get a first-order approximation fine, as it can be implemented without differentiating through gradient descent. When doing “full” MAML, I am running into an error on GPU. Below is a MWE

#device = Flux.gpu # Returns an error
device = Flux.cpu # Does not error

xs = rand(10,100) |> device;
ys = rand(10,100) |> device;

f = Chain(Dense(10,10)) |> device;

function inner_loop(f, x, y)
    gs_inner = gradient(Flux.params(f)) do
        return sum(f(x))
    end
end

gs = gradient(Flux.params(f)) do
   gs_inner = inner_loop(f, xs, ys)
   return  sum(LinearAlgebra.norm, gs_inner )
end

gs.grads[Flux.params(f)[1]] #nothing
gs.grads[Flux.params(f)[2]] #nothing         

If the device is Flux.cpu, this seems to works fine. If the device is Flux.gpu, then I encounter the following error

ERROR: this intrinsic must be compiled to be called
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:9
  [3] _pullback
    @ ./atomics.jl:358 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::typeof(getindex), args::Base.Threads.Atomic{Int64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [5] _pullback (repeats 2 times)
    @ ~/.julia/packages/CUDA/Uurn4/lib/utils/threading.jl:46 [inlined]
  [6] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:7 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(CUDA.device_properties), args::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:51 [inlined]
  [9] _pullback(::Zygote.Context, ::CUDA.var"##CUDACompilerTarget#238", ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.CUDACompilerTarget), ::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/gpucompiler.jl:51 [inlined]
 [11] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:294 [inlined]
 [12] _pullback(::Zygote.Context, ::CUDA.var"##cufunction#253", ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#17", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:291 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#17", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [15] macro expansion
    @ ~/.julia/packages/CUDA/Uurn4/src/compiler/execution.jl:102 [inlined]
 [16] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/gpuarrays.jl:17 [inlined]
 [17] _pullback(::Zygote.Context, ::CUDA.var"##launch_heuristic#282", ::Int64, ::Int64, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#17", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [18] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [19] adjoint
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:200 [inlined]
 [20] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Int64, Int64, typeof(GPUArrays.launch_heuristic), CUDA.CuArrayBackend, GPUArrays.var"#broadcast_kernel#17"}, ::Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64})
    @ Zygote ./none:0
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [22] _pullback
    @ ~/.julia/packages/CUDA/Uurn4/src/gpuarrays.jl:17 [inlined]
 [23] _pullback(::Zygote.Context, ::GPUArrays.var"#launch_heuristic##kw", ::NamedTuple{(:elements, :elements_per_thread), Tuple{Int64, Int64}}, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#17", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [24] _pullback
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:73 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(GPUArrays._copyto!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [26] _pullback
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:51 [inlined]
 [27] _pullback
    @ ./broadcast.jl:891 [inlined]
 [28] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.materialize!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [29] _pullback
    @ ./broadcast.jl:887 [inlined]
 [30] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/broadcast.jl:280 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::Zygote.var"#1233#1240"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [32] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [33] _pullback(ctx::Zygote.Context, f::Zygote.var"#4049#back#1241"{Zygote.var"#1233#1240"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [34] _pullback
    @ ./REPL[90]:3 [inlined]
 [35] _pullback(ctx::Zygote.Context, f::typeof(∂(λ)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [36] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:357 [inlined]
 [37] _pullback(ctx::Zygote.Context, f::Zygote.var"#87#88"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(λ)), Zygote.Context}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [38] _pullback
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [inlined]
 [39] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#109#110"{Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [40] _pullback
    @ ./REPL[90]:2 [inlined]
 [41] _pullback(::Zygote.Context, ::typeof(inner_loop), ::Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [42] _pullback
    @ ./REPL[91]:2 [inlined]
 [43] _pullback(::Zygote.Context, ::var"#113#114")
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [44] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:352
 [45] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:75
 [46] top-level scope
    @ REPL[91]:1
 [47] top-level scope
    @ ~/.julia/packages/CUDA/Uurn4/src/initialization.jl:52

I understand that there can be issues with nested calls to Zygote. Particularly because the function’s pullbacks need to be differentiable themselves. I do not see why this should work on CPU and not GPU, and would appreciate any pointers towards fixing this problem.

1 Like

Do you mind filing an issue for this? It looks like Zygote is trying to differentiate through Zygote.jl/broadcast.jl at master · FluxML/Zygote.jl · GitHub, which means it isn’t second derivative safe. We’d probably have to add a rule for it.