Would this be of any help?
using Flux
using Zygote
using Functors
struct Prelu{T<:AbstractVector}
slope::T
end
Prelu(dim::Int=1; init::Real=0.25f0) = Prelu(fill(float(init), dim))
@functor Prelu
function (p::Prelu)(x::AbstractArray)
triv = x isa AbstractVector ? () : ntuple(_ -> 1, ndims(x) - 2)
a = reshape(p.slope, triv..., :) # channel dim is 2nd-last, unless x is a vector
leakyrelu.(x, a)
end
struct ConvBlock
cnn
act
end
@functor ConvBlock
function ConvBlock(
kernel_size::Int,
in_channels::Int,
out_channels::Int,
act::Bool,
s::Int,
p::Int,
) return ConvBlock(
Conv((kernel_size, kernel_size), in_channels => out_channels; stride=s, pad=p,bias=true),
if act
x -> leakyrelu.(x, 0.2)
else
Flux.identity
end
)
end
function (net::ConvBlock)(x)
return net.act(net.cnn(x))
end
in_channels = 4
channels = 1
kernel_size = 3
net = Chain([ ConvBlock(
kernel_size,
in_channels,
in_channels,
i <= 3,
1,
1,
) for i in 1:5]...)
net(randn(24,24,4,5))