I am trying to implement a SRGAN (super-resolution generative adversarial networks) in Flux (I am happy to share once it works… :-))
The generator network is composed of 5 residual blocks (function
Unfortunately the time of the first call of
Zygote.pullback increases very quickly
with the number of residual blocks, so that I am not able to use the 5
residual blocks as intended. I already commented out any elements-wise
broadcasting (activation functions), but the issue is still present.
|number res. block||pullback time (1st call)||pullback time (2nd call)|
|0||58.2 s||10.5 s|
|1||73.33 s||21.2 s|
|2||94.13 s||39.08 s|
|3||3264 s (54.4 min)||3240 s (54.0 min)|
I am using Julia 1.7.0-rc1 and Flux 0.12.6 and Zygote 0.6.21 on Linux (Intel i9-10900X, NVIDIA GeForce RTX 3080).
The issue is also present in Julia 1.6.1 while Julia 1.7.0-rc1 is a bit faster.
If I start julia with the option -O1, then the compile times become manageable (e.g. 20 seconds for 3 residual blocks) but the runtime performance in general will probably suffer.
Any help would be greatly appreciated.
using Flux using Dates channels = 64 in_channels = 3 num_channels = [3,64,64,128,128,256,256,512,512] # discriminator function block((in_channels,out_channels); stride = 1, use_batch_norm = true) layers =  push!(layers, Conv((3,3),in_channels => out_channels,pad = 1,stride = stride)) if use_batch_norm push!(layers, BatchNorm(out_channels)) end #push!(layers, x -> leakyrelu.(x,0.2f0)) return layers end discriminator = Chain( reduce(vcat,[block(num_channels[i] => num_channels[i+1]; stride = 1 + (i+1) % 2, use_batch_norm = i!=1) for i = 1:length(num_channels)-1 ])..., AdaptiveMeanPool((1,1)), Conv((1,1), num_channels[end] => 1024), #x -> leakyrelu.(x,0.2f0), Conv((1,1), 1024 => 1), ) |> gpu function resblock(channels) return SkipConnection(Chain( Conv((3,3),channels => channels, pad=1), BatchNorm(channels), #Prelu(), Conv((3,3),channels => channels, pad=1), BatchNorm(channels), ) , +) end function upsample(in_channels, up_scale) return [ Conv((3,3),in_channels => in_channels*up_scale^2,pad=1), PixelShuffle(up_scale), #Prelu(), ] end generator = Chain( Conv((9,9),3 => channels, pad = 4), #Prelu(), SkipConnection(Chain( # test with different number of residual blocks resblock(channels), #resblock(channels), #resblock(channels), #resblock(channels), #resblock(channels), Conv((3,3),channels => channels, pad=1), BatchNorm(channels)),+), upsample(channels, 2)..., upsample(channels, 2)..., Conv((9,9),channels => 3,σ, pad=4), ) |> gpu; hr_images = randn(Float32,88,88,3,32) lr_images = randn(Float32,22,22,3,32) hr_images = gpu(hr_images) lr_images = gpu(lr_images) # check foreward @show sum(discriminator(generator(lr_images))) params_g = Flux.params(generator) @info "generator $(Dates.now())" # Taking gradient of generator loss_g, back = @time Flux.pullback(params_g) do sum(discriminator(generator(lr_images))) end