Flux.withgradient() doesn't work in a CUDA kernel

Hi all,

I’m adapting some work for approximations of the Hilbert-Schmidt Independence Criterion (HSIC) from Python to Julia, and part of the code involves optimising which samples should be taken and what the widths of the radial basis function (RBF) kernels should be to maximise HSIC’s test power. This code involves optimising over a certain objective function and taking the gradients with respect to the samples and the widths.

Now, to speed up my code, I’ve decided to use CUDA.jl, since I have access to a GPU. Nothing particularly fancy - I’ve just been using CuArrays and broadcasting operations where I can to speed up the work on larger data sets that I’m working with. However, this seems to be an issue when I use Flux.withgradient() to get the gradient of my objective function.

I can’t put more than two links in my post so here’s the Python file where all the code comes from [link]. The specific line numbers are in the docstrings.

Here’s my Julia code for my main objective function:

function generic_optimize_locs_widths(X::AbstractArray, Y::AbstractArray,
        V0::AbstractArray, W0::AbstractArray, gwidthx0::Float64,
        gwidthy0::Float64, func_obj::Function;
        max_iter::Int64=400, V_step::Float64=1.0, W_step::Float64=1.0,
        gwidthx_step::Float64=1.0, gwidthy_step::Float64=1.0,
        batch_proportion::Float64=1.0, tol_fun::Float64=1e-3, 
        step_pow::Float64=0.5, reg::Float64=1e-5,
        gwidthx_lb::Float64=1e-3, gwidthx_ub::Float64=1e6,
        gwidthy_lb::Float64=1e-3, gwidthy_ub::Float64=1e6)

    """
    https://github.com/wittawatj/fsic-test/blob/master/fsic/indtest.py#L553-L723
    """

    if size(V0, 1) != size(W0, 1) 
        error("V0 and W0 must have the same number of rows J.")
    end

    constrain(var::Float64, lb::Float64, ub::Float64) = max(min(var, ub), lb)

    it = 1.0

    # heuristic to prevent step sizes from being too large
    max_gwidthx_step = minimum(std(X, dims=1)) / 2.0
    max_gwidthy_step = minimum(std(Y, dims=1)) / 2.0
    old_S = 0
    S = 0
    Vth = V0
    Wth = W0
    gwidthx_th = sqrt(gwidthx0)
    gwidthy_th = sqrt(gwidthy0)

    n = size(Y, 1)
    J = size(V0, 1)

    for t=1:max_iter
        # stochastic gradient ascent
        ind = sample(1:n, min(floor(Int64, batch_proportion * n), n);
            ordered=true, replace=false)
        try
            # Represent this as a function so I can get its gradient for later
            s(nt::NamedTuple) = func_obj(X[ind, :], Y[ind, :], nt.V, nt.W, nt.gwidthx,
                nt.gwidthy, reg, n, J)
            # @time (
            params = (V = Vth, W = Wth,
                gwidthx = gwidthx_th^2, gwidthy = gwidthy_th^2)
            S, gradient = Flux.withgradient(s, params)
            println("calculated gradient!")
            g = gradient[1]

            g_V = g.V;               g_W = g.W
            g_gwidthx = g.g_widthx; g_gwidthy = g.g_widthy

            # updates
            Vth .+= (V_step / it^step_pow / sqrt(mapreduce(x -> x^2, +, g_V))) .* g_V
            Wth .+= (W_step / it^step_pow / sqrt(mapreduce(x -> x^2, +, g_W))) .* g_W
            it += 1
            gwidthx_th = constrain(
                gwidthx_th + gwidthx_step * sign(g_gwidthx) * 
                    min(abs(g_gwidthx), max_gwidthx_step) / it^step_pow,
                    sqrt(gwidthx_lb), sqrt(gwidthx_ub)
            )
            gwidthy_th = constrain(
                gwidthy_th + gwidthy_step * sign(g_gwidthy) * 
                    min(abs(g_gwidthy), max_gwidthy_step) / it^step_pow,
                    sqrt(gwidthy_lb), sqrt(gwidthy_ub)
            )

            if t >= 4 && abs(old_S - S) <= tol_fun break end
            old_S = S
            # )
        catch e
            println("Exception occurred during gradient descent. Stop optimization.")
            println("Return the value from previous iter. ")
            throw(e)
            break
        end

        if t >= 0  return (Vth, Wth, gwidthx_th, gwidthy_th)
        else       return (V0, W0, gwidthx0, gwidthy0) # Probably an error occurred in the first iteration.
        end
    end
