I am trying to implement a SRGAN (super-resolution generative adversarial networks) in Flux (I am happy to share once it works… :-))
The generator network is composed of 5 residual blocks (function resblock
below).
Unfortunately the time of the first call of Zygote.pullback
increases very quickly
with the number of residual blocks, so that I am not able to use the 5
residual blocks as intended. I already commented out any elements-wise
broadcasting (activation functions), but the issue is still present.
number res. block | pullback time (1st call) | pullback time (2nd call) |
---|---|---|
0 | 58.2 s | 10.5 s |
1 | 73.33 s | 21.2 s |
2 | 94.13 s | 39.08 s |
3 | 3264 s (54.4 min) | 3240 s (54.0 min) |
I am using Julia 1.7.0-rc1 and Flux 0.12.6 and Zygote 0.6.21 on Linux (Intel i9-10900X, NVIDIA GeForce RTX 3080).
The issue is also present in Julia 1.6.1 while Julia 1.7.0-rc1 is a bit faster.
If I start julia with the option -O1, then the compile times become manageable (e.g. 20 seconds for 3 residual blocks) but the runtime performance in general will probably suffer.
Any help would be greatly appreciated.
using Flux
using Dates
channels = 64
in_channels = 3
num_channels = [3,64,64,128,128,256,256,512,512]
# discriminator
function block((in_channels,out_channels); stride = 1, use_batch_norm = true)
layers = []
push!(layers, Conv((3,3),in_channels => out_channels,pad = 1,stride = stride))
if use_batch_norm
push!(layers, BatchNorm(out_channels))
end
#push!(layers, x -> leakyrelu.(x,0.2f0))
return layers
end
discriminator = Chain(
reduce(vcat,[block(num_channels[i] => num_channels[i+1];
stride = 1 + (i+1) % 2,
use_batch_norm = i!=1) for i = 1:length(num_channels)-1 ])...,
AdaptiveMeanPool((1,1)),
Conv((1,1), num_channels[end] => 1024),
#x -> leakyrelu.(x,0.2f0),
Conv((1,1), 1024 => 1),
) |> gpu
function resblock(channels)
return SkipConnection(Chain(
Conv((3,3),channels => channels, pad=1),
BatchNorm(channels),
#Prelu(),
Conv((3,3),channels => channels, pad=1),
BatchNorm(channels),
)
, +)
end
function upsample(in_channels, up_scale)
return [
Conv((3,3),in_channels => in_channels*up_scale^2,pad=1),
PixelShuffle(up_scale),
#Prelu(),
]
end
generator = Chain(
Conv((9,9),3 => channels, pad = 4),
#Prelu(),
SkipConnection(Chain(
# test with different number of residual blocks
resblock(channels),
#resblock(channels),
#resblock(channels),
#resblock(channels),
#resblock(channels),
Conv((3,3),channels => channels, pad=1),
BatchNorm(channels)),+),
upsample(channels, 2)...,
upsample(channels, 2)...,
Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;
hr_images = randn(Float32,88,88,3,32)
lr_images = randn(Float32,22,22,3,32)
hr_images = gpu(hr_images)
lr_images = gpu(lr_images)
# check foreward
@show sum(discriminator(generator(lr_images)))
params_g = Flux.params(generator)
@info "generator $(Dates.now())"
# Taking gradient of generator
loss_g, back = @time Flux.pullback(params_g) do
sum(discriminator(generator(lr_images)))
end