Hessian Vector Products on GPU using ForwardDiff, Zygote, and Flux

Hello,

I am trying to compute the Hessian vector product of a Flux model (with some vector) on the GPU. This is giving me some errors, which I think are realted specifically to ForwardDiff. Any help would be greatly appreciated!

The following defines an Hvp operator which works fine on the CPU:

using ForwardDiff: partials, Dual
using Zygote: pullback
using LinearAlgebra

mutable struct HvpOperator{F, T, I}
	f::F
	x::AbstractArray{T, 1}
	dualCache1::AbstractArray{Dual{Nothing, T, 1}}
	size::I
	nProd::I
end

function HvpOperator(f, x::AbstractVector)
	dualCache1 = Dual.(x, similar(x))
	return HvpOperator(f, x, dualCache1, size(x, 1), 0)
end

Base.eltype(op::HvpOperator{F, T, I}) where{F, T, I} = T
Base.size(op::HvpOperator) = (op.size, op.size)

function LinearAlgebra.mul!(result::AbstractVector, op::HvpOperator, v::AbstractVector)
	op.nProd += 1

	op.dualCache1 .= Dual.(op.x, v)
	val, back = pullback(op.f, op.dualCache1)

	result .= partials.(back(one(val))[1], 1)
end

Ignore the Nothing tags for now as I know I shouldn’t be doing that. The following code defines a MWE to use the Hvp operator on the GPU:

using Flux, CUDA
data = randn(10, 4) |> gpu
model = Dense(10, 1, σ) |> gpu

ps, re = Flux.destructure(model)
f(θ) = re(θ)(data)

Hop = HvpOperator(f, ps)

v, res = similar(ps), similar(ps)

LinearAlgebra.mul!(res, Hop, v)

This results in the following error:

ERROR: LoadError: MethodError: no method matching cudnnDataType(::Type{Dual{Nothing, Float32, 1}})
Closest candidates are:
  cudnnDataType(::Type{Float16}) at /home/rs-coop/.julia/packages/CUDA/nYggH/lib/cudnn/util.jl:7
  cudnnDataType(::Type{Float32}) at /home/rs-coop/.julia/packages/CUDA/nYggH/lib/cudnn/util.jl:8
  cudnnDataType(::Type{Float64}) at /home/rs-coop/.julia/packages/CUDA/nYggH/lib/cudnn/util.jl:9
  ...
Stacktrace:
  [1] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}; format::CUDA.CUDNN.cudnnTensorFormat_t, dims::Vector{Int32})
    @ CUDA.CUDNN ~/.julia/packages/CUDA/nYggH/lib/cudnn/tensor.jl:9
  [2] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer})
    @ CUDA.CUDNN ~/.julia/packages/CUDA/nYggH/lib/cudnn/tensor.jl:8
  [3] cudnnActivationForward!(y::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, x::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}; o::Base.Iterators.Pairs{Symbol, CUDA.CUDNN.cudnnActivationMode_t, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{CUDA.CUDNN.cudnnActivationMode_t}}})
    @ CUDA.CUDNN ~/.julia/packages/CUDA/nYggH/lib/cudnn/activation.jl:22
  [4] (::NNlibCUDA.var"#65#69")(src::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, dst::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer})
    @ NNlibCUDA ~/.julia/packages/NNlibCUDA/IeeBk/src/cudnn/activations.jl:11
  [5] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(σ), Tuple{CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}}})
    @ NNlibCUDA ~/.julia/packages/NNlibCUDA/IeeBk/src/cudnn/activations.jl:30
  [6] rrule(#unused#::typeof(Base.Broadcast.broadcasted), #unused#::typeof(σ), x::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer})
    @ NNlib ~/.julia/packages/NNlib/y5z4i/src/activations.jl:813
  [7] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Function, ::Function, ::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/oBjCg/src/rules.jl:134
  [8] chain_rrule
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:216 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::typeof(σ), ::CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:9
 [11] _pullback
    @ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:158 [inlined]
 [12] _pullback(ctx::Zygote.Context, f::Dense{typeof(σ), CuArray{Dual{Nothing, Float32, 1}, 2, CUDA.Mem.DeviceBuffer}, CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/Documents/research/masters-thesis/CubicNewton/experiments/test.jl:43 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::typeof(f), args::CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [15] _pullback(f::Function, args::CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:34
 [16] pullback(f::Function, args::CuArray{Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:40
 [17] mul!(result::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, op::HvpOperator{typeof(f), Float32, Int64}, v::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Main ~/Documents/research/masters-thesis/CubicNewton/experiments/test.jl:34
 [18] top-level scope
    @ ~/Documents/research/masters-thesis/CubicNewton/experiments/test.jl:49
 [19] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [20] top-level scope
    @ REPL[14]:1
 [21] top-level scope
    @ ~/.julia/packages/CUDA/nYggH/src/initialization.jl:52

I know there is quite a lot of information here, so if I can clarify or add anything let me know. Thanks again for the help!

1 Like

you probably need to strictly type the Duals. Did you try the SparseDiffTools implementations?

@ChrisRackauckas It seems like those SparseDiffTools functions are only exported for a particular, kind of old version of Zygote (SparseDiffTools.jl/SparseDiffTools.jl at master · JuliaDiff/SparseDiffTools.jl · GitHub). Is there a reason for this, and/or is it tied to a current SparseDiffTools or Zygote issue?

No, they support the most recent version.

My operator is heavily based on SparseDiffTools, but I haven’t tried the implementation there. When you say strictly typing the Duals do you mean writing Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1} as opposed to Dual{Nothing, T, 1}? In other words just specifying the type T? I am also not quite sure I understand the idea behind the tag, what should I be putting there?

The full type specification can improve inference for Duals.

Though reading this again, the issue is a cudnn one, in that cudnn is not compatible with dual numbers. I think you’d have to use a different form of forward differentiation to get around that, or a different dense kernel.

To clarify, ForwardDiff (or Duals in particular) are compatible with CUDA arrays, but not with cuDNN? Without asking too much of you, can you comment on possible alternatives? I think Zygote has some hidden functionality for forward mode AD, and there is the fabled Diffractor.jl?

You’d need to check for an frule on this op.

Ok, I think I understand the issue.

When I remove the sigmoid activation from my model the code runs fine on the GPU because there is no issue with an frule being defined for the cuDNN sigmoid. Obviously the hvp will be trivially zero in this case, but it still works.

If I wanted to use a sigmoid then it seems I have two options:

  • Define an frule for the cuDNN sigmoid
  • Make my own sigmoid function which can then be differentiated using lower level frules even on the GPU

Is that correct?

I haven’t dug into the details of your case, but it certainly sounds like that’s the right track. I’m not sure it’s using the fused sigmoid rule, but it could be, in which case that would need an frule.

Things seem to be working by defining my own activation functions. Thanks for the help here!

Open an issue so that future users don’t have to workaround it.

1 Like

With ForwardDiff, Flux, or NNlibCUDA? Seems like one of the last two, maybe Flux because they are trying to be AD agnostic.

I think it’s an NNlibCUDA issue.