Flux, CUDA, Zygote : InvalidIRError: compiling kernel getindex_kernel(CUDA.CuKernelContext, CuDeviceArray

I have a flux model composed of convolution and deconvolution layers with a custom loss function. I try to train the model but get ERROR: InvalidIRError: compiling kernel getindex_kernel(CUDA.CuKernelContext, CuDeviceArray{Complex{Float32},4,1}, CuDeviceArray{Complex{Float32},4,1}, Tuple{Int64}, CuDeviceArray{Float32,4,1}) resulted in invalid LLVM IR Reason: unsupported dynamic function invocation (call to #sprint#355(context, sizehint::Integer, ::typeof(sprint), f::Function, args...) in Base at strings/io.jl:100)

For the mwe:

using Flux
using CUDA
using Flux: glorot_uniform
using Statistics: mean


CUDA.allowscalar(false); # disallowing scalar operations on GPU


mutable struct Enc
    rConv::Chain
    iConv::Chain

    function Enc(filter, stride, in, out, pad )
        realConv = Chain(Conv(filter, in=>out, leakyrelu, init=glorot_uniform, stride=stride, pad=pad),
                         BatchNorm(out, relu))
        imgConv  = Chain(Conv(filter, in=>out, leakyrelu, init=glorot_uniform, stride=stride, pad=pad),
                         BatchNorm(out, relu))
        new(realConv, imgConv)
    end

    function Enc(rConv::Chain, iConv::Chain)
        new(rConv, iConv)
    end
end 
Flux.@functor Enc

function (enc::Enc)(x)
    rC = enc.rConv(real(x)) 
    iC = enc.iConv(imag(x))
    rC = rC - iC
    iC = rC + iC
    complex.(rC, iC) 
end

function multistft(spectrogram::CuArray{T, 4},
                    framelen::Int=1024,
                    hopsize::Int=div(framelen, 2)) where T <: Complex

    freqbins, numframes, channels, samples = size(spectrogram)
    expectedlen = framelen + (numframes - 1) * hopsize

    spectrogram = isodd(numframes) ? hcat(spectrogram, CUDA.zeros(eltype(spectrogram), size(spectrogram, 1), 1, channels, samples)) : spectrogram
    numframes   = isodd(numframes) ? numframes + 1 : numframes  # number of frames can be altered here, it should not effect the original framelen !
 
    # window  = hanningTensor(framelen, numframes, channels, samples)
    window  = CUDA.ones(Float32, (framelen, numframes, channels, samples)) .* CUDA.CuArray(Float32.(.5 .* (1 .- cos.(2 .* pi .* collect(0:framelen - 1)/(framelen - 1)))))
    windows = CUDA.fill(Float32(1.0e-8), framelen, numframes, channels, samples) .+ (window.^2)
    
    odds   = Flux.flatten(windows[:, 1:2:end, :, :]);
    evens  = Flux.flatten(windows[:, 2:2:end, :, :]);
    winsum = vcat(odds, CUDA.zeros(Float32, hopsize, samples)) .+ vcat(CUDA.zeros(Float32, hopsize, samples), evens);

    wr_odd  = window[:, 1:2:end, :, :] .* CUDA.CUFFT.irfft(spectrogram[:, 1:2:end, :, :], framelen, 1);
    wr_even = window[:, 2:2:end, :, :] .* CUDA.CUFFT.irfft(spectrogram[:, 2:2:end, :, :], framelen, 1);
    
    reconstructed = vcat(Flux.flatten(wr_odd), CUDA.zeros(Float32, hopsize, samples)) .+ vcat(CUDA.zeros(Float32, hopsize, samples), Flux.flatten(wr_even))

    return (reconstructed ./ winsum)
end


# this loss is user-defined
function wsdrLoss(x, ŷ, y; ϵ=1e-8)

    x = x |> multistft
    ŷ = ŷ |> multistft
    y = y |> multistft
    
    z = x .- y
    ẑ = x .- ŷ

    nd  = sum(y.^2; dims=1)[:]
    dom = sum(z.^2; dims=1)[:]

    ϵ_array = CUDA.fill(Float32(ϵ), size(nd))
    aux = nd ./ (nd .+ dom .+ ϵ_array)
    wSDR = aux .* sdr(ŷ, y) .+ (1 .- aux) .* sdr(ẑ, z) 
    CUDA.mean(wSDR)
end

multiNorm(A; dims) = CUDA.sqrt.(sum(real(A .* conj(A)), dims=dims))

function sdr(ypred, ygold; ϵ=1e-8)
    num = sum(ygold .*  ypred, dims=1)
    den = multiNorm(ygold, dims=1) .* multiNorm(ypred, dims=1)
    ϵ_array = CUDA.fill(Float32(ϵ), size(den))
    -(num ./ (den  .+ ϵ_array))
end 

For taking gradients w/ Zyogte, I needed to add :

Zygote.@adjoint CUDA.ones(x...) = CUDA.ones(x...), _ -> map(_ -> nothing, x)

Zygote.@adjoint CUDA.zeros(x...) = CUDA.zeros(x...), _ -> map(_ -> nothing, x)

Zygote.@adjoint CUDA.fill(x::Real, dims...) = CUDA.fill(x, dims...), Δ->(sum(Δ), map(_->nothing, dims)...)

Then create a dummy model with dummy input and output :

x = CUDA.rand(ComplexF32, 513, 321, 1, 1); # input
y = CUDA.rand(ComplexF32, 513, 321, 1, 1); # output

# creating a dummy model on gpu
encoder = Chain(Enc((1, 1), (1, 1), 1, 1, (0, 0))) |> gpu

#  ŷ = encoder(x);
# the loss function accepts 3 arguments that are input, prediction, and ground truths.

# to train/update the model 
θ = params(encoder)
opt = ADAM(0.01)
∇ = gradient(wsdrLoss, x, encoder(x), y)[1]
Flux.update!(opt, θ, ∇)

then I get :

ERROR: InvalidIRError: compiling kernel getindex_kernel(CUDA.CuKernelContext, CuDeviceArray{Complex{Float32},4,1}, CuDeviceArray{Complex{Float32},4,1}, Tuple{Int64}, CuDeviceArray{Float32,4,1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to #sprint#355(context, sizehint::Integer, ::typeof(sprint), f::Function, args...) in Base at strings/io.jl:100)
Stacktrace:
 [1] #repr#356 at strings/io.jl:227
 [2] limitrepr at strings/io.jl:229
 [3] to_index at indices.jl:297
 [4] to_index at indices.jl:274
 [5] to_indices at indices.jl:325
 [6] to_indices at indices.jl:322
 [7] getindex at abstractarray.jl:1060
 [8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported dynamic function invocation (call to print)
Stacktrace:
 [1] print_to_string at strings/io.jl:135
 [2] string at strings/io.jl:174
 [3] to_index at indices.jl:297
 [4] to_index at indices.jl:274
 [5] to_indices at indices.jl:325
 [6] to_indices at indices.jl:322
 [7] getindex at abstractarray.jl:1060
 [8] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [9] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_grow_end)
Stacktrace:
 [1] _growend! at array.jl:892
 [2] resize! at array.jl:1085
 [3] print_to_string at strings/io.jl:137
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_del_end)
Stacktrace:
 [1] _deleteend! at array.jl:901
 [2] resize! at array.jl:1090
 [3] print_to_string at strings/io.jl:137
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_array_to_string)
Stacktrace:
 [1] String at strings/string.jl:39
 [2] print_to_string at strings/io.jl:137
 [3] string at strings/io.jl:174
 [4] to_index at indices.jl:297
 [5] to_index at indices.jl:274
 [6] to_indices at indices.jl:325
 [7] to_indices at indices.jl:322
 [8] getindex at abstractarray.jl:1060
 [9] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_alloc_string)
Stacktrace:
 [1] _string_n at strings/string.jl:60
 [2] StringVector at iobuffer.jl:31
 [3] #IOBuffer#331 at iobuffer.jl:114
 [4] print_to_string at strings/io.jl:133
 [5] string at strings/io.jl:174
 [6] to_index at indices.jl:297
 [7] to_index at indices.jl:274
 [8] to_indices at indices.jl:325
 [9] to_indices at indices.jl:322
 [10] getindex at abstractarray.jl:1060
 [11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to jl_string_to_array)
Stacktrace:
 [1] unsafe_wrap at strings/string.jl:71
 [2] StringVector at iobuffer.jl:31
 [3] #IOBuffer#331 at iobuffer.jl:114
 [4] print_to_string at strings/io.jl:133
 [5] string at strings/io.jl:174
 [6] to_index at indices.jl:297
 [7] to_index at indices.jl:274
 [8] to_indices at indices.jl:325
 [9] to_indices at indices.jl:322
 [10] getindex at abstractarray.jl:1060
 [11] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [12] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Reason: unsupported call through a literal pointer (call to __memset_avx2_unaligned_erms)
Stacktrace:
 [1] fill! at array.jl:428
 [2] #IOBuffer#331 at iobuffer.jl:121
 [3] print_to_string at strings/io.jl:133
 [4] string at strings/io.jl:174
 [5] to_index at indices.jl:297
 [6] to_index at indices.jl:274
 [7] to_indices at indices.jl:325
 [8] to_indices at indices.jl:322
 [9] getindex at abstractarray.jl:1060
 [10] macro expansion at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:145
 [11] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139
Stacktrace:
 [1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /opt/.julia/packages/GPUCompiler/uTpNx/src/validation.jl:123
 [2] macro expansion at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
 [3] macro expansion at /opt/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /opt/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:310
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{typeof(GPUArrays.getindex_kernel),Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:40
 [10] getindex_kernel at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:139 [inlined]
 [11] cached_compilation at /opt/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::typeof(GPUArrays.getindex_kernel), ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Complex{Float32},4,1},CuDeviceArray{Complex{Float32},4,1},Tuple{Int64},CuDeviceArray{Float32,4,1}}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
 [13] cufunction at /opt/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
 [14] #launch_heuristic#853 at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
 [15] launch_heuristic at /opt/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
 [16] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}; target::CuArray{Complex{Float32},4}, total_threads::Nothing, threads::Nothing, blocks::Nothing, name::Nothing) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:61
 [17] gpu_call(::typeof(GPUArrays.getindex_kernel), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{Float32},4}, ::Tuple{Int64}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/device/execution.jl:46
 [18] _getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:135
 [19] getindex(::CuArray{Complex{Float32},4}, ::CuArray{Float32,4}) at /opt/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:125
 [20] update!(::ADAM, ::Params, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Flux/q3zeA/src/optimise/train.jl:28
 [21] top-level scope at REPL[32]:1

Any suggestions ?

B.R.

Since you are using Flux.params, it’s typically easiest to work with implicit parameters instead, i.e.

∇ = gradient(() -> wsdrLoss(x, encoder(x), y), θ)

∇ = gradient(() -> wsdrLoss(x, encoder(x), y), θ) still gives error.

ERROR: MethodError: no method matching plan_brfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64)
Closest candidates are:
  plan_brfft(::CuArray{T,N}, ::Integer, ::Any) where {T<:Union{Complex{Float32}, Complex{Float64}}, N} at /opt/.julia/packages/CUDA/YeS8q/lib/cufft/fft.jl:306
  plan_brfft(::AbstractArray, ::Integer; kws...) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:285
