Need help to translate this Pytorch part to Flux

Would this be of any help?

using Flux
using Zygote
using Functors


struct Prelu{T<:AbstractVector}
    slope::T
end

Prelu(dim::Int=1; init::Real=0.25f0) = Prelu(fill(float(init), dim))

@functor Prelu

function (p::Prelu)(x::AbstractArray)
	triv = x isa AbstractVector ? () : ntuple(_ -> 1, ndims(x) - 2)
	a = reshape(p.slope, triv..., :)  # channel dim is 2nd-last, unless x is a vector
	leakyrelu.(x, a)
end

struct ConvBlock
    cnn
    act
end

@functor ConvBlock


function ConvBlock(
    kernel_size::Int,
    in_channels::Int,
    out_channels::Int,
    act::Bool,
    s::Int,
    p::Int,
) return ConvBlock(
    Conv((kernel_size, kernel_size), in_channels => out_channels; stride=s, pad=p,bias=true),
    if act 
        x -> leakyrelu.(x, 0.2)
    else 
        Flux.identity
    end
)  
end

function (net::ConvBlock)(x)
    return net.act(net.cnn(x))
end


in_channels = 4
channels = 1
kernel_size = 3

net = Chain([ ConvBlock(
		kernel_size,
        in_channels,
        in_channels,
        i <= 3,
        1,
        1,
    ) for i in 1:5]...)

net(randn(24,24,4,5))
1 Like