The whole code, can be executed as follows:
using Flux: output_size
using Base: _before_colon, concatenate_setindex!
using Flux, BSON
using Revise
using BenchmarkTools
using InteractiveUtils
using TimerOutputs
const to = TimerOutput()
xshape(x) = eltype(x) <: AbstractArray ? (length(x), xshape(x[1])...) : [length(x)]
tovec(x) = eltype(x) <: AbstractArray ? tovec([xx_ for x_ in x for xx_ in x_]) : [x...]
struct DownConv{C1<:Conv,C2<:Conv, BI<:Union{BatchNorm,InstanceNorm},M<:MaxPool,F<:Function}
conv1::C1
conv2::Vector{C2}
act::F
norm1::BI
bn::Vector{BI}
residual::Bool
pooling:: Bool
pool::M
end
function DownConv(in_channels, out_channels, blocks; pooling=true, norm=Flux.BatchNorm, act=Flux.relu, residual=true, dilations=[])
# norm = norm == "bn" ? Flux.BatchNorm : norm == "in" ? Flux.InstanceNorm : error("Unknown type:\t$norm")
conv1 = Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
dilations = length(dilations) == 0 ? [1 for i in 1:blocks] : dilations
conv2 =[Conv((3,3), out_channels=>out_channels; dilation=dilations[i], pad=dilations[i])
for i in 1:blocks ]
bn = fill(norm(out_channels), blocks)
pool = Flux.MaxPool((2, 2), stride=2)
DownConv{typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)}(conv1, conv2, act, norm1, bn, residual, pooling, pool)
end
function (m::DownConv)(x)
x1 = m.act.(m.norm1(m.conv1(x)))
for (idx, conv) in enumerate(m.conv2)
x2 = conv(x1)
x2 = m.bn[idx](x2)
x1 = m.residual ? x1+x2 : x2
x1 = m.act.(x1)
end
before_pool = deepcopy(x1)
x1 = m.pooling ? m.pool(x1) : x1
x1, before_pool
end
Flux.@functor DownConv
struct UpConv{C<:Conv,BI<:Union{BatchNorm,InstanceNorm},F<:Function}
up_conv::C
conv1::C
conv2::Vector{C}
bn::Vector{BI}
norm0::BI
norm1::BI
act::F
concat::Bool
use_mask::Bool
residual::Bool
end
function UpConv(in_channels, out_channels, blocks; residual=true, norm=Flux.BatchNorm, act=Flux.relu, concat=true, use_att=false, use_mask=false, dilations=[], out_fuse=false)
up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
norm0 = norm(out_channels)
if length(dilations)==0
dilations = [1 for _ in 1:blocks]
end
if concat
conv1 = Conv((3,3), (2*out_channels+(use_mask ? 1 : 0))=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
else
conv1 = Conv((3,3), out_channels=>out_channels; pad=1, bias=true)
norm1 = norm(out_channels)
end
conv2 = [ Conv((3,3), out_channels=>out_channels; dilation = dilations[i], pad=dilations[i], bias=true) for i in 1:blocks ]
bn = [norm(out_channels) for _ in 1:blocks]
UpConv{typeof(up_conv), typeof(norm0), typeof(act)}(up_conv, conv1, conv2, bn, norm0, norm1, act, concat, use_mask, residual)
end
struct OutFuse{v}
end
OutFuse(x) = OutFuse{x}()
function (m::UpConv)(::OutFuse{true}, from_up, from_down; mask=nothing, se=nothing)#::Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}
from_up = @timeit to "4.1.1" m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))
x1 = @timeit to "4.1.2" m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (!(from_down===nothing) ? from_up + from_down : from_up)
xfuse = x1 = @timeit to "4.1.3" m.act.(m.norm1(m.conv1(x1)))
for (idx, conv) in enumerate(m.conv2)
x2 = @timeit to "4.1.4" conv(x1)
x2 = @timeit to "4.1.5" m.bn[idx](x2)
if !(se===nothing) && idx == length(m.conv2)
x2 = se(x2)
end
if m.residual
x2 = x2 + x1
end
x2 = @timeit to "4.1.6" m.act.(x2)
x1 = x2
end
x1, xfuse
end
function (m::UpConv)(::OutFuse{false}, from_up, from_down; mask=nothing, se=nothing)#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))
x1 = m.concat ?
(m.use_mask ? Base.cat(from_up, from_down, mask; dims=Val(3)) : Base.cat(from_up, from_down; dims=Val(3))) :
(!(from_down===nothing) ? from_up + from_down : from_up)
x1 = m.act.(m.norm1(m.conv1(x1)))
for (idx, conv) in enumerate(m.conv2)
x2 = conv(x1)
x2 = m.bn[idx](x2)
if !(se===nothing) && idx == length(m.conv2)
x2 = se(x2)
end
if m.residual
x2 = x2 + x1
end
x2 = m.act.(x2)
x1 = x2
end
x1
end
Flux.@functor UpConv
struct CFFBlock{C1<:Conv, C2<:Conv, D1<:DownConv, D2<:DownConv, D3<:DownConv, CH1<:Chain, CH2<:Chain}
up32::C1
up31::C2
down1::D1
down2::D2
down3::D3
conv22::CH1
conv33::CH2
end
function CFFBlock(; down=DownConv, up=UpConv, ngf::Int=32)
p = [
Conv((3,3), ngf*4 => ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf*4=>ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
DownConv(ngf, ngf, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
DownConv(ngf, ngf*2, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
DownConv(ngf*2, ngf*4, 3; pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[1,2,5]),
Chain(
Conv((3,3), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
),
Chain(
Conv((3,3), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
)
]
CFFBlock(p...)
end
function (m::CFFBlock)(inputs)
x1, x2, x3 = inputs
# @show size(x1)
# @show size(x2)
# @show size(x3)
#@show (size(x2)[2], size(x2)[1])
#@show (size(x1)[2], size(x1)[1])
x32 = Flux.upsample_bilinear(x3; size=(size(x2)[2], size(x2)[1]))
x32 = m.up32(x32)
x31 = Flux.upsample_bilinear(x3; size=(size(x1)[2], size(x1)[1]))
x31 = m.up31(x31)
# cross-connection
x, d1 = m.down1(x1 + x31)
x, d2 = m.down2(x + m.conv22(x2) + x32)
#@show size(x)
#@show size(x3)
#@show size(m.conv33(x3))
d3, _ = m.down3(x + m.conv33(x3))
d1,d2,d3
end
Flux.@functor CFFBlock
struct ECABlocks{C<:Conv,AMP<:Flux.AdaptiveMeanPool}
conv::C
avg_pool::AMP
end
ECABlocks(channel, k_size=3) = ECABlocks(
Conv((k_size,), 1=>1, sigmoid; pad=(k_size-1)รท2, bias=false),
AdaptiveMeanPool((1,1))
)
function (m::ECABlocks)(x)
# h, w, c, b = size(x)
y = m.avg_pool(x)
y = Flux.unsqueeze(permutedims(m.conv(permutedims(reshape(y, size(y)[2:end]), (2,1,3))), (2, 1, 3)), dims=1)
# y = Flux.sigmoid(y) ๅจ่ฟ้ๅฏไปฅๆพๅฐconv้ๅป
x .* y
end
Flux.@functor ECABlocks
struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
up_conv::C
bn::Vector{BI}
norm0::BI
norm1::BI
conv1::C
conv2::Vector{CH}
conv3::Vector{C}
act::F
concat::Bool
residual::Bool
end
function MBEBlock(in_channels=512, out_channels=3; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true, is_final=true)
up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
conv1 = Conv((3,3), (concat ? 2*out_channels : out_channels)=>out_channels; pad=1, bias=true)
conv2 = Vector{Chain}()
conv3 = Vector{Conv}()
for i in 1:blocks
push!(conv2, Chain(
Conv((5,5), (out_channels รท 2 + 1)=>(out_channels รท 4), Flux.relu; stride=1, pad=2, bias=true),
Conv((5,5), (out_channels รท 4)=>1, Flux.sigmoid_fast; stride=1, pad=2, bias=true)
))
push!(conv3, Conv((3,3), (out_channels รท 2)=>out_channels; pad=1, bias=true))
end
bn = [norm(out_channels) for i in 1:blocks]
MBEBlock(up_conv, bn, norm(out_channels), norm(out_channels), conv1, conv2, conv3, act, concat, residual)
end
function (m::MBEBlock)(from_up, from_down, mask=nothing)
from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))
if m.concat
x1 = cat(from_up, from_down; dims=Val(3))
elseif !(from_down === nothing)
x1 = from_up + from_down
else
x1 = from_up
end
x1 = m.act.(m.norm1(m.conv1(x1)))
#residual structure
H, W, C, _ = size(x1)
for (idx, (conv1, conv2)) in enumerate(zip(m.conv2, m.conv3))
#@show size(x1)
#@show size(x1[:,:,1:(C รท 2 ), :])
#@show size(mask)
mask = conv1(cat(x1[:,:,1:(C รท 2), :], mask; dims=Val(3)))
x2_actv = x1[:,:,(C รท 2+1) : end, :] .*mask
x2 = conv2(x1[:,:,(C รท 2+1) : end, :] + x2_actv)
x2 = m.bn[idx](x2)
if m.residual
x = x2 + x1
end
x2 = m.act.(x2)
x1 = x2
end
x1
end
Flux.@functor MBEBlock
struct SelfAttentionSimple{C<:Conv, CH<:Chain, AA<:AbstractArray{Float32, 4}}
k_center::Int32
q_conv::C
k_conv::C
v_conv::C
sim_func::C
out_conv::CH
min_area::Float32
threshold::Float32
k_weight::AA
end
SelfAttentionSimple(in_channel, k_center) = SelfAttentionSimple(Int32(k_center),
Conv((1, 1), in_channel=>in_channel),
Conv((1, 1), in_channel=>in_channel*k_center),
Conv((1, 1), in_channel=>in_channel*k_center),
Conv((1,1), (2*in_channel) => 1; stride=1, pad=0, bias=true),
Chain(
Conv((3,3), in_channel=>(in_channelรท8), Flux.relu; stride=1, pad=1),
Conv((3,3), (in_channelรท8)=>1; stride=1, pad=1)
),
100.0f0, 0.5f0,
fill(1.0f0, (1, 1, k_center, 1))
)
function compute_attention(m::SelfAttentionSimple, query::T, key::T, mask::T, eps=1) where{T}#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
h,w,c,b = size(query)
#@show size(query)
# @btime $m.q_conv($query)
query = @timeit to "cmp_att_1" m.q_conv(query)
key_in = key
# @btime $m.k_conv($key_in)
key = @timeit to "cmp_att_2" m.k_conv(key_in)
# keys = [view(key, :,:,i:(i+c-1),:) for i in 1:c:size(key)[3]]
keys = @timeit to "cmp_att_3" Vector{T}([key[:,:,i:(i+c-1),:] for i in 1:c:size(key)[3]])
# @btime eltype($mask).($mask .> $m.threshold)
importance_map = @timeit to "cmp_att_4" eltype(mask).(mask .> m.threshold)
# @btime sum($importance_map, dims=[1,2])
s_area = @timeit to "cmp_att_5" sum(importance_map, dims=[1,2])
# mask = s_area .>= m.min_area
# s_area = s_area .* mask + m.min_area .* .!mask
# @btime clamp_lo(one(eltype($s_area))*$m.min_area).($s_area)
@timeit to "cmp_att_6" clamp!(s_area, m.min_area, 1.0f7) #clamp_lo(one(eltype(s_area))*m.min_area).(s_area)
s_area = @timeit to "cmp_att_7" s_area[:, :, 1:1, :]#view(s_area, :, :, 1:1, :)
@timeit to "cmp_att_8" if m.k_center != 2
# @btime [sum(k .*$importance_map, dims=[1, 2]) ./ $s_area for k in $keys]
ks = Vector{T}()
for k in keys
push!(ks, sum(k .*importance_map, dims=(Val(1), Val(2))) ./ s_area)
end
keys = ks
# keys = [sum(keys[1] .*importance_map, dims=[1, 2]) ./ s_area, sum(keys[2] .*importance_map, dims=[1, 2]) ./ s_area]
else
# @btime [sum($keys[1] .* $importance_map, dims=[1,2]) ./$s_area,
# sum($keys[2] .*(one(eltype($importance_map)) .- $importance_map), dims=[1,2]) ./ (size($keys[2])[1] * size($keys[2])[2] .- $s_area .+ $eps),
# ]
keys = Vector{T}([sum(keys[1] .* importance_map, dims=[1,2]) ./s_area,
sum(keys[2] .*(one(eltype(importance_map)) .- importance_map), dims=[1,2]) ./ (size(keys[2])[1] * size(keys[2])[2] .+ eps .- s_area)
])
end
f_query = query
f_key = @timeit to "cmp_att_9" [reshape(k, (1, 1, c, b)) .* fill!(typeof(f_query)(undef, size(f_query)), 1) for k in keys]
# f_key = [Flux.upsample_nearest(reshape(k, (1, 1, c, b)), size=(size(f_query)[1], size(f_query)[2])) for k in keys]
# @btime [reshape(k, (1, 1, $c, $b)) for k in $keys]
# f_key = [reshape(k, (1, 1, c, b)) for k in keys]
# @btime [cat([k for i in 1:size($f_query)[1]]..., dims=1) for k in $f_key]
# f_key = [cat([k for i in 1:size(f_query)[1]]..., dims=1) for k in f_key]
# @btime [cat([k for i in 1:size($f_query)[2]]..., dims=2) for k in $f_key]
# f_key = [cat([k for i in 1:size(f_query)[2]]..., dims=2) for k in f_key]
attention_scores = Vector{T}()
for k in f_key
# @btime Flux.tanh_fast.(cat($f_query, $k, dims=3))
combine_qk = @timeit to "cmp_att_10" Flux.tanh_fast.(cat(f_query, k, dims=Val(3)))
# @btime $m.sim_func($combine_qk)
sk = @timeit to "cmp_att_11" m.sim_func(combine_qk)
@timeit to "cmp_att_12" push!(attention_scores, sk)
end
# @btime cat($attention_scores...; dims=3)
# s = cat(attention_scores...; dims=Val(3))
s = @timeit to "cmp_att_13" reshape(mapreduce(Flux.flatten, vcat, attention_scores), size(attention_scores[1])[1], size(attention_scores[1])[2], sum([size(x)[3] for x in attention_scores]), size(attention_scores[1])[4])
# @btime permutedims($s, (3,1,2,4))
s = @timeit to "cmp_att_14" permutedims(s, (3,1,2,4))
# @btime $m.v_conv($key_in)
v = @timeit to "cmp_att_15" m.v_conv(key_in)
if m.k_center == 2
# @btime sum($v[:, :, 1:$c-1, :] .* $importance_map, dims=[1,2]) ./ $s_area
v_fg = @timeit to "cmp_att_16" sum(v[:, :, 1:c-1, :] .* importance_map, dims=[1,2]) ./ s_area
# @btime sum($v[:,:,$c:end, :] .* (1 .- $importance_map);dims=[1,2]) ./ (size($v)[1] * size($v)[2] .- $s_area .+ $eps)
v_bg = @timeit to "cmp_att_17" sum(v[:,:,c:end, :] .* (1 .- importance_map);dims=[1,2]) ./ (size(v)[1] * size(v)[2] .- s_area .+ eps)
v = @timeit to "cmp_att_18" cat(v_fg, v_bg; dims=Val(3))
else
# @btime sum($v .* $importance_map, dims=[1,2]) ./ $s_area
v = @timeit to "cmp_att_19" sum(v .* importance_map, dims=[1,2]) ./ s_area
end
v = @timeit to "cmp_att_20" reshape(v, (c, m.k_center, b))
#@show size(s), (m.k_center, h*w, b)
#@show size(v)
# @btime permutedims(reshape(Flux.batched_mul($v, reshape($s, ($m.k_center, $h*$w, $b))), ($c,$h,$w,$b)), (2, 3, 1,4))
attn = @timeit to "cmp_att_21" permutedims(reshape(Flux.batched_mul(v, reshape(s, (m.k_center, h*w, b))), (c,h,w,b)), (2, 3, 1,4))
#@show size(query)
#@show size(attn)
# @btime $m.out_conv($attn+$query)
@timeit to "cmp_att_22" m.out_conv(attn + query)
end
function (m::SelfAttentionSimple)(xin, xout, xmask)
h, w, c, b_num= size(xin)
attention_score = compute_attention(m, xin, xout, xmask)
attention_score = reshape(attention_score, (h, w, 1, b_num))
return xout, Flux.sigmoid.(attention_score)
end
Flux.@functor SelfAttentionSimple
struct SMRBlock{C<:Conv, U<:UpConv, S<:SelfAttentionSimple}
upconv::U
primary_mask::C
self_calibrated::S
end
SMRBlock(ins, outs, k_center; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true) = SMRBlock(
UpConv(ins, outs, blocks; residual=residual, concat=concat, norm=norm, act=act),
Conv((1,1), outs=>1, Flux.sigmoid_fast; stride=1, pad=0, bias=true),
SelfAttentionSimple(outs, k_center)
)
function (m::SMRBlock)(input, encoder_outs=nothing)
# @btime $m.upconv(OutFuse(true), $input, $encoder_outs)
mask_x, fuse_x = @timeit to "4.1" m.upconv(OutFuse(true), input, encoder_outs)
# @btime $m.primary_mask($mask_x)
primary_mask = @timeit to "4.2" m.primary_mask(mask_x)
# # @btime $m.self_calibrated($mask_x, $mask_x, $primary_mask)
mask_x, self_calibrated_mask = @timeit to "4.3" m.self_calibrated(mask_x, mask_x, primary_mask)
return Dict(
"feats"=>[mask_x],
"attn_maps"=>[primary_mask, self_calibrated_mask]
)
end
Flux.@functor SMRBlock
struct CoarseEncoder{D<:DownConv}
down_convs::Vector{D}
end
function CoarseEncoder(in_channels::Int=3, depth::Int=3; blocks=1, start_filters=32, residual=true, norm=Flux.BatchNorm, act=Flux.relu)
down_convs = []
outs = nothing
if isa(blocks, Tuple)
blocks = blocks[0]
end
for i in 1:depth
ins = i==1 ? in_channels : outs
outs = start_filters*(2^(i-1))
pooling = true
# #@show ins, depth
push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act))
end
CoarseEncoder{DownConv}(down_convs)
end
function (m::CoarseEncoder)(x)
nx = x
encoder_outs = Vector{typeof(x)}()
for d_conv in m.down_convs
nx, before_pool = d_conv(nx)
push!(encoder_outs, before_pool)
end
nx, encoder_outs
end
Flux.@functor CoarseEncoder
struct SharedBottleNeck{U<:UpConv, D<:DownConv, ECA1<:ECABlocks, ECA2<:ECABlocks}
up_convs::Vector{U}
down_convs::Vector{D}
up_im_atts::Vector{ECA1}
up_mask_atts::Vector{ECA2}
end
function SharedBottleNeck(in_channels=512, depth=5, shared_depth=2; start_filters=32, blocks=1, residual=true, concat=true, norm=Flux.BatchNorm, act=Flux.relu, dilations=[1,2,5])
start_depth = depth - shared_depth
max_filters = 512
down_convs = Vector{DownConv}()
up_convs = Vector{UpConv}()
up_im_atts = Vector{ECABlocks}()
up_mask_atts = Vector{ECABlocks}()
outs = 0
for i in start_depth:depth-1
# #@show i, start_depth
ins = i == start_depth ? in_channels : outs
outs = min(ins*2, max_filters)
# encoder convs
pooling = i<depth-1 ? true : false
push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act, dilations=dilations))
# decoder convs
if i < depth - 1
up_conv = UpConv(min(outs*2, max_filters), outs, blocks, residual=residual, concat=concat, norm=norm, act=Flux.relu, dilations=dilations)
push!(up_convs, up_conv)
push!(up_im_atts, ECABlocks(outs))
push!(up_mask_atts, ECABlocks(outs))
end
end
@show eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)
SharedBottleNeck{eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)}(up_convs, down_convs, up_im_atts, up_mask_atts)
end
function (m::SharedBottleNeck)(input)
# encoder convs
im_encoder_outs = Vector{typeof(input)}()
mask_encoder_outs = Vector{typeof(input)}()
x = input
for (i, d_conv) in enumerate(m.down_convs)
x, before_pool = d_conv(x)
push!(im_encoder_outs, before_pool)
push!(mask_encoder_outs, before_pool)
end
x_im = x
x_mask = x
#@show size(x_mask)
# Decoder convs
x = x_im
for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_im_atts))) in enumerate(zip(m.up_convs, m.up_im_atts))
before_pool = im_encoder_outs === nothing ? nothing : im_encoder_outs[end-i]
x = up_conv(OutFuse(false), x, before_pool, se=attn)
end
x_im = x
x = x_mask
for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_mask_atts))) in enumerate(zip(m.up_convs, m.up_mask_atts))
before_pool = mask_encoder_outs === nothing ? nothing : mask_encoder_outs[end-i]
x = up_conv(OutFuse(false), x, before_pool, se=attn)
end
x_mask = x
x_im, x_mask
end
Flux.@functor SharedBottleNeck
struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
up_convs_bg::Vector{MBE}
up_convs_mask::Vector{SMR}
atts_mask::Vector{ECA}
atts_bg::Vector{ECA}
conv_final_bg::C
use_att::Bool
end
function CoarseDecoder(in_channels=512, out_channels=3, k_center=2; norm=Flux.BatchNorm, act=Flux.relu, depth=5, blocks=1, residual=true, concat=true, use_att=false)
up_convs_bg = Vector{MBEBlock}()
up_convs_mask = Vector{SMRBlock}()
atts_bg = Vector{ECABlocks}()
atts_mask = Vector{ECABlocks}()
outs = in_channels
for i in 1:depth
ins = outs
outs = ins รท 2
# background reconstruction branch
up_conv = MBEBlock(ins, outs; blocks=blocks, residual=residual, concat=concat, norm=norm, act=act)
push!(up_convs_bg, up_conv)
if use_att
push!(atts_bg, ECABlocks(outs))
end
#mask prediction branch
up_conv = SMRBlock(ins, outs, k_center; norm=norm, act=act, blocks=blocks, residual=residual, concat=concat)
push!(up_convs_mask, up_conv)
if use_att
push!(atts_mask, ECABlocks(outs))
end
end
conv_final_bg = Conv((1,1), outs=>out_channels, stride=1, pad=0, bias=true)
CoarseDecoder{typeof(conv_final_bg),eltype(up_convs_bg),eltype(up_convs_mask), eltype(atts_mask)}(up_convs_bg, up_convs_mask, atts_mask, atts_bg,
conv_final_bg,
use_att
)
end
function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
bg_x = bg
mask_x = mask
mask_outs = Vector{T}()
bg_outs = Vector{T}()
for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
# @btime before_pool = $encoder_outs[end-($i-1)]
before_pool = @timeit to "encoder_outs1" encoder_outs[end-(i-1)] #encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)]
if m.use_att
# @btime $m.atts_mask[$i]($before_pool)
mask_before_pool = @timeit to "atts_mask2" m.atts_mask[i](before_pool)
# @btime $m.atts_bg[$i]($before_pool)
bg_before_pool = @timeit to "atts_bg3" m.atts_bg[i](before_pool)
end
# @btime $up_mask($mask_x,$mask_before_pool)
# @code_warntype up_mask(mask_x, mask_before_pool)
smr_outs = @timeit to "up_mask4" up_mask(mask_x, mask_before_pool)
# @btime mask_x = $smr_outs["feats"][1]
mask_x = @timeit to "smr_outs5" smr_outs["feats"][1]
# @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
primary_map, self_calibrated_map = @timeit to "smr_outs6" smr_outs["attn_maps"]
# @btime push!($mask_outs, $primary_map)
@timeit to "push7" push!(mask_outs, primary_map)
# @btime push!($mask_outs, $self_calibrated_map)
@timeit to "push8" push!(mask_outs, self_calibrated_map)
# @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
bg_x = @timeit to "up_bg9" up_bg(bg_x, bg_before_pool, self_calibrated_map) # ่ฟ้ๅฏ่ฝๆ้ฎ้ข
# @btime push!($bg_outs, $bg_x)
@timeit to "push10" push!(bg_outs, bg_x)
end
if m.conv_final_bg !== nothing
# @btime $m.conv_final_bg($bg_x)
bg_x = @timeit to "conv_final_bg11" m.conv_final_bg(bg_x)
# @btime push!($bg_outs, $bg_x)
@timeit to "push12" push!(bg_outs, bg_x)
end
#@show length(bg_outs)
#@show length(mask_outs)
return bg_outs, mask_outs, nothing
end
Flux.@functor CoarseDecoder
struct Refinement{C<:Conv, CH1<:Chain, CH2<:Chain, CH3<:Chain, CH4<:Chain,D1<:DownConv,D2<:DownConv,D3<:DownConv,CFF<:CFFBlock}
conv_in::CH1
dec_conv2::C
dec_conv3::CH2
dec_conv4::CH3
down1::D1
down2::D2
down3::D3
cff_blocks::Vector{CFF}
out_conv::CH4
n_skips::Int
end
Refinement(;in_channels=3, out_channels=3, shared_depth=2, down=DownConv, up=UpConv, ngf=32, n_cff=3, n_skips=3) = Refinement(
Chain(
Conv((3,3), in_channels=>ngf; stride=1, pad=1, bias=true),
Flux.InstanceNorm(ngf),
x->Flux.leakyrelu.(x, eltype(x)(0.2))
),
Conv((1,1), ngf=>ngf; stride=1, pad=0, bias=true),
Chain(
Conv((1,1), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
),
Chain(
Conv((1,1), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
),
down(ngf, ngf, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[]),
down(ngf, ngf*2, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[]),
down(ngf*2, ngf*4, 3, pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[1,2,5]),
[CFFBlock(;ngf=ngf) for i in 1:n_cff],
Chain(
Conv((3,3), (ngf+ngf*2+ngf*4)=>ngf; stride=1, pad=1, bias=true),
Flux.InstanceNorm(ngf),
x->Flux.leakyrelu.(x, eltype(x)(0.2)),
Conv((1,1), ngf=>out_channels; stride=1, pad=0)
),
n_skips
)
function (m::Refinement)(input, coarse_bg, mask, encoder_outs, decoder_outs)
xin = cat(coarse_bg, mask, dims=Val(3))
# @btime $m.conv_in($xin)
x = m.conv_in(xin)
# @btime $m.dec_conv2($decoder_outs[1])
m.n_skips < 1 && (x += m.dec_conv2(decoder_outs[1]))
# @btime $m.down1($x)
x,d1 = m.down1(x)
# @btime $m.dec_conv3($decoder_outs[2])
m.n_skips < 2 && (x += m.dec_conv3(decoder_outs[2]))
# @btime $m.down2($x)
x,d2 = m.down2(x)
# @btime $m.dec_conv4($decoder_outs[3])
m.n_skips < 3 && (x += m.dec_conv4(decoder_outs[3]))
# @btime $m.down3($x)
x,d3 = m.down3(x)
xs = [d1, d2, d3]
for block in m.cff_blocks
# @btime $block($xs)
xs = block(xs)
end
# @btime [Flux.upsample_bilinear(x_hr; size=(size($coarse_bg)[2], size($coarse_bg)[1])) for x_hr in $xs]
xs = [Flux.upsample_bilinear(x_hr; size=(size(coarse_bg)[2], size(coarse_bg)[1])) for x_hr in xs]
# @btime $m.out_conv(cat($xs..., dims=3))
im = m.out_conv(cat(xs..., dims=Val(3)))
im
end
Flux.@functor Refinement
struct SLBR
encoder::CoarseEncoder
shared_decoder::SharedBottleNeck
coarse_decoder::CoarseDecoder
refinement::Refinement
long_skip::Bool
end
SLBR(; in_channels=3, depth=5, shared_depth=2, blocks=[1 for i in 1:5], out_channels_image=3, out_channels_mask=1, start_filters=32, residual=true, concat=true, long_skip=false, k_refine=3, n_skips=3, k_center=3) = SLBR(
CoarseEncoder(in_channels, depth-shared_depth; blocks=blocks[1], start_filters=start_filters, residual=residual, norm=Flux.BatchNorm, act=Flux.relu),
SharedBottleNeck(start_filters*2^(depth-shared_depth-1), depth, shared_depth; blocks=blocks[5], residual=residual, concat=concat, norm=Flux.InstanceNorm),
CoarseDecoder(start_filters*2^(depth-shared_depth), out_channels_image, k_center; depth=depth-shared_depth, blocks=blocks[2], residual=residual, concat=concat, norm=Flux.BatchNorm, use_att=true),
Refinement(; in_channels=4, out_channels=3, shared_depth=1, n_cff=k_refine, n_skips=n_skips),
long_skip
)
function (m::SLBR)(synthesized)
# @btime $m.encoder($synthesized) # (type stablity)-> 970.384 ฮผs (2239 allocations: 109.50 KiB)
image_code, before_pool = m.encoder(synthesized)
unshared_before_pool = before_pool
# @code_warntype m.shared_decoder(image_code) # 2.848 ms (7389 allocations: 467.92 KiB) (type stablity)-> 2.858 ms (7309 allocations: 464.67 KiB)
# @btime $m.shared_decoder($image_code)
im, mask = m.shared_decoder(image_code)
# @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool) # 233.433 ms (24298 allocations: 1.17 MiB) (type stablity)->234.892 ms (23475 allocations: 1.15 MiB)
# @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool)
ims, mask, wm = m.coarse_decoder(im, nothing, mask, unshared_before_pool)
im = ims[end]
reconstructed_image = Flux.tanh_fast.(im)
if m.long_skip
reconstructed_image = reconstructed_image + synthesized
reconstructed_image = clamp.(reconstructed_image, zero(eltype(reconstructed_image)), one(eltype(reconstructed_image)))
end
reconstructed_mask = mask[end]
reconstruct_wm = wm
dec_feats = reverse(ims[1:end-1])
#@show eltype(reconstructed_image)
#@show eltype(reconstructed_mask)
#@show eltype(synthTimerOutPutsesized)
coarser = reconstructed_image .* reconstructed_mask + (one(eltype(reconstructed_mask)) .- reconstructed_mask) .* synthesized
# @btime m.refinement($synthesized, $coarser, $reconstructed_mask, nothing, $dec_feats) # 664.924 ms (24094 allocations: 1.45 MiB)
refine_bg = m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)
refine_bg = clamp.(Flux.tanh_fast.(refine_bg ) + synthesized, zero(eltype(refine_bg)), one(eltype(refine_bg)))
return [refine_bg, reconstructed_image], mask, [reconstruct_wm]
end
Flux.@functor SLBR
m = SLBR(; shared_depth=2, blocks=[3 for i in 1:5], long_skip=true, k_center=2) |> gpu
# BSON.@save "model.bson" m
x = rand(Float32, 720,720, 3, 1) |> gpu
# @btime m(x)
out = m(x);
reset_timer!(to)
out = m(x);
@show length(out)
@show size(out[1][1])
@show size(out[2][end])
show(to)