i already made some progress
Pytorch code
class ResidualBlock(nn.Module):
def __init__(self, in_channels, channels=32, residual_beta=0.2):
super().__init__()
self.residual_beta = residual_beta
self.blocks = nn.ModuleList()
for i in range(5):
self.blocks.append(
ConvBlock(
in_channels + channels * i,
channels if i <= 3 else in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=True if i <= 3 else False,
)
)
def forward(self, x):
new_inputs = x
for block in self.blocks:
out = block(new_inputs)
new_inputs = torch.cat([new_inputs, out], dim=1)
return self.residual_beta * out + x
Julia code
struct ResidualBlock
residual_beta
blocks
end
@functor ResidualBlock
function ResidualBlock(
in_channels::Int,
channels::Int=32,
residual_beta::Float32 = 0.2f0,
)
return ResidualBlock(
#this part is incorrect
blocks = []
for i in range(1, length=5)
push!(blocks,
3,
in_channels + channels * i,
i <= 3 ? channels : in_channels,
i <= 3 ? use_act=true : false,
1,
1)
end
)
end
function (net::ResidualBlock)(x)
new_inputs = x
for block in net.blocks
out = block(new_inputs)
new_inputs = cat([new_inputs, out], dims=1)
end
return net.residual_beta * out + x
end