Hello I am trying to define the custom Chainrules rrule to be able to use Enzyme on selected functions for backpropagation done generally with Zygote . Kernel and taking derivatives using enzyme works. However, taking jacobian using Zygote with this custom rule already defined do not.
I inspected the kernel with @device_code_warntype but all seems to be fine.
imports and test data
using ChainRulesCore,Zygote,CUDA,Enzyme
Nx, Ny, Nz = 8, 8, 8
oneSidePad = 1
totalPad=oneSidePad*2
A = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad )
dA= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad )
Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad )
dAout= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad )
p = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad )
dp= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad )
kernel and its derivative definition
function testKern(A, p, Aout)
#adding one bewcouse of padding
x = (threadIdx().x + ((blockIdx().x - 1) * CUDA.blockDim_x())) + 1
y = (threadIdx().y + ((blockIdx().y - 1) * CUDA.blockDim_y())) + 1
z = (threadIdx().z + ((blockIdx().z - 1) * CUDA.blockDim_z())) + 1
Aout[x, y, z] = A[x, y, z] *p[x, y, z] *p[x, y, z] *p[x, y, z]
return nothing
end
function testKernDeff( A, dA, p
, dp, Aout
, dAout)
Enzyme.autodiff_deferred(testKern, Const, Duplicated(A, dA), Duplicated(p, dp), Duplicated(Aout, dAout)
)
return nothing
end
running kernel all seems fine
threads = (4, 4, 4)
blocks = (2, 2, 2)
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
@device_code_warntype @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, dAout)
maximum(dp)# 4
maximum(dA)# 2
ChainRules rrule definition
function ChainRulesCore.rrule(::typeof(testKern), A, p,Aout)
function call_test_kernel1_pullback(dAout)
# Allocate shadow memory.
threads = (4, 4, 4)
blocks = (2, 2, 2)
dp = CUDA.ones(size(p))
dA = CUDA.ones(size(A))
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, collect(dAout))
f̄ = NoTangent()
x̄ = dA
ȳ = dp
return f̄, x̄, ȳ
end
return Aout, call_test_kernel1_pullback
end
testing whether rrule compiles
Zygote.jacobian(testKern,A, p,Aout )
gives error
KernelError: kernel returns a value of type `Union{}`
Thanks for help !!