Yes, I did have made it type-stable with helping of @code_warntype
. , Iโve no idea.
There are two big factors here that can cause runtime variability: GC and threading. I suspect the former is at play, because your model has a metric ton of Flux layers which all allocate basically all of their outputs and intermediates. While benchmarking each statement separately gives the GC more opportunities to clean up between statements, benchmarking the whole function does not.
Also remember that @btime
reports the minimum execution time across all runs. In this case, that probably means the run sampled for each line with the least amount of GC intervention or background system noise. The chances of everything going perfectly when all statements are combined is thus diminishingly tiny, especially when you consider that each statement increases the chance of slowdown for subsequent ones by allocating large arrays through the GC.
There are a couple things you can look at to visualize this variability. Firstly is to run out = m(x);
more than once before resetting the timer. You can to this with a plain loop or BenchmarkTools macro as you were doing before. This will allow collection of more samples for a smoother/more stable average, and I suspect that average will show that individual components take more time than @btime
shows. The second thing is to use @benchmark
instead of @btime
to look at individual lines. This is a bit more code since @benchmark
doesnโt return the output of the benchmarked expression, but itโll give you a proper distribution of all runtimes with allocation and GC statistics.
I really donโt understand most of the code, and itโs clear youโre doing and timing a lot more than the CoarseDecoder
method in your original post. But I had a thought about you saying you used @code_warntype
to make sure methods are type-stable.
Both m
and x
are non-const
global variables, so that could introduce type instabilities that @code_warntype m(x)
does not catch. For a simple example, x = 1; @code_warntype x+x
will report type stability, but x+x
in practice does not know the type of x
and needs to do dynamic dispatch. This is visible in increases of @time
and @btime
, which is why the looped @btime
has $
-interpolation like in the rest of your code. This effect might actually be small compared to the big work your code does, but itโs worth checking because Iโve seen this really drag down the performance in some cases.
The quick fix is to make and work on m
and x
in a function, or to make them const
; just doing @btime $m($x)
may remove the type instability in the benchmark, but your original code m(x)
would still suffer from type instability. But beware that const
variables will not be properly replaced by each include("slbr.jl")
, and that will especially hurt if they hold large data you want to discard.
Thank you @Benny! I think this possibility can be put aside for a while. Because at the very beginning of this performance experiment, using @btime, it requires you to use $name when passing in parameters, and this guys it is type stable. What you said about type stability and multiple dispatching goes deeper than I understand.
@ToucheSir
Thank you for your guidance. I finally have a clear understanding of the current performance bottleneck.
But I still donโt know how to deal with the performance issues caused by GC.
There are some other questions.
- if the function passes the type stability test, wonโt it no longer cause latency due to multiple dispatching?
- Julia is compiled and executed, shouldnโt GC happen at a fixed point in time?
- GC is controlled by the system, so if we canโt do anything about it, what about the big performance gap and the performance advantage Julia is proud of?
- I believe so in most cases, but I donโt think thatโs where your overhead is coming from.
- No, at least not in the way youโd expect something like RAII to work.
- Generally the advice is to allocate less. Unfortunately this is not possible for Flux because of technical limitations, so weโre stuck with heavy reliance on the GC (for now). Itโs no coincidence that https://github.com/JuliaCI/GCBenchmarks/blob/main/benches/flux/flux_multithreaded_training.jl exists.
A bit more on the last point: generally speaking, you should not expect Flux to be significantly faster than e.g. PyTorch for โstandardโ deep learning models, because most of the hot code is running in highly optimized, 3rd party routines that both libraries call into. PyTorchโs runtime also has a specialized allocator and memory reclamation strategy that is better suited for the kinds of consistent but heavy allocation patterns found with big NN models. Juliaโs GC is not so specialized and also (currently) lower throughput, so unfortunately it can become a bottleneck for certain Flux models. In addition to improving the GC, there is pending compiler work for optimizing more array allocations (e.g. 1, 2, 3) so that theyโre cleaned up more earlier and more consistently.
Thank you, seems I have to go back pytorch.
It seems there is no such thing as a free lunch. The simple elegance of Julia is still lacking in power. I hope that julia will soon be able to lift itself up and use its simplicity and elegance to make the underlying support more efficient.
Help please!
I donโt even know what the problem is.
All I know is that the pytorch
version only takes 60ms, while my rewritten version takes 1200ms.
The source code is above.
julia -t auto slbr.jl
"==================================" = "=================================="
1.278358 seconds (57.91 k allocations: 3.173 MiB, 2.60% gc time)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Time Allocations
โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโโ
Tot / % measured: 2.83s / 0.0% 21.7MiB / 0.0%
Section ncalls time %tot avg alloc %tot avg
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Itโs hard to say anything here without seeing that PyTorch code. I would highly recommend posting it alongside a version of sblr.jl
with all timing-related code removed (itโs already way to large for a MWE, so giving people less to read through will increase your chances of getting help)
Thank you!
This slbr.jl
is my rewrite version of class SLBR
of SLBR
, which source code is opened in github, file/networks/resunet.py.
I call the model as follow(python):
videoCapture = cv2.VideoCapture(fvideo)
success, frame = videoCapture.read()
while success:
im = preprocess(cv2.resize(frame, (720,720))).to("cuda").unsqueeze(dim=0)
t0 = time.time()
imoutput,immask_all,imwatermark = model(im)
print("time consume ", time.time()-t0)
success, frame = videoCapture.read()
And, the output(recently) is:
time consume 0.2411808967590332
time consume 0.08216118812561035
time consume 0.08243012428283691
time consume 0.0825655460357666
time consume 0.08230853080749512
time consume 0.08225369453430176
I thought people might be interested in such a tedious job, so I didnโt post it before.
Both run on the same machine with Intelยฎ Coreโข i5-9400F CPU @ 2.90GHz ร 6
and 15.6Gib memory and NVIDIA Corporation GP104 [GeForce GTX 1070].
The 60ms mentioned before refers to the most time consuming step(ims, mask, wm = m.coarse_decoder(im, nothing, mask0, unshared_before_pool)
) on slbr.jl
on the pytorch version.
The full code is as follow:
using Flux: output_size
using Base: _before_colon, concatenate_setindex!, copymutable
using Flux, BSON
using Revise
using BenchmarkTools
using InteractiveUtils
using TimerOutputs
const to = TimerOutput()
xshape(x) = eltype(x) <: AbstractArray ? (length(x), xshape(x[1])...) : [length(x)]
tovec(x) = eltype(x) <: AbstractArray ? tovec([xx_ for x_ in x for xx_ in x_]) : [x...]
struct DownConv{C1<:Conv,C2<:Conv, BI<:Union{BatchNorm,InstanceNorm},M<:MaxPool,F<:Function}
conv1::C1
conv2::Vector{C2}
act::F
norm1::BI
bn::Vector{BI}
residual::Bool
pooling:: Bool
pool::M
end
function DownConv(in_channels, out_channels, blocks; pooling=true, norm=Flux.BatchNorm, act=Flux.relu, residual=true, dilations=[])
# norm = norm == "bn" ? Flux.BatchNorm : norm == "in" ? Flux.InstanceNorm : error("Unknown type:\t$norm")
conv1 = Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
dilations = length(dilations) == 0 ? [1 for i in 1:blocks] : dilations
conv2 =[Conv((3,3), out_channels=>out_channels; dilation=dilations[i], pad=dilations[i])
for i in 1:blocks ]
bn = fill(norm(out_channels), blocks)
pool = Flux.MaxPool((2, 2), stride=2)
# @show typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)
DownConv{typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)}(conv1, conv2, act, norm1, bn, residual, pooling, pool) #
end
function (m::DownConv)(x::AbstractArray{Float32, 4})
x1 = m.act.(m.norm1(m.conv1(x)))
for (idx, conv) in enumerate(m.conv2)
x2 = conv(x1)
x2 = m.bn[idx](x2)
x1 = m.residual ? x1+x2 : x2
x1 = m.act.(x1)
end
before_pool = deepcopy(x1)
x1 = m.pooling ? m.pool(x1) : x1
x1, before_pool
end
Flux.@functor DownConv
struct UpConv{C<:Conv,BI<:Union{BatchNorm,InstanceNorm},F<:Function}
up_conv::C
conv1::C
conv2::Vector{C}
bn::Vector{BI}
norm0::BI
norm1::BI
act::F
concat::Bool
use_mask::Bool
residual::Bool
end
function UpConv(in_channels, out_channels, blocks; residual=true, norm=Flux.BatchNorm, act=Flux.relu, concat=true, use_att=false, use_mask=false, dilations=[], out_fuse=false)
up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
norm0 = norm(out_channels)
if length(dilations)==0
dilations = [1 for _ in 1:blocks]
end
if concat
conv1 = Conv((3,3), (2*out_channels+(use_mask ? 1 : 0))=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
else
conv1 = Conv((3,3), out_channels=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
end
conv2 = [ Conv((3,3), out_channels=>out_channels; dilation = dilations[i], pad=dilations[i], bias=true) for i in 1:blocks ]
bn = [norm(out_channels) for _ in 1:blocks]
UpConv{typeof(up_conv), typeof(norm0), typeof(act)}(up_conv, conv1, conv2, bn, norm0, norm1, act, concat, use_mask, residual)
end
struct OutFuse{v}
end
OutFuse(x) = OutFuse{x}()
function (m::UpConv)(::OutFuse{true}, from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing, se=nothing)#::Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}
from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0,2.0f0)))))
x1 = m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (prod(size(from_down))==0 ? from_up + from_down : from_up)
xfuse = x1 = m.act.(m.norm1(m.conv1(x1)))
for (idx, conv) in enumerate(m.conv2)
x2 = m.bn[idx](conv(x1))
if !(se===nothing) && idx == length(m.conv2)
x2 = se(x2)
end
x1 = m.residual ? x1+x2 : x2
x1 = m.act.(x1)
end
x1, xfuse
end
function (m::UpConv)(::OutFuse{false}, from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing, se=nothing)#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0,2.0f0)))))
x1 = m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (prod(size(from_down))==0 ? from_up + from_down : from_up)
x1 = m.act.(m.norm1(m.conv1(x1)))
for (idx, conv) in enumerate(m.conv2)
x2 = m.bn[idx](conv(x1))
if !(se===nothing) && idx == length(m.conv2)
x2 = se(x2)
end
x1 = m.residual ? x1+x2 : x2
x1 = m.act.(x1)
end
x1
end
Flux.@functor UpConv
struct CFFBlock{C1<:Conv, C2<:Conv, D1<:DownConv, D2<:DownConv, D3<:DownConv, CH1<:Chain, CH2<:Chain}
up32::C1
up31::C2
down1::D1
down2::D2
down3::D3
conv22::CH1
conv33::CH2
end
function CFFBlock(; down=DownConv, up=UpConv, ngf::Int=32)
p = [
Conv((3,3), ngf*4 => ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf*4=>ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
DownConv(ngf, ngf, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
DownConv(ngf, ngf*2, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
DownConv(ngf*2, ngf*4, 3; pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[1,2,5]),
Chain(
Conv((3,3), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
),
Chain(
Conv((3,3), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
)
]
CFFBlock{typeof.(p)...}(p...)
end
function (m::CFFBlock)(x1::AbstractArray{Float32, 4}, x2::AbstractArray{Float32, 4}, x3::AbstractArray{Float32, 4})
x32 = Flux.upsample_bilinear(x3; size=(size(x2)[2], size(x2)[1]))
x32 = m.up32(x32)
x31 = Flux.upsample_bilinear(x3; size=(size(x1)[2], size(x1)[1]))
x31 = m.up31(x31)
# cross-connection
x, d1 = m.down1(x1 + x31)
x, d2 = m.down2(x + m.conv22(x2) + x32)
d3, _ = m.down3(x + m.conv33(x3))
d1,d2,d3
end
Flux.@functor CFFBlock
struct ECABlocks{C<:Conv,AMP<:Flux.AdaptiveMeanPool}
conv::C
avg_pool::AMP
end
function ECABlocks(channel, k_size=3)
conv = Conv((k_size,), 1=>1, sigmoid; pad=(k_size-1)รท2, bias=false)
avg_pool = AdaptiveMeanPool((1,1))
ECABlocks{typeof(conv), typeof(avg_pool)}(conv, avg_pool)
end
function (m::ECABlocks)(x::AbstractArray{Float32, 4})
# h, w, c, b = size(x)
y = m.avg_pool(x)
y = Flux.unsqueeze(permutedims(m.conv(permutedims(reshape(y, size(y)[2:end]), (2,1,3))), (2, 1, 3)), dims=1)
# y = Flux.sigmoid(y) ๅจ่ฟ้ๅฏไปฅๆพๅฐconv้ๅป
x .* y
end
Flux.@functor ECABlocks
struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
up_conv::C
bn::Vector{BI}
norm0::BI
norm1::BI
conv1::C
conv2::Vector{CH}
conv3::Vector{C}
act::F
concat::Bool
residual::Bool
end
function MBEBlock(in_channels=512, out_channels=3; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true, is_final=true)
up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
conv1 = Conv((3,3), (concat ? 2*out_channels : out_channels)=>out_channels; pad=1, bias=true)
conv2 = Vector{Chain}()
conv3 = Vector{Conv}()
for i in 1:blocks
push!(conv2, Chain(
Conv((5,5), (out_channels รท 2 + 1)=>(out_channels รท 4), Flux.relu; stride=1, pad=2, bias=true),
Conv((5,5), (out_channels รท 4)=>1, Flux.sigmoid_fast; stride=1, pad=2, bias=true)
))
push!(conv3, Conv((3,3), (out_channels รท 2)=>out_channels; pad=1, bias=true))
end
bn = [norm(out_channels) for i in 1:blocks]
MBEBlock{typeof(up_conv), eltype(bn), eltype(conv2), typeof(act)}(up_conv, bn, norm(out_channels), norm(out_channels), conv1, conv2, conv3, act, concat, residual)
end
function (m::MBEBlock)(from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing)
from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0, 2.0f0)))))
if m.concat
x1 = cat(from_up, from_down; dims=Val(3))
elseif !(from_down === nothing)
x1 = from_up + from_down
else
x1 = from_up
end
x1 = m.act.(m.norm1(m.conv1(x1)))
#residual structure
H, W, C, _ = size(x1)
for (idx, (conv1, conv2)) in enumerate(zip(m.conv2, m.conv3))
#@show size(x1)
#@show size(x1[:,:,1:(C รท 2 ), :])
#@show size(mask)
mask = conv1(cat(view(x1, :,:,1:(C รท 2), :), mask; dims=Val(3)))
x2_actv = view(x1, :,:,(C รท 2+1) : size(x1)[3], :) .*mask
x2 = conv2(view(x1, :,:,(C รท 2+1) : size(x1)[3], :) + x2_actv)
x2 = m.bn[idx](x2)
x1 = m.residual ? x2+x1 : x1
x1 = m.act.(x1)
end
x1
end
Flux.@functor MBEBlock
struct SelfAttentionSimple{C<:Conv, CH<:Chain, AA<:AbstractArray{Float32, 4}}
k_center::Int32
q_conv::C
k_conv::C
v_conv::C
sim_func::C
out_conv::CH
min_area::Float32
threshold::Float32
k_weight::AA
end
function SelfAttentionSimple(in_channel, k_center)
conv1 = Conv((1, 1), in_channel=>in_channel)
conv2 = Conv((1, 1), in_channel=>in_channel*k_center)
conv3 = Conv((1, 1), in_channel=>in_channel*k_center)
conv4 = Conv((1,1), (2*in_channel) => 1; stride=1, pad=0, bias=true)
ch = Chain(
Conv((3,3), in_channel=>(in_channelรท8), Flux.relu; stride=1, pad=1),
Conv((3,3), (in_channelรท8)=>1; stride=1, pad=1)
)
a = fill(1.0f0, (1, 1, k_center, 1))
SelfAttentionSimple{typeof(conv1), typeof(ch), typeof(a)}(Int32(k_center),
conv1, conv2, conv3, conv4, ch,
100.0f0, 0.5f0,
a
)
end
function compute_attention(m::SelfAttentionSimple, query::T, key::T, mask::T, eps=1) where{T}#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
h,w,c,b = size(query)
#@show size(query)
# @btime $m.q_conv($query)
query = m.q_conv(query)
key_in = key
# @btime $m.k_conv($key_in)
key = m.k_conv(key_in)
# keys = [view(key, :,:,i:(i+c-1),:) for i in 1:c:size(key)[3]]
keys = Vector{T}([key[:,:,i:(i+c-1),:] for i in 1:c:size(key)[3]])
# @btime eltype($mask).($mask .> $m.threshold)
importance_map = eltype(mask).(mask .> m.threshold)
# @btime sum($importance_map, dims=[1,2])
s_area = sum(importance_map, dims=[1,2])
# mask = s_area .>= m.min_area
# s_area = s_area .* mask + m.min_area .* .!mask
# @btime clamp_lo(one(eltype($s_area))*$m.min_area).($s_area)
clamp!(s_area, m.min_area, 1.0f7) #clamp_lo(one(eltype(s_area))*m.min_area).(s_area)
s_area = s_area[:, :, 1:1, :]#view(s_area, :, :, 1:1, :)
if m.k_center != 2
# @btime [sum(k .*$importance_map, dims=[1, 2]) ./ $s_area for k in $keys]
ks = Vector{T}()
for k in keys
push!(ks, sum(k .*importance_map, dims=(Val(1), Val(2))) ./ s_area)
end
keys = ks
# keys = [sum(keys[1] .*importance_map, dims=[1, 2]) ./ s_area, sum(keys[2] .*importance_map, dims=[1, 2]) ./ s_area]
else
# @btime [sum($keys[1] .* $importance_map, dims=[1,2]) ./$s_area,
# sum($keys[2] .*(one(eltype($importance_map)) .- $importance_map), dims=[1,2]) ./ (size($keys[2])[1] * size($keys[2])[2] .- $s_area .+ $eps),
# ]
keys = Vector{T}([sum(keys[1] .* importance_map, dims=[1,2]) ./s_area,
sum(keys[2] .*(one(eltype(importance_map)) .- importance_map), dims=[1,2]) ./ (size(keys[2])[1] * size(keys[2])[2] .+ eps .- s_area)
])
end
f_query = query
# f_key = [repeat(reshape(k, (1,1,c,b)), size(f_query)[1:2]..., 1, 1) for k in keys]
f_key = [Flux.upsample_nearest(reshape(k, (1, 1, c, b)), size=(size(f_query)[1], size(f_query)[2])) for k in keys]
attention_scores = Vector{T}()
for k in f_key
# @btime Flux.tanh_fast.(cat($f_query, $k, dims=3))
combine_qk = Flux.tanh_fast.(cat(f_query, k, dims=Val(3)))
# combine_qk = typeof(f_query)(undef, size(f_query)[1], size(f_query)[2], size(f_query)[3]*2, size(f_query)[4])
# combine_qk[:,:,1:size(f_query)[3],:] = f_query
# combine_qk[:,:,(size(f_query)[3]+1):end,:] = k
# @btime $m.sim_func($combine_qk)
sk = m.sim_func(combine_qk)
push!(attention_scores, sk)
end
# @btime cat($attention_scores...; dims=3)
# s = cat(attention_scores...; dims=Val(3))
s = reshape(mapreduce(Flux.flatten, vcat, attention_scores), size(attention_scores[1])[1], size(attention_scores[1])[2], sum([size(x)[3] for x in attention_scores]), size(attention_scores[1])[4])
# @btime permutedims($s, (3,1,2,4))
s = permutedims(s, (3,1,2,4))
# @btime $m.v_conv($key_in)
v = m.v_conv(key_in)
if m.k_center == 2
# @btime sum($v[:, :, 1:$c-1, :] .* $importance_map, dims=[1,2]) ./ $s_area
v_fg = sum(view(v, :, :, 1:c-1, :), dims=[1,2]) .* sum(importance_map, dims=[1,2]) ./ s_area
# @btime sum($v[:,:,$c:end, :] .* (1 .- $importance_map);dims=[1,2]) ./ (size($v)[1] * size($v)[2] .- $s_area .+ $eps)
v_bg = sum(view(v, :,:,c:size(v)[3], :), dims=[1,2]) .* sum((1 .- importance_map);dims=[1,2]) ./ (size(v)[1] * size(v)[2] .- s_area .+ eps)
v = cat(v_fg, v_bg; dims=Val(3))
else
# @btime sum($v .* $importance_map, dims=[1,2]) ./ $s_area
v = sum(v, dims=[1,2]) .* sum(importance_map, dims=[1,2]) ./ s_area
end
v = reshape(v, (c, m.k_center, b))
#@show size(s), (m.k_center, h*w, b)
#@show size(v)
# @btime permutedims(reshape(Flux.batched_mul($v, reshape($s, ($m.k_center, $h*$w, $b))), ($c,$h,$w,$b)), (2, 3, 1,4))
attn = permutedims(reshape(Flux.batched_mul(v, reshape(s, (m.k_center, h*w, b))), (c,h,w,b)), (2, 3, 1,4))
#@show size(query)
#@show size(attn)
# @btime $m.out_conv($attn+$query)
m.out_conv(attn + query)
end
function (m::SelfAttentionSimple)(xin::AbstractArray{Float32, 4}, xout::AbstractArray{Float32, 4}, xmask::AbstractArray{Float32, 4})
h, w, c, b_num= size(xin)
attention_score = compute_attention(m, xin, xout, xmask)
attention_score = reshape(attention_score, (h, w, 1, b_num))
return xout, Flux.sigmoid.(attention_score)
end
Flux.@functor SelfAttentionSimple
struct SMRBlock{C<:Conv, U<:UpConv, S<:SelfAttentionSimple}
upconv::U
primary_mask::C
self_calibrated::S
end
function SMRBlock(ins, outs, k_center; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true)
conv1 = UpConv(ins, outs, blocks; residual=residual, concat=concat, norm=norm, act=act)
conv2 = Conv((1,1), outs=>1, Flux.sigmoid_fast; stride=1, pad=0, bias=true)
sa = SelfAttentionSimple(outs, k_center)
SMRBlock{typeof(conv2), typeof(conv1), typeof(sa)}( conv1, conv2, sa)
end
function (m::SMRBlock)(input::AbstractArray{Float32, 4}, encoder_outs=nothing)
# @btime $m.upconv(OutFuse(true), $input, $encoder_outs)
mask_x = m.upconv(OutFuse(false), input, (encoder_outs===nothing ? typeof(input)(undef, zeros(Int64, length(size(input)))) : encoder_outs))
# @btime $m.primary_mask($mask_x)
primary_mask = m.primary_mask(mask_x)
# # @btime $m.self_calibrated($mask_x, $mask_x, $primary_mask)
mask_x, self_calibrated_mask = m.self_calibrated(mask_x, mask_x, primary_mask)
return Dict(
"feats"=>[mask_x],
"attn_maps"=>[primary_mask, self_calibrated_mask]
)
end
Flux.@functor SMRBlock
struct CoarseEncoder{D<:DownConv}
down_convs::Vector{D}
end
function CoarseEncoder(in_channels::Int=3, depth::Int=3; blocks=1, start_filters=32, residual=true, norm=Flux.BatchNorm, act=Flux.relu)
down_convs = []
outs = nothing
if isa(blocks, AbstractArray)
blocks = blocks[0]
end
for i in 1:depth
ins = i==1 ? in_channels : outs
outs = start_filters*(2^(i-1))
pooling = true
# #@show ins, depth
push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act))
end
CoarseEncoder{DownConv}(down_convs)
end
function (m::CoarseEncoder)(x::AbstractArray{Float32, 4})
nx = x
encoder_outs = Vector{typeof(x)}()
for d_conv in m.down_convs
nx, before_pool = d_conv(nx)
push!(encoder_outs, before_pool)
end
nx, encoder_outs
end
Flux.@functor CoarseEncoder
struct SharedBottleNeck{U<:UpConv, D<:DownConv, ECA1<:ECABlocks, ECA2<:ECABlocks}
up_convs::Vector{U}
down_convs::Vector{D}
up_im_atts::Vector{ECA1}
up_mask_atts::Vector{ECA2}
end
function SharedBottleNeck(in_channels=512, depth=5, shared_depth=2; start_filters=32, blocks=1, residual=true, concat=true, norm=Flux.BatchNorm, act=Flux.relu, dilations=[1,2,5])
start_depth = depth - shared_depth
max_filters = 512
down_convs = Vector{DownConv}()
up_convs = Vector{UpConv}()
up_im_atts = Vector{ECABlocks}()
up_mask_atts = Vector{ECABlocks}()
outs = 0
for i in start_depth:depth-1
# #@show i, start_depth
ins = i == start_depth ? in_channels : outs
outs = min(ins*2, max_filters)
# encoder convs
pooling = i<depth-1 ? true : false
push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act, dilations=dilations))
# decoder convs
if i < depth - 1
up_conv = UpConv(min(outs*2, max_filters), outs, blocks, residual=residual, concat=concat, norm=norm, act=Flux.relu, dilations=dilations)
push!(up_convs, up_conv)
push!(up_im_atts, ECABlocks(outs))
push!(up_mask_atts, ECABlocks(outs))
end
end
# @show eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)
SharedBottleNeck{eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)}(up_convs, down_convs, up_im_atts, up_mask_atts)
end
function (m::SharedBottleNeck)(input::AbstractArray{Float32, 4})
# encoder convs
im_encoder_outs = Vector{typeof(input)}()
mask_encoder_outs = Vector{typeof(input)}()
x = input
for (i, d_conv) in enumerate(m.down_convs)
x, before_pool = d_conv(x)
push!(im_encoder_outs, before_pool)
push!(mask_encoder_outs, before_pool)
end
x_im = x
x_mask = x
#@show size(x_mask)
# Decoder convs
x = x_im
for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_im_atts))) in enumerate(zip(m.up_convs, m.up_im_atts))
before_pool = im_encoder_outs === nothing ? typeof(x)(undef, zeros(Int64, length(size(x)))) : im_encoder_outs[end-i]
x = up_conv(OutFuse(false), x, before_pool, se=attn)
end
x_im = x
x = x_mask
for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_mask_atts))) in enumerate(zip(m.up_convs, m.up_mask_atts))
before_pool = mask_encoder_outs === nothing ? typeof(x)(undef, zeros(Int64, length(size(x)))) : mask_encoder_outs[end-i]
x = up_conv(OutFuse(false), x, before_pool, se=attn)
end
x_mask = x
x_im, x_mask
end
Flux.@functor SharedBottleNeck
struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
up_convs_bg::Vector{MBE}
up_convs_mask::Vector{SMR}
atts_mask::Vector{ECA}
atts_bg::Vector{ECA}
conv_final_bg::C
use_att::Bool
end
function CoarseDecoder(in_channels=512, out_channels=3, k_center=2; norm=Flux.BatchNorm, act=Flux.relu, depth=5, blocks=1, residual=true, concat=true, use_att=false)
up_convs_bg = Vector{MBEBlock}()
up_convs_mask = Vector{SMRBlock}()
atts_bg = Vector{ECABlocks}()
atts_mask = Vector{ECABlocks}()
outs = in_channels
for i in 1:depth
ins = outs
outs = ins รท 2
# background reconstruction branch
up_conv = MBEBlock(ins, outs; blocks=blocks, residual=residual, concat=concat, norm=norm, act=act)
push!(up_convs_bg, up_conv)
if use_att
push!(atts_bg, ECABlocks(outs))
end
#mask prediction branch
up_conv = SMRBlock(ins, outs, k_center; norm=norm, act=act, blocks=blocks, residual=residual, concat=concat)
push!(up_convs_mask, up_conv)
if use_att
push!(atts_mask, ECABlocks(outs))
end
end
conv_final_bg = Conv((1,1), outs=>out_channels, stride=1, pad=0, bias=true)
CoarseDecoder{typeof(conv_final_bg),eltype(up_convs_bg),eltype(up_convs_mask), eltype(atts_mask)}(up_convs_bg, up_convs_mask, atts_mask, atts_bg,
conv_final_bg,
use_att
)
end
function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
bg_x = bg
mask_x = mask
mask_outs = Vector{T}()
bg_outs = Vector{T}()
for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
# @btime before_pool = $encoder_outs[end-($i-1)]
before_pool = encoder_outs[end-(i-1)] #encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)]
if m.use_att
# @btime $m.atts_mask[$i]($before_pool)
mask_before_pool = m.atts_mask[i](before_pool)
# @btime $m.atts_bg[$i]($before_pool)
bg_before_pool = m.atts_bg[i](before_pool)
end
# @btime $up_mask($mask_x,$mask_before_pool)
# @code_warntype up_mask(mask_x, mask_before_pool)
smr_outs = up_mask(mask_x, mask_before_pool)
# @btime mask_x = $smr_outs["feats"][1]
mask_x = smr_outs["feats"][1]
# @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
primary_map, self_calibrated_map = smr_outs["attn_maps"]
# @btime push!($mask_outs, $primary_map)
push!(mask_outs, primary_map)
# @btime push!($mask_outs, $self_calibrated_map)
push!(mask_outs, self_calibrated_map)
# @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
# @show typeof(bg_x), typeof(bg_before_pool), typeof(self_calibrated_map)
bg_x = up_bg(bg_x, bg_before_pool; mask = self_calibrated_map) # ่ฟ้ๅฏ่ฝๆ้ฎ้ข
# @btime push!($bg_outs, $bg_x)
push!(bg_outs, bg_x)
end
if m.conv_final_bg !== nothing
# @btime $m.conv_final_bg($bg_x)
bg_x = m.conv_final_bg(bg_x)
# @btime push!($bg_outs, $bg_x)
push!(bg_outs, bg_x)
end
#@show length(bg_outs)
#@show length(mask_outs)
return bg_outs, mask_outs, nothing
end
Flux.@functor CoarseDecoder
struct Refinement{C<:Conv, CH1<:Chain, CH2<:Chain, CH3<:Chain, CH4<:Chain,D1<:DownConv,D2<:DownConv,D3<:DownConv,CFF<:CFFBlock}
conv_in::CH1
dec_conv2::C
dec_conv3::CH2
dec_conv4::CH3
down1::D1
down2::D2
down3::D3
cff_blocks::Vector{CFF}
out_conv::CH4
n_skips::Int64
end
function Refinement(;in_channels=3, out_channels=3, shared_depth=2, down=DownConv, up=UpConv, ngf=32, n_cff=3, n_skips=3)
conv_in = Chain(
Conv((3,3), in_channels=>ngf; stride=1, pad=1, bias=true),
Flux.InstanceNorm(ngf),
x->Flux.leakyrelu.(x, eltype(x)(0.2))
)
dec2 = Conv((1,1), ngf=>ngf; stride=1, pad=0, bias=true)
dec3 = Chain(
Conv((1,1), (ngf*2)=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
)
dec4 = Chain(
Conv((1,1), (ngf*4)=>(ngf*2), x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
Conv((3,3), (ngf*2)=>(ngf*2), x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
)
down1 = down(ngf, ngf, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[])
down2 = down(ngf, (ngf*2), 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[])
down3 = down((ngf*2), (ngf*4), 3, pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[1,2,5])
cffs = [CFFBlock(;ngf=ngf) for i in 1:n_cff]
out_conv = Chain(
Conv((3,3), (ngf+ngf*2+ngf*4)=>ngf; stride=1, pad=1, bias=true),
Flux.InstanceNorm(ngf),
x->Flux.leakyrelu.(x, eltype(x)(0.2)),
Conv((1,1), ngf=>out_channels; stride=1, pad=0)
)
Refinement{typeof(dec2), typeof(conv_in), typeof(dec3), typeof(dec4), typeof(out_conv), typeof(down1), typeof(down2), typeof(down3), eltype(cffs)}(
conv_in, dec2, dec3, dec4, down1, down2, down3, cffs, out_conv, n_skips )
end
function (m::Refinement)(input::AbstractArray{Float32, 4}, coarse_bg::AbstractArray{Float32, 4}, mask::AbstractArray{Float32, 4}, encoder_outs, decoder_outs::Vector{T} where T<:AbstractArray{Float32, 4})
xin = cat(coarse_bg, mask, dims=Val(3))
# @btime $m.conv_in($xin)
x = m.conv_in(xin)
# @btime $m.dec_conv2($decoder_outs[1])
m.n_skips < 1 && (x += m.dec_conv2(decoder_outs[1]))
# @btime $m.down1($x)
x,d1 = m.down1(x)
# @btime $m.dec_conv3($decoder_outs[2])
m.n_skips < 2 && (x += m.dec_conv3(decoder_outs[2]))
# @btime $m.down2($x)
x,d2 = m.down2(x)
# @btime $m.dec_conv4($decoder_outs[3])
m.n_skips < 3 && (x += m.dec_conv4(decoder_outs[3]))
# @btime $m.down3($x)
x,d3 = m.down3(x)
for block in m.cff_blocks
# @btime $block($xs)
d1,d2,d3 = block(d1,d2,d3)
end
# @btime [Flux.upsample_bilinear(x_hr; size=(size($coarse_bg)[2], size($coarse_bg)[1])) for x_hr in $xs]
xs = [Flux.upsample_bilinear(x_hr; size=(size(coarse_bg)[2], size(coarse_bg)[1])) for x_hr in (d1,d2,d3)]
# @btime $m.out_conv(cat($xs..., dims=3))
xct = Base.cat_t(eltype(xs[1]), xs...; dims=3)
im = m.out_conv(xct)
end
Flux.@functor Refinement
struct SLBR{CE<:CoarseEncoder, SB<:SharedBottleNeck, CD<:CoarseDecoder, RF<:Refinement}
encoder::CE
shared_decoder::SB
coarse_decoder::CD
refinement::RF
long_skip::Bool
end
function SLBR(; in_channels=3, depth=5, shared_depth=2, blocks=[1 for i in 1:5], out_channels_image=3, out_channels_mask=1, start_filters=32, residual=true, concat=true, long_skip=false, k_refine=3, n_skips=3, k_center=3)
encoder = CoarseEncoder(in_channels, depth-shared_depth; blocks=blocks[1], start_filters=start_filters, residual=residual, norm=Flux.BatchNorm, act=Flux.relu)
shared_decoder = SharedBottleNeck(start_filters*2^(depth-shared_depth-1), depth, shared_depth; blocks=blocks[5], residual=residual, concat=concat, norm=Flux.InstanceNorm)
coarse_decoder = CoarseDecoder(start_filters*2^(depth-shared_depth), out_channels_image, k_center; depth=depth-shared_depth, blocks=blocks[2], residual=residual, concat=concat, norm=Flux.BatchNorm, use_att=true)
refinement = Refinement(; in_channels=4, out_channels=3, shared_depth=1, n_cff=k_refine, n_skips=n_skips)
SLBR{typeof(encoder), typeof(shared_decoder), typeof(coarse_decoder), typeof(refinement)}(encoder, shared_decoder, coarse_decoder, refinement, long_skip)
end
function (m::SLBR)(synthesized::AbstractArray{Float32, 4})
# @btime $m.encoder($synthesized) # (type stablity)-> 970.384 ฮผs (2239 allocations: 109.50 KiB)
# @code_warntype m.encoder(synthesized)
image_code, before_pool = m.encoder(synthesized)
unshared_before_pool = before_pool
# @code_warntype m.shared_decoder(image_code) # 2.848 ms (7389 allocations: 467.92 KiB) (type stablity)-> 2.858 ms (7309 allocations: 464.67 KiB)
# @btime $m.shared_decoder($image_code)
im, mask0 = m.shared_decoder(image_code)
# @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool) # 233.433 ms (24298 allocations: 1.17 MiB) (type stablity)->234.892 ms (23475 allocations: 1.15 MiB)
# @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool)
ims, mask, wm = m.coarse_decoder(im, nothing, mask0, unshared_before_pool)
im = ims[end]
reconstructed_image = Flux.tanh_fast.(im)
if m.long_skip
reconstructed_image = reconstructed_image + synthesized
reconstructed_image = clamp.(reconstructed_image, zero(eltype(reconstructed_image)), one(eltype(reconstructed_image)))
end
reconstructed_mask = mask[end]
reconstruct_wm = wm
dec_feats = reverse(ims[1:end-1])
#@show eltype(reconstructed_image)
#@show eltype(reconstructed_mask)
#@show eltype(synthTimerOutPutsesized)
coarser = reconstructed_image .* reconstructed_mask + (one(eltype(reconstructed_mask)) .- reconstructed_mask) .* synthesized
# @btime m.refinement($synthesized, $coarser, $reconstructed_mask, nothing, $dec_feats) # 664.924 ms (24094 allocations: 1.45 MiB)
# @code_warntype m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)
refine_bg = m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)
refine_bg = clamp.(Flux.tanh_fast.(refine_bg ) + synthesized, zero(eltype(refine_bg)), one(eltype(refine_bg)))
return [refine_bg, reconstructed_image], mask, [reconstruct_wm]
end
Flux.@functor SLBR
m = SLBR(; shared_depth=2, blocks=[3 for i in 1:5], long_skip=false, k_center=2) |> gpu
@show "=================================="
# BSON.@save "model.bson" m
const x = rand(Float32, 720,720, 3, 1) |> gpu
# @btime m(x)
# @code_warntype m(x)
out = m(x);
reset_timer!(to)
# for i in 1:10
# out = m(x);
# end
m(x)
@time m(x)
# @show length(out)
# @show size(out[1][1])
# @show size(out[2][end])
show(to)
Is the differ(0.08s vs 1.2s) understandable, supposing that my implement is all right?
@ToucheSir
Donโt you think itโs a โgoodโ problem for Flux
?
I checked my code again and again, but nothing to be improved can be found.
could you run it on your pc and take a look?
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Time Allocations
โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโโ
Tot / % measured: 998ms / 32.7% 19.4MiB / 43.8%
Section ncalls time %tot avg alloc %tot avg
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
3.0 1 237ms 72.7% 237ms 4.61MiB 54.1% 4.61MiB
3.4 3 181ms 55.5% 60.2ms 2.41MiB 28.3% 822KiB
3.3.4 9 5.42ms 1.7% 603ฮผs 156KiB 1.8% 17.3KiB
3.3.6 9 2.49ms 0.8% 277ฮผs 73.3KiB 0.8% 8.14KiB
3.3.1 3 2.13ms 0.7% 708ฮผs 69.6KiB 0.8% 23.2KiB
3.3.3 3 1.64ms 0.5% 548ฮผs 65.1KiB 0.7% 21.7KiB
3.3.2 3 337ฮผs 0.1% 112ฮผs 16.4KiB 0.2% 5.47KiB
3.3.5 9 286ฮผs 0.1% 31.8ฮผs 25.2KiB 0.3% 2.80KiB
3.3 3 12.7ms 3.9% 4.22ms 557KiB 6.4% 186KiB
3.3.4 9 3.59ms 1.1% 399ฮผs 156KiB 1.8% 17.3KiB
3.3.6 9 2.57ms 0.8% 285ฮผs 73.3KiB 0.8% 8.14KiB
3.3.1 3 1.70ms 0.5% 568ฮผs 69.6KiB 0.8% 23.2KiB
3.3.5 9 1.17ms 0.4% 130ฮผs 25.2KiB 0.3% 2.80KiB
3.3.3 3 1.11ms 0.3% 371ฮผs 65.1KiB 0.7% 21.7KiB
3.1 3 1.09ms 0.3% 364ฮผs 36.8KiB 0.4% 12.3KiB
3.2 3 831ฮผs 0.3% 277ฮผs 36.8KiB 0.4% 12.3KiB
3.3.2 3 296ฮผs 0.1% 98.7ฮผs 16.4KiB 0.2% 5.47KiB
3.5 1 55.9ฮผs 0.0% 55.9ฮผs 5.38KiB 0.1% 5.38KiB
4.0 1 55.8ms 17.1% 55.8ms 1.39MiB 16.3% 1.39MiB
3.4 3 11.6ms 3.6% 3.86ms 967KiB 11.1% 322KiB
3.3.4 9 1.77ms 0.5% 196ฮผs 156KiB 1.8% 17.3KiB
3.3.6 9 1.46ms 0.4% 163ฮผs 73.3KiB 0.8% 8.14KiB
3.3.1 3 828ฮผs 0.3% 276ฮผs 69.6KiB 0.8% 23.2KiB
3.3.3 3 820ฮผs 0.3% 273ฮผs 65.1KiB 0.7% 21.7KiB
3.3.5 9 210ฮผs 0.1% 23.3ฮผs 25.2KiB 0.3% 2.80KiB
3.3.2 3 184ฮผs 0.1% 61.3ฮผs 16.4KiB 0.2% 5.47KiB
3.3 3 8.93ms 2.7% 2.98ms 557KiB 6.4% 186KiB
2.0 1 4.12ms 1.3% 4.12ms 459KiB 5.3% 459KiB
3.3.4 9 2.18ms 0.7% 243ฮผs 156KiB 1.8% 17.3KiB
1.0 1 2.00ms 0.6% 2.00ms 108KiB 1.2% 108KiB
3.2 3 1.29ms 0.4% 432ฮผs 36.8KiB 0.4% 12.3KiB
3.3.6 9 980ฮผs 0.3% 109ฮผs 73.3KiB 0.8% 8.14KiB
3.3.1 3 777ฮผs 0.2% 259ฮผs 69.6KiB 0.8% 23.2KiB
3.3.3 3 687ฮผs 0.2% 229ฮผs 65.1KiB 0.7% 21.7KiB
3.1 3 336ฮผs 0.1% 112ฮผs 36.8KiB 0.4% 12.3KiB
3.3.5 9 172ฮผs 0.1% 19.1ฮผs 25.2KiB 0.3% 2.80KiB
3.3.2 3 127ฮผs 0.0% 42.2ฮผs 16.4KiB 0.2% 5.47KiB
3.5 1 53.6ฮผs 0.0% 53.6ฮผs 5.38KiB 0.1% 5.38KiB
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