end

and here’s my code for the objective function:

function func_obj(Xth::AbstractArray, Yth::AbstractArray, Vth::AbstractArray,
        Wth::AbstractArray, gwidthx_th::Float64, gwidthy_th::Float64,
        regth::Float64, n::Int64, J::Int64)
    """
    https://github.com/wittawatj/fsic-test/blob/master/fsic/indtest.py#L242-L277
    """
    diag_regth = regth .* Matrix(1.0I, J, J)
    Kth = rbf_dot(Xth, Vth, gwidthx_th)
    Lth = rbf_dot(Yth, Wth, gwidthy_th)

    mean_k = mean(Kth, dims=1)
    mean_l = mean(Lth, dims=1)
    KLth = Kth .* Lth
    u = mean(KLth, dims=1) .- mean_k .* mean_l

    Kth_norm = Kth .- mean_k
    Lth_norm = Lth .- mean_l
    # Gam is n x J
    Gam = (Kth_norm .* Lth_norm .- u) .- mean(Kth_norm .* Lth_norm .- u, dims=1)
    Sig = Gam' * Gam ./ Float64(n)
    dot(inv(Sig .+ diag_regth) * u', u) / Float64(n)
end

Note, the function rbf_dot(A, B, sigma) calculates the RBF kernel matrix between data sources A and B with kernel width of sigma. Here’s the code for that:

function rbf_dot(pattern1::AbstractArray, pattern2::AbstractArray, deg::Float64)
    """
    Constructs the RBF kernel matrix between two input sources of data,
    `pattern1` and `pattern2`, using a given Gaussian kernel width.

    ## Parameters
    - pattern1: the first input source of data
    - pattern2: the second input source of data
    - deg: the width of the Gaussian kernel to calculate

    ## Returns
    A kernel matrix H calculating the distance between each point in pattern1
    and pattern2.

    https://github.com/amber0309/HSIC/blob/master/HSIC.py
    """
    # Get sum of Hadamard product along the rows
    G = sum(pattern1 .* pattern1, dims=2)
    H = sum(pattern2 .* pattern2, dims=2)

    # Project both matrices into a common space
    H = (G .+ H') .- (2 .* (pattern1 * pattern2'))

    # Apply Gaussian distance calculation
    map(exp, -H ./ (2 * deg^2))
end

If you want to run this as a MWE here are the packages you’ll need:

using CUDA, Flux, LinearAlgebra, Statistics
import StatsBase: sample

Also, here’s a quick and easy way to get data for a MWE:

# feel free to change these dimensions if you'd like 
# but make sure size(X, 1) == size(Y, 1)
X = CUDA.randn(100, 1) # remove CUDA. to get no bugs
Y = CUDA.randn(100, 3) # remove CUDA. to get no bugs
# also feel free to change the number of samples
n = size(X, 1)
V0 = X[sample(1:n, n / 5; ordered=true, replace=false), :]
W0 = Y[sample(1:n, n / 5; ordered=true, replace=false), :]
# welcome to change the widths if you like
gwidthx0 = 1.5
gwidthy0 = 1.5
# the parameter `func_obj` in `generic_optimize_locs_widths()` is the `func_obj()` function I mentioned in this post.

Here’s the stack trace I get:

