CUDA.jl only supports atomic operations on primitive types, because we don’t really have support for interpreting aggregate types (like Dual) as a single bitstype, and it isn’t legal to split the atomic operation to act separately on the fields of the Dual. You can add the necessary definitions, of course:
using CUDA
# minimal Dual type
struct Dual{T}
real::T
dual::T
end
Base.one(::Type{Dual{T}}) where {T} = Dual{T}(1, 0)
Base.:(+)(x::Dual{T}, y::Dual{T}) where T = Dual{T}(x.real + y.real, x.dual + y.dual)
# approach 1: convince CUDA you can do an atomic addition of duals separately.
# this is not technically correct as the addition isn't truely atomic,
# but might suffice depending on the use case (e.g. global accumulation).
# this will only work for element types that are supported by atomic_add!
@inline function CUDA.atomic_arrayset(A::AbstractArray{Dual{T}}, I::Integer, op::typeof(+),
val::Dual{T}) where {T}
real_ptr = pointer(reinterpret(T, A), (I-1)*2+1)
CUDA.atomic_add!(real_ptr, val.real)
dual_ptr = pointer(reinterpret(T, A), (I-1)*2+2)
CUDA.atomic_add!(dual_ptr, val.dual)
end
# approach 2: convince CUDA you can do an atomic exchange of duals by casting to a bitstype.
# this is better, because it will ensure the atomic operation happens atomically,
# but relies on the atomic fallback mechanism (i.e. not an atomic addition, but
# a compare-and-swap loop) and is more limited because it requires the widened
# type being supported by the hardware for CAS (64-bits only is on sm_60+,
# 128-bits CAS isn't supported)
using Core: LLVMPtr
widen(::Type{Float32}) = Float64
@inline function CUDA.atomic_cas!(ptr::LLVMPtr{Dual{T},A}, cmp::Dual{T}, val::Dual{T}) where {T,A}
U = widen(T)
wide_ptr = reinterpret(LLVMPtr{U,A}, ptr)
# XXX: this is JuliaLang/julia#43065 (type punning of aggregates)
cmp_ref = Ref(cmp)
wide_cmp = GC.@preserve cmp_ref begin
cmp_ptr = Base.unsafe_convert(Ptr{Dual{T}}, cmp_ref)
unsafe_load(reinterpret(Ptr{U}, cmp_ptr))
end
val_ref = Ref(cmp)
wide_val = GC.@preserve val_ref begin
val_ptr = Base.unsafe_convert(Ptr{Dual{T}}, val_ref)
unsafe_load(reinterpret(Ptr{U}, val_ptr))
end
wide_ret = CUDA.atomic_cas!(wide_ptr, wide_cmp, wide_val)
wide_ret_ref = Ref(wide_ret)
GC.@preserve wide_ret_ref begin
wide_ret_ptr = Base.unsafe_convert(Ptr{U}, wide_ret_ref)
unsafe_load(reinterpret(Ptr{Dual{T}}, wide_ret_ptr))
end
end
function kernel(A)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if i <= length(A)
CUDA.@atomic A[i] += one(eltype(A))
end
return
end
function main()
@show A = cu([1,2])
@cuda threads=length(A) kernel(A)
@show A
@show B = cu([Dual{Float32}(1,2),Dual{Float32}(3,4)])
@cuda threads=length(B) kernel(B)
@show B
end