Stacktrace:
 [1] plan_irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64; kws::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:334
 [2] plan_irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:334
 [3] irfft(::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/AbstractFFTs/mhQvY/src/definitions.jl:284
 [4] adjoint at /opt/.julia/packages/Zygote/xBjHw/src/lib/array.jl:929 [inlined]
 [5] _pullback(::Zygote.Context, ::typeof(AbstractFFTs.irfft), ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47
 [6] multistft at ./REPL[13]:19 [inlined]
 [7] _pullback(::Zygote.Context, ::typeof(multistft), ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::Int64, ::Int64) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [8] multistft at ./REPL[13]:5 [inlined] (repeats 2 times)
 [9] |> at ./operators.jl:834 [inlined]
 [10] #wsdrLoss#3 at ./REPL[14]:5 [inlined]
 [11] _pullback(::Zygote.Context, ::var"##wsdrLoss#3", ::Float64, ::typeof(wsdrLoss), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [12] wsdrLoss at ./REPL[14]:4 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::CuArray{Complex{Float32},4}, ::CuArray{Complex{ForwardDiff.Dual{Nothing,Float32,2}},4}, ::CuArray{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [14] #24 at ./REPL[29]:1 [inlined]
 [15] _pullback(::Zygote.Context, ::var"#24#25") at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [16] pullback(::Function, ::Params) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:172
 [17] gradient(::Function, ::Params) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:53
 [18] top-level scope at REPL[29]:1

Since you are using Flux.params , it’s typically easiest to work with implicit parameters instead, i.e.

Is there any other way ?

Haven’t figured out why ForwardDiff is being used here, but it I think it doesn’t work with FFT, e.g. yesterday: ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

1 Like

I think it doesn’t work with FFT : ForwardDiff and Zygote cannot automatically differentiate (AD) function from C^n to R that uses FFT

Thank you for the post. Since I am new to Zygote and other related packages I really don’t know what should I try first? For my case, do I need to define gradient rules manually ? If so, this will be out of my knowledge and I crash into a serious issue.

B.R.