Attempting to reimplement SRGAN from Pytorch to Flux, any help appreciated

Pytorch:

class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
        self.ps = nn.PixelShuffle(scale_factor)  # in_c * 4, H, W --> in_c, H*2, W*2
        self.act = nn.PReLU(num_parameters=in_c)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False,
        )

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x


class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

Julia:

using Flux
using Zygote
using CUDA
using CSV
using DataFrames
using Images
using MLDatasets
using BSON: @save, @load
using ImageView
using Functors
CUDA.allowscalar(false)

struct Prelu{T<:AbstractVector}
    slope::T
  end
  Prelu(dim::Int=1; init::Real=0.25f0) = Prelu(fill(float(init), dim))
  @functor Prelu
  function (p::Prelu)(x::AbstractArray)
    triv = x isa AbstractVector ? () : ntuple(_ -> 1, ndims(x) - 2)
    a = reshape(p.slope, triv..., :)  # channel dim is 2nd-last, unless x is a vector
    leakyrelu.(x, a)
  end

struct ConvBlock
    cnn
    bn
    act
end

@functor ConvBlock

function ConvBlock(
    kernel_size::Int,
    in_channels::Int,
    out_channels::Int,
    discriminator::Bool = false,
    act::Bool = true,
    use_batchnorm::Bool = true,
    s::Int = 1,
    p::Int = 1,
) return ConvBlock(
    Conv((kernel_size, kernel_size), in_channels => out_channels; stride=s, pad=p, bias=~use_batchnorm),
    if use_batchnorm
        Flux.BatchNorm(out_channels) 
    else 
        identity# ???
    end,
    if discriminator 
        x -> leakyrelu.(x, 0.2)#same as in pytorch nn.LeakyReLU(0.2, inplace=True)  ????
    else 
        Prelu(out_channels)
    end
)  
end

function (net::ConvBlock)(x)
    if net.act
        return net.act(net.bn(net.cnn(x)))
    else 
        return net.bn(net.cnn(x))
    end
end

struct UpsampleBlock
    conv
    ps
    act
end

@functor UpsampleBlock
function UpsampleBlock(
    in_c::Int,
    scale_factor::Int
) return UpsampleBlock(
    Conv((3,3), in_c=> in_c * scale_factor^2; stride=1, pad=1),
    PixelShuffle(scale_factor),
    Prelu(in_c),
)  
end

function (net::UpsampleBlock)(x)
    return net.act(net.ps(net.conv(x)))
end


struct ResidualBlock
    block1
    block2
end
@functor ResidualBlock
function ResidualBlock(
    in_channels::Int
) return ResidualBlock(
    ConvBlock(3, in_channels, in_channels,false,true,false,1,1),
    ConvBlock(3, in_channels, in_channels,false,true,true,1,1)
)  
end

function (net::ResidualBlock)(x)
    out = net.block1(x)
    out = net.block2(out)
    return out + x
end

struct Generator
    initial
    residuals
    convblock
    upsamples
    final
end

@functor Generator
function Generator(
    in_channels::Int = 3,
    num_blocks::Int = 16,
    num_channels::Int = 64,
) return Generator(
    ConvBlock(9,in_channels,num_channels,false,true,false,1,4),
    Chain([ResidualBlock(num_channels) for _ in range(1, length=num_blocks)]...),#is it same as nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)]) ??
    ConvBlock(3,in_channels, num_channels,false,false,true,1,4),
    Chain(UpsampleBlock(num_channels,2),UpsampleBlock(num_channels,2)),
    Conv((9,9),num_channels=>in_channels; stride=1,pad=4)
)  
end

function (net::Generator)(x)
    init = net.initial(x)
    x = net.residuals(init)
    x = net.convblock(x) + init
    x = net.upsamples(x)
    return Flux.tanh(net.final(x))
end

struct Discriminator
    blocks
    classifier
end

@functor Discriminator
function Discriminator(
    in_channels::Int = 3,
) 
return Discriminator(
    blocks = []
    for (idx, feature) in enumerate([64, 64, 128, 128, 256, 256, 512, 512])
            use_batchnorm = true
            if idx == 0
                use_batchnorm = false
            end
            append!(blocks, 
                 ConvBlock(3, in_channels, feature,  true, true, use_batchnorm, 1 + idx % 2, 1)
            )
            in_channels = feature
    end
    Chain(blocks...),
    Chain(
        Flux.AdaptiveMeanPool((6, 6)),
        Flux.Flatten(),
        Flux.Dense(512*6*6, 1024),
        x -> leakyrelu.(x, 0.2),#same as in pytorch nn.LeakyReLU(0.2, inplace=True)  ????
        Flux.Dense(1024, 1),
    )
) 
end

function (net::Discriminator)(x)
    x = net.blocks(initial)
    return net.classifier(x)
end


function main(num_epochs::Int, batch_size::Int, shuffle::Bool, λ::Float64)
    gen = Generator(3,16,64)|>gpu
    disc = Discriminator(3)|>gpu
    # gen_trainable_params = Flux.params(gen.initial, gen.residuals,gen.upsamples,gen.final)
    # disc_trainable_params = Flux.params(disc.blocks, disc.classifier)
    print(disc)
end

main(10,128,true,0.001)
1 Like

I think you want x -> leakyrelu.(x, 0.2) here.

I don’t think so, but it’s only a few lines:

struct Prelu{T<:AbstractVector}
  slope::T
end
Prelu(dim::Int=1; init::Real=0.25f0) = Prelu(fill(float(init), dim))
@functor Prelu
function (p::Prelu)(x::AbstractArray)
  triv = x isa AbstractVector ? () : ntuple(_ -> 1, ndims(x) - 2)
  a = reshape(p.slope, triv..., :)  # channel dim is 2nd-last, unless x is a vector
  leakyrelu.(x, a)
end
Base.show(io::IO, p::Prelu) = print(io, "Prelu(", length(p.slope) == 1 ? "" : length(p.slope), ")")
3 Likes

thanks