ERROR: LoadError: GPU compilation of kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .f is of type Zygote.var"#814#818"{Zygote.Context, typeof(exp)} which is not isbits.
    .cx is of type Zygote.Context which is not isbits.
      .cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.


Stacktrace:
  [1] generic_optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, V0::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, W0::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, gwidthx0::Float64, gwidthy0::Float64, func_obj::typeof(LowRankModels.func_obj); max_iter::Int64, V_step::Float64, W_step::Float64, gwidthx_step::Float64, gwidthy_step::Float64, batch_proportion::Float64, tol_fun::Float64, step_pow::Float64, reg::Float64, gwidthx_lb::Float64, gwidthx_ub::Float64, gwidthy_lb::Float64, gwidthy_ub::Float64)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:729
  [2] generic_optimize_locs_widths
    @ ~/LowRankModels.jl/src/hsic.jl:660 [inlined]
  [3] optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}; n_test_locs::Int64, max_iter::Int64, V_step::Float64, W_step::Float64, gwidthx_step::Float64, gwidthy_step::Float64, batch_proportion::Float64, tol_fun::Float64, step_pow::Float64, seed::Int64, reg::Float64, gwidthx_lb::Nothing, gwidthy_lb::Nothing, gwidthx_ub::Nothing, gwidthy_ub::Nothing)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:452
  [4] optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:421
  [5] get_nfsic(Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:269
  [6] get_independence_criterion
    @ ~/LowRankModels.jl/src/hsic.jl:38 [inlined]
  [7] get_independence_criterion(Y::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, X::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, ic::DataType)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:47
  [8] HSICReg(scale::Float64, s::Vector{Float64}, X::Vector{Float64}, ic::DataType)
    @ LowRankModels ~/LowRankModels.jl/src/regularizers.jl:613
  [9] test_pca(test_reg::String)
    @ Main ~/LowRankModels.jl/test/fair_glrms/experiments/test_pca.jl:88
 [10] top-level scope
    @ ~/LowRankModels.jl/test/fair_glrms/experiments/test_independence.jl:4
in expression starting at /home/remote/u6390710/LowRankModels.jl/test/fair_glrms/experiments/test_independence.jl:4

caused by: GPU compilation of kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .f is of type Zygote.var"#814#818"{Zygote.Context, typeof(exp)} which is not isbits.
    .cx is of type Zygote.Context which is not isbits.
      .cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/S3TWf/src/validation.jl:88
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/S3TWf/src/driver.jl:154 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/S3TWf/src/driver.jl:152 [inlined]
  [5] emit_julia(job::GPUCompiler.CompilerJob; validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/S3TWf/src/utils.jl:83
  [6] emit_julia
    @ ~/.julia/packages/GPUCompiler/S3TWf/src/utils.jl:77 [inlined]
  [7] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.ThreadSafeContext)
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:353
  [8] #228
    @ ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:348 [inlined]
  [9] LLVM.ThreadSafeContext(f::CUDA.var"#228#229"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#17", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}})
    @ LLVM ~/.julia/packages/LLVM/HykgZ/src/executionengine/ts_module.jl:14
 [10] JuliaContext(f::CUDA.var"#228#229"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#17", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/S3TWf/src/driver.jl:74
 [11] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:347
 [12] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/S3TWf/src/cache.jl:90
 [13] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, always_inline::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:300
 [14] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float64, Zygote.ZBack{ChainRules.var"#exp_pullback#1305"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}})
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:293
 [15] macro expansion
    @ ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:102 [inlined]
 [16] #launch_heuristic#252
    @ ~/.julia/packages/CUDA/BbliS/src/gpuarrays.jl:17 [inlined]
 [17] launch_heuristic
    @ ~/.julia/packages/CUDA/BbliS/src/gpuarrays.jl:15 [inlined]
 [18] _copyto!
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:73 [inlined]
 [19] copyto!
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:56 [inlined]
 [20] copy
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:47 [inlined]
 [21] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, Zygote.var"#814#818"{Zygote.Context, typeof(exp)}, Tuple{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}})
    @ Base.Broadcast ./broadcast.jl:873
 [22] map(::Function, ::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
    @ GPUArrays ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:92
 [23] ∇map
    @ ~/.julia/packages/Zygote/IoW2g/src/lib/array.jl:195 [inlined]
 [24] adjoint
    @ ~/.julia/packages/Zygote/IoW2g/src/lib/array.jl:221 [inlined]
 [25] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [26] _pullback
    @ ~/LowRankModels.jl/src/hsic.jl:89 [inlined]
 [27] _pullback(::Zygote.Context, ::typeof(LowRankModels.rbf_dot), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [28] _pullback
    @ ~/LowRankModels.jl/src/hsic.jl:749 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(LowRankModels.func_obj), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float64, ::Float64, ::Float64, ::Int64, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [30] _pullback
    @ ~/LowRankModels.jl/src/hsic.jl:696 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::LowRankModels.var"#s#60"{Float64, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, typeof(LowRankModels.func_obj), Int64, Int64, Vector{Int64}}, args::NamedTuple{(:V, :W, :gwidthx, :gwidthy), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [32] _pullback(f::Function, args::NamedTuple{(:V, :W, :gwidthx, :gwidthy), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:34
 [33] pullback(f::Function, args::NamedTuple{(:V, :W, :gwidthx, :gwidthy), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:40
 [34] withgradient(f::Function, args::NamedTuple{(:V, :W, :gwidthx, :gwidthy), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:102
 [35] generic_optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, V0::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, W0::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, gwidthx0::Float64, gwidthy0::Float64, func_obj::typeof(LowRankModels.func_obj); max_iter::Int64, V_step::Float64, W_step::Float64, gwidthx_step::Float64, gwidthy_step::Float64, batch_proportion::Float64, tol_fun::Float64, step_pow::Float64, reg::Float64, gwidthx_lb::Float64, gwidthx_ub::Float64, gwidthy_lb::Float64, gwidthy_ub::Float64)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:701
 [36] generic_optimize_locs_widths
    @ ~/LowRankModels.jl/src/hsic.jl:660 [inlined]
 [37] optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}; n_test_locs::Int64, max_iter::Int64, V_step::Float64, W_step::Float64, gwidthx_step::Float64, gwidthy_step::Float64, batch_proportion::Float64, tol_fun::Float64, step_pow::Float64, seed::Int64, reg::Float64, gwidthx_lb::Nothing, gwidthy_lb::Nothing, gwidthx_ub::Nothing, gwidthy_ub::Nothing)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:452
 [38] optimize_locs_widths(X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:421
 [39] get_nfsic(Y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, X::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:269
 [40] get_independence_criterion
    @ ~/LowRankModels.jl/src/hsic.jl:38 [inlined]
 [41] get_independence_criterion(Y::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, X::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, ic::DataType)
    @ LowRankModels ~/LowRankModels.jl/src/hsic.jl:47
 [42] HSICReg(scale::Float64, s::Vector{Float64}, X::Vector{Float64}, ic::DataType)
    @ LowRankModels ~/LowRankModels.jl/src/regularizers.jl:613
 [43] test_pca(test_reg::String)
    @ Main ~/LowRankModels.jl/test/fair_glrms/experiments/test_pca.jl:88
 [44] top-level scope
    @ ~/LowRankModels.jl/test/fair_glrms/experiments/test_independence.jl:4

Note, there is a similar post here as a GitHub issue in Flux.jl, but I didn’t really see any resolution from this or any documentation as to how to get Flux and CUDA to play nicely together

My question is, can I have my cake (CUDA) and eat it too (Flux)? Or do I have to choose to have my auto-differentiator run on the CPU, or analytically compute the gradient for my objective function?

Thanks heaps for any help. Also, please let me know if there’s any forum etiquette I’ve missed - this is my first post on the Discourse so I’d appreciate any help to make the post easier to interpret.