Custom backpropagation rule on GPU

Hello I am trying to define the custom Chainrules rrule to be able to use Enzyme on selected functions for backpropagation done generally with Zygote . Kernel and taking derivatives using enzyme works. However, taking jacobian using Zygote with this custom rule already defined do not.
I inspected the kernel with @device_code_warntype but all seems to be fine.

imports and test data

using ChainRulesCore,Zygote,CUDA,Enzyme
Nx, Ny, Nz = 8, 8, 8
oneSidePad = 1
totalPad=oneSidePad*2
A = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dA= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dAout= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

p = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dp= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

kernel and its derivative definition

function testKern(A, p, Aout)
    #adding one bewcouse of padding
    x = (threadIdx().x + ((blockIdx().x - 1) * CUDA.blockDim_x())) + 1
    y = (threadIdx().y + ((blockIdx().y - 1) * CUDA.blockDim_y())) + 1
    z = (threadIdx().z + ((blockIdx().z - 1) * CUDA.blockDim_z())) + 1
    Aout[x, y, z] = A[x, y, z] *p[x, y, z] *p[x, y, z] *p[x, y, z] 
    
    return nothing
end

function testKernDeff( A, dA, p
    , dp, Aout
    , dAout)
    Enzyme.autodiff_deferred(testKern, Const, Duplicated(A, dA), Duplicated(p, dp), Duplicated(Aout, dAout)
    )
    return nothing
end

running kernel all seems fine

threads = (4, 4, 4)
blocks = (2, 2, 2)
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
@device_code_warntype @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
maximum(dp)# 4
maximum(dA)# 2

ChainRules rrule definition

function ChainRulesCore.rrule(::typeof(testKern), A, p,Aout)

    function call_test_kernel1_pullback(dAout)
        # Allocate shadow memory.
        threads = (4, 4, 4)
        blocks = (2, 2, 2)
        dp = CUDA.ones(size(p))
        dA = CUDA.ones(size(A))
        @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, collect(dAout))

        f̄ = NoTangent()
        x̄ = dA
        ȳ = dp
        
        return f̄, x̄, ȳ
    end
    
    return Aout, call_test_kernel1_pullback
end

testing whether rrule compiles

Zygote.jacobian(testKern,A, p,Aout )

gives error

KernelError: kernel returns a value of type `Union{}`

Thanks for help !!

If you put @device_code_warntype before the jacobian, you see:

PTX CompilerJob of kernel #testKernDeff(CuDeviceArray{Float32, 3, 1}, CuDeviceArray{Float32, 3, 1}, CuDeviceArray{Float32, 3, 1}, CuDeviceArray{Float32, 3, 1}, CuDeviceArray{Float32, 3, 1}, Array{Float32, 3}) for sm_75, always_inline=false

Compare that to the direct invocation of testKernDeff; there’s an Array in here, because collect always return a CPU array.

1 Like

Thanks It was it - with some small further modifications now it works for the reference code below

using ChainRulesCore,Zygote,CUDA,Enzyme

Nx, Ny, Nz = 8, 8, 8
oneSidePad = 1
totalPad=oneSidePad*2
A = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dA= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

Aoutout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dAoutout= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

p = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dp= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

threads = (4, 4, 4)
blocks = (2, 2, 2)

function testKern(A, p, Aout)
    #adding one bewcouse of padding
    x = (threadIdx().x + ((blockIdx().x - 1) * CUDA.blockDim_x())) + 1
    y = (threadIdx().y + ((blockIdx().y - 1) * CUDA.blockDim_y())) + 1
    z = (threadIdx().z + ((blockIdx().z - 1) * CUDA.blockDim_z())) + 1
    Aout[x, y, z] = A[x, y, z] *p[x, y, z] *p[x, y, z] *p[x, y, z] 
    
    return nothing
end

function testKernDeff( A, dA, p
    , dp, Aout
    , dAout)
    Enzyme.autodiff_deferred(testKern, Const, Duplicated(A, dA), Duplicated(p, dp), Duplicated(Aout, dAout)
    )
    return nothing
end

function calltestKern(A, p)
    Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
    @cuda threads = threads blocks = blocks testKern( A, p,  Aout)
    return Aout
end

aa=calltestKern(A, p)
maximum(aa)

# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(calltestKern), A, p)
    Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
    function call_test_kernel1_pullback(dAout)
        # Allocate shadow memory.
        threads = (4, 4, 4)
        blocks = (2, 2, 2)
        dp = CUDA.ones(size(p))
        dA = CUDA.ones(size(A))
        @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, CuArray(collect(dAout)))

        f̄ = NoTangent()
        x̄ = dA
        ȳ = dp
        
        return f̄, x̄, ȳ
    end   
    return Aout, call_test_kernel1_pullback

end

ress=Zygote.jacobian(calltestKern,A, p )
typeof(ress)
maximum(ress[1])
maximum(ress[2])