CUDA atomic add for ForwardDiff duals?

CUDA.jl only supports atomic operations on primitive types, because we don’t really have support for interpreting aggregate types (like Dual) as a single bitstype, and it isn’t legal to split the atomic operation to act separately on the fields of the Dual. You can add the necessary definitions, of course:

using CUDA

# minimal Dual type
struct Dual{T}
    real::T
    dual::T
end
Base.one(::Type{Dual{T}}) where {T} = Dual{T}(1, 0)
Base.:(+)(x::Dual{T}, y::Dual{T}) where T = Dual{T}(x.real + y.real, x.dual + y.dual)

# approach 1: convince CUDA you can do an atomic addition of duals separately.
#             this is not technically correct as the addition isn't truely atomic,
#             but might suffice depending on the use case (e.g. global accumulation).
#             this will only work for element types that are supported by atomic_add!
@inline function CUDA.atomic_arrayset(A::AbstractArray{Dual{T}}, I::Integer, op::typeof(+),
                                      val::Dual{T}) where {T}
    real_ptr = pointer(reinterpret(T, A), (I-1)*2+1)
    CUDA.atomic_add!(real_ptr, val.real)
    dual_ptr = pointer(reinterpret(T, A), (I-1)*2+2)
    CUDA.atomic_add!(dual_ptr, val.dual)
end

# approach 2: convince CUDA you can do an atomic exchange of duals by casting to a bitstype.
#             this is better, because it will ensure the atomic operation happens atomically,
#             but relies on the atomic fallback mechanism (i.e. not an atomic addition, but
#             a compare-and-swap loop) and is more limited because it requires the widened
#             type being supported by the hardware for CAS (64-bits only is on sm_60+,
#             128-bits CAS isn't supported)
using Core: LLVMPtr
widen(::Type{Float32}) = Float64
@inline function CUDA.atomic_cas!(ptr::LLVMPtr{Dual{T},A}, cmp::Dual{T}, val::Dual{T}) where {T,A}
    U = widen(T)
    wide_ptr = reinterpret(LLVMPtr{U,A}, ptr)

    # XXX: this is JuliaLang/julia#43065 (type punning of aggregates)
    cmp_ref = Ref(cmp)
    wide_cmp = GC.@preserve cmp_ref begin
        cmp_ptr = Base.unsafe_convert(Ptr{Dual{T}}, cmp_ref)
        unsafe_load(reinterpret(Ptr{U}, cmp_ptr))
    end

    val_ref = Ref(cmp)
    wide_val = GC.@preserve val_ref begin
        val_ptr = Base.unsafe_convert(Ptr{Dual{T}}, val_ref)
        unsafe_load(reinterpret(Ptr{U}, val_ptr))
    end

    wide_ret = CUDA.atomic_cas!(wide_ptr, wide_cmp, wide_val)

    wide_ret_ref = Ref(wide_ret)
    GC.@preserve wide_ret_ref begin
        wide_ret_ptr = Base.unsafe_convert(Ptr{U}, wide_ret_ref)
        unsafe_load(reinterpret(Ptr{Dual{T}}, wide_ret_ptr))
    end
end

function kernel(A)
    i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    if i <= length(A)
        CUDA.@atomic A[i] += one(eltype(A))
    end
    return
end

function main()
    @show A = cu([1,2])
    @cuda threads=length(A) kernel(A)
    @show A

    @show B = cu([Dual{Float32}(1,2),Dual{Float32}(3,4)])
    @cuda threads=length(B) kernel(B)
    @show B
end
1 Like