How to calculate the Hessian vector product in a Flux model?

Hello.
I implemented the Hessian vector product referring to the following link: Hessian Vector Products on GPU using ForwardDiff, Zygote, and Flux.
The implementation works correctly on both CPU and GPU for a flux model without a Convolutional layer. However, when the Convolutional layer is included, it functions properly on the CPU but throws an error on the GPU. Could someone please advise me on how to resolve this issue? Thank you in advance.
The following code for calculating the Hessian vector product:

using ForwardDiff: partials, Dual
using Zygote: pullback
using LinearAlgebra

mutable struct HvpOperator{F, T, I}
    f::F
    x::AbstractArray{T, 1}
    dualCache1::AbstractArray{Dual{Nothing, T, 1}}
    size::I
    nProd::I
end

function HvpOperator(f, x::AbstractVector)
    dualCache1 = Dual.(x, similar(x))
    return HvpOperator(f, x, dualCache1, size(x, 1), 0)
end

Base.eltype(op::HvpOperator{F, T, I}) where{F, T, I} = T
Base.size(op::HvpOperator) = (op.size, op.size)

function LinearAlgebra.mul!(result::AbstractVector, op::HvpOperator, v::AbstractVector)
    op.nProd += 1

    op.dualCache1 .= Dual.(op.x, v)
    val, back = pullback(op.f, op.dualCache1) 

    result .= partials.(back(one(val))[1], 1)
end

function Hvp(v::AbstractVector, model, loss, input, label)
    ps, re = Flux.destructure(model)
    f(θ) = loss(re(θ), input, label)
    
    Hop = HvpOperator(f, ps)
    
    res = similar(ps)
    
    res = LinearAlgebra.mul!(res, Hop, v)
    
    return res 
end;

The following example illustrates the error that occurs when including a Convolutional layer:

using MLDatasets

mnist_data = MNIST(:train)
mnist_input = mnist_data.features
mnist_input = reshape(mnist_input, (28, 28, 1, 60000))
mnist_label = onehotbatch(mnist_data.targets, 0:9)

train_data = mnist_input[:, :, :, 1:500]|> gpu
train_label = mnist_label[:, 1:500]|> gpu

model = model = Chain(Conv((5, 5), 1 => 3),
            MaxPool((2, 2)),
            Flux.flatten,
            Dense(432, 120, relu),
            Dense(120, 80, relu),
            Dense(80, 10))|> gpu

function MSE(model, input, label)
    return Flux.Losses.mse(model(input), label)
end

ps,_ = Flux.destructure(model)
v = copy(ps)

CUDA.allowscalar(false)

result = Hvp(v, model, MSE, train_data, train_label)

If you run this code, you may encounter the following error message:

