Thank You all for help! As to describing problem in more detail I want to automate definitions of some functions for automatic differentiation (there will be limited number of them like 6-7 - but the code is very repetitive). All information (conff_params below) will be constant and manually provided in a code - known at compilation time it requires:
- define set of variables definitions and for each weather it is parameter,state,output or input; additionally I pass the kernel function (testKern below) that has the same arguments and is the function I want to differentiate (conff_params below)
- need to create struct that is inheriting from Lux.AbstractExplicitLayer (KernelAstr below)
- define function that call Enzyme.autodiff_deferred (testKernDeff below)
- define ChainRulesCore.rrule (ChainRulesCore.rrule(::typeof(calltestKern) below)
- define Lux.initialparameters (Lux.initialparameters below)
- define Lux.initialstates (Lux.initialstates below)
- define apply function of the layer ((l::KernelAstr)(x, ps, st::NamedTuple) below)
- define call function using CUDA macro (calltestKern below)
all steps can be done using information supplied in variables list list (conff_params below)
How variables are defined
@enum paramType Eparameter=1 Estate Einput Eoutput
conff_params= Dict(
"A"=>(Einput
,(rng,l)->()
,Duplicated
)
,"p"=>(Eparameter
,(rng,l)->CuArray(rand(rng,Float32, l.confA, l.confA, l.confA))
,Duplicated
)
,"Aout"=>(Eoutput
,(rng,l)->()
,Duplicated
)
,"Nxa"=>(
Estate
,(rng,l)->Nx
,Const
)
,"Nya"=>(
Estate
,(rng,l)->Ny
,Const
)
,"Nza"=>(
Estate
,(rng,l)->Nz
,Const
)
)
example code before metaprogramming
#lux layers from http://lux.csail.mit.edu/dev/manual/interface/
struct KernelAstr<: Lux.AbstractExplicitLayer
confA::Int
Nxa::Int
Nya::Int
Nza::Int
end
function testKern(A, p, Aout,Nxa,Nya,Nza)
#adding one bewcouse of padding
x = (threadIdx().x + ((blockIdx().x - 1) * CUDA.blockDim_x())) + 1
y = (threadIdx().y + ((blockIdx().y - 1) * CUDA.blockDim_y())) + 1
z = (threadIdx().z + ((blockIdx().z - 1) * CUDA.blockDim_z())) + 1
Aout[x, y, z] = A[x, y, z] *p[x, y, z] *p[x, y, z] *p[x, y, z]
return nothing
end
function testKernDeff( A, dA, p
, dp, Aout
, dAout,Nxa,Nya,Nza)
Enzyme.autodiff_deferred(testKern, Const, Duplicated(A, dA), Duplicated(p, dp), Duplicated(Aout, dAout),Const(Nxa),Const(Nya),Const(Nza))
return nothing
end
function calltestKern(A, p,Nxa,Nya,Nza)
Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad )
@cuda threads = threads blocks = blocks testKern( A, p, Aout,Nxa,Nya,Nza)
return Aout
end
aa=calltestKern(A, p,Nx,Ny,Nz)
maximum(aa)
# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(calltestKern), A, p,Nxa,Nya,Nza)
Aout = calltestKern(A, p,Nxa,Nya,Nza)#CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad )
function call_test_kernel1_pullback(dAout)
# Allocate shadow memory.
threads = (4, 4, 4)
blocks = (2, 2, 2)
dp = CUDA.ones(size(p))
dA = CUDA.ones(size(A))
@cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, CuArray(collect(dAout)),Nxa,Nya,Nza)
f̄ = NoTangent()
x̄ = dA
ȳ = dp
return f̄, x̄, ȳ,NoTangent(),NoTangent(),NoTangent()
end
return Aout, call_test_kernel1_pullback
end
#first testing
# ress=Zygote.jacobian(calltestKern,A, p ,Nx,Ny,Nz)
# typeof(ress)
# maximum(ress[1])
# maximum(ress[2])
function KernelA(confA::Int,Nxa,Nya,Nza)
return KernelAstr(confA,Nxa,Nya,Nza)
end
function Lux.initialparameters(rng::AbstractRNG, l::KernelAstr)
return (paramsA=CuArray(rand(rng,Float32, l.confA, l.confA, l.confA))
,paramsB = CuArray(rand(rng,Float32, l.Nxa, l.Nya, l.Nza)))
end
Lux.initialstates(::AbstractRNG, ::KernelAstr) = (Nxa=Nx,Nya=Ny,Nza=Nz)
# # But still recommened to define these
# Lux.parameterlength(l::KernelAstr) = l.out_dims * l.in_dims + l.out_dims
# Lux.statelength(::KernelAstr) = 0
function (l::KernelAstr)(x, ps, st::NamedTuple)
return calltestKern(x, ps.paramsA,st.Nxa,st.Nya,st.Nza),st
end