TaskFailedException

    nested task error: Scalar indexing is disallowed.
    Invocation of getindex resulted in scalar indexing of a GPU array.
    This is typically caused by calling an iterating implementation of a method.
    Such implementations *do not* execute on the GPU, but very slowly on the CPU,
    and therefore are only permitted from the REPL for prototyping purposes.
    If you did intend to index this array, annotate the caller with @allowscalar.
    Stacktrace:
     [1] error(s::String)
       @ Base .\error.jl:35
     [2] assertscalar(op::String)
       @ GPUArraysCore C:\Users\Taizo\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
     [3] getindex(::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64, ::Int64, ::Int64, ::Vararg{Int64})
       @ GPUArrays C:\Users\Taizo\.julia\packages\GPUArrays\t0LfC\src\host\indexing.jl:9
     [4] getindex
       @ .\subarray.jl:282 [inlined]
     [5] conv_direct!(y::SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}, ::Val{(5, 5, 1)}, ::Val{3}, ::Val{(0, 0, 0, 0, 0, 0)}, ::Val{(1, 1, 1)}, ::Val{(1, 1, 1)}, fk::Val{false}; alpha::Dual{Nothing, Float32, 1}, beta::Bool)
       @ NNlib C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\impl\conv_direct.jl:104
     [6] conv_direct!(y::SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; alpha::Dual{Nothing, Float32, 1}, beta::Bool)
       @ NNlib C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\impl\conv_direct.jl:50
     [7] conv_direct!
       @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\impl\conv_direct.jl:47 [inlined]
     [8] (::NNlib.var"#314#318"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, SubArray{Dual{Nothing, Float32, 1}, 5, CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib .\threadingconstructs.jl:258

Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base .\task.jl:436
  [2] macro expansion
    @ .\task.jl:455 [inlined]
  [3] conv!(out::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, in1::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, in2::CuArray{Dual{Nothing, Float32, 1}, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:205
  [4] conv!
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:185 [inlined]
  [5] #conv!#264
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:145 [inlined]
  [6] conv!
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:140 [inlined]
  [7] conv(x::CuArray{Dual{Nothing, Float32, 1}, 4, CUDA.Mem.DeviceBuffer}, w::CuArray{Dual{Nothing, Float32, 1}, 4, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:88
  [8] conv
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:83 [inlined]
  [9] #rrule#379
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:355 [inlined]
 [10] rrule
    @ C:\Users\Taizo\.julia\packages\NNlib\Jmwx0\src\conv.jl:345 [inlined]
 [11] rrule
    @ C:\Users\Taizo\.julia\packages\ChainRulesCore\0t04l\src\rules.jl:134 [inlined]
 [12] chain_rrule
    @ C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\chainrules.jl:223 [inlined]
 [13] macro expansion
    @ C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:0 [inlined]
 [14] _pullback
    @ C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:9 [inlined]
 [15] _pullback
    @ C:\Users\Taizo\.julia\packages\Flux\EHgZm\src\layers\conv.jl:202 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::Conv{2, 4, typeof(identity), CuArray{Dual{Nothing, Float32, 1}, 4, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, args::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:0
 [17] macro expansion
    @ C:\Users\Taizo\.julia\packages\Flux\EHgZm\src\layers\basic.jl:53 [inlined]
 [18] _pullback
    @ C:\Users\Taizo\.julia\packages\Flux\EHgZm\src\layers\basic.jl:53 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{Conv{2, 4, typeof(identity), CuArray{Dual{Nothing, Float32, 1}, 4, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(relu), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(relu), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:0
 [20] _pullback
    @ C:\Users\Taizo\.julia\packages\Flux\EHgZm\src\layers\basic.jl:51 [inlined]
 [21] _pullback
    @ .\In[133]:19 [inlined]
 [22] _pullback(::Zygote.Context{false}, ::typeof(MSE), ::Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Dual{Nothing, Float32, 1}, 4, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(relu), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(relu), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}})
    @ Zygote C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:0
 [23] _pullback
    @ .\In[48]:32 [inlined]
 [24] _pullback(ctx::Zygote.Context{false}, f::var"#f#37"{typeof(MSE), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}}, Optimisers.Restructure{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, Tuple{}, Tuple{}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}}, args::CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface2.jl:0
 [25] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface.jl:44
 [26] pullback
    @ C:\Users\Taizo\.julia\packages\Zygote\TSj5C\src\compiler\interface.jl:42 [inlined]
 [27] mul!(result::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, op::HvpOperator{var"#f#37"{typeof(MSE), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}}, Optimisers.Restructure{Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, Tuple{}, Tuple{}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}}, Float32, Int64}, v::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Main .\In[48]:25
 [28] Hvp(v::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, model::Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(relu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, loss::typeof(MSE), input::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, label::OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}})
    @ Main .\In[48]:38
 [29] top-level scope
    @ In[133]:27

I understand that there may be some parts that are difficult to understand, but I would appreciate any advice you can provide. Thank you in advance.

There is no GPU path for convolutions in Flux which supports ForwardDiff, so it hits a CPU-only fallback which fails. Unfortunately, your options are either to define an AD rule for NNlib.conv (which Flux’s conv layer uses internally) using ForwardDiff’s rule system or to calculate the HVP using nested reverse AD.