There is a huge difference in time consumption between calling a function sentence by sentence and calling the entire function at once

First, I statistics time consumer of each sentence with @btime:

function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
    bg_x = bg
    mask_x = mask
    mask_outs = Vector{T}()
    bg_outs = Vector{T}()
    for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
        @btime before_pool = $encoder_outs[end-($i-1)]
        before_pool = encoder_outs[end-(i-1)] #encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)]
        if m.use_att
            @btime $m.atts_mask[$i]($before_pool)
            mask_before_pool = m.atts_mask[i](before_pool)
            @btime $m.atts_bg[$i]($before_pool)
            bg_before_pool = m.atts_bg[i](before_pool)
        end
        @btime $up_mask($mask_x,$mask_before_pool)
        # @code_warntype up_mask(mask_x, mask_before_pool)
        smr_outs = up_mask(mask_x, mask_before_pool)
        @btime mask_x = $smr_outs["feats"][1]
        mask_x = smr_outs["feats"][1]
        @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
        primary_map, self_calibrated_map = smr_outs["attn_maps"]
        @btime push!($mask_outs, $primary_map)
        push!(mask_outs, primary_map)
        @btime push!($mask_outs, $self_calibrated_map)
        push!(mask_outs, self_calibrated_map)
        @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
        bg_x = up_bg(bg_x, bg_before_pool, self_calibrated_map) # ่ฟ™้‡Œๅฏ่ƒฝๆœ‰้—ฎ้ข˜
        @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
    end
    if m.conv_final_bg !== nothing
        @btime $m.conv_final_bg($bg_x)
        bg_x = m.conv_final_bg(bg_x)
        @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
    end
    #@show length(bg_outs)
    #@show length(mask_outs)
    return bg_outs, mask_outs, nothing
end

Then, I removed each sentence start with @btime, and call the function with @btime.

output 1st:

  2.452 ns (0 allocations: 0 bytes)
  68.748 ฮผs (279 allocations: 16.20 KiB)
  69.494 ฮผs (279 allocations: 16.20 KiB)
  1.265 ms (4283 allocations: 254.20 KiB)
  18.716 ns (0 allocations: 0 bytes)
  19.152 ns (0 allocations: 0 bytes)
  7.785 ns (0 allocations: 0 bytes)
  7.819 ns (0 allocations: 0 bytes)
  1.091 ms (3489 allocations: 191.00 KiB)
  7.784 ns (0 allocations: 0 bytes)
  2.531 ns (0 allocations: 0 bytes)
  69.095 ฮผs (282 allocations: 16.30 KiB)
  69.612 ฮผs (282 allocations: 16.30 KiB)
  1.354 ms (4278 allocations: 254.02 KiB)
  18.726 ns (0 allocations: 0 bytes)
  18.546 ns (0 allocations: 0 bytes)
  8.783 ns (0 allocations: 0 bytes)
  7.671 ns (0 allocations: 0 bytes)
  1.143 ms (3515 allocations: 191.59 KiB)
  7.824 ns (0 allocations: 0 bytes)
  2.282 ns (0 allocations: 0 bytes)
  73.834 ฮผs (281 allocations: 16.30 KiB)
  73.659 ฮผs (281 allocations: 16.30 KiB)
  1.385 ms (4346 allocations: 256.08 KiB)
  18.680 ns (0 allocations: 0 bytes)
  18.621 ns (0 allocations: 0 bytes)
  7.814 ns (0 allocations: 0 bytes)
  8.580 ns (0 allocations: 0 bytes)
  1.190 ms (3584 allocations: 193.45 KiB)
  8.609 ns (0 allocations: 0 bytes)
  28.334 ฮผs (113 allocations: 6.34 KiB)
  8.843 ns (0 allocations: 0 bytes)

Output 2nd:

209.548 ms (26226 allocations: 1.44 MiB)

sum of 1st is about 9ms, just 1/22 over 2nd experiment!
Whatโ€™s wrong!

Iโ€™m confused. Can you post the entire, exact code youโ€™re benchmarking? As written, I canโ€™t tell if youโ€™re timing the cost of a single iteration in one place but the sum over multiple iterations in another place.

Thanks!
The entire code is too large.
As I wrote " Then, I removed each sentence start with @btime, and call the function with @btime.",
I think there is nothing Uncertainties. Right?

Ok, lets note the instance of CoarseDecoder as m, the 2nd experiment is done as follows:

  • remove each sentence which starts with @btime, in the function.

  • call the function as @btime $m($param1, $param2, ...)

In the test with @btime, you execute both the line with @btime and the line without it? If so the lines are probably executed two times, no? Because the line with @btime also executes the action.

I removed each sentence start with @btime, and call the function with @btime, in the 2nd experiment.

The results looks unbelievable, right? Itโ€™s why I post it here.

Again, the output show it, or the output will not so clean โ€” all is here in one line

209.548 ms (26226 allocations: 1.44 MiB)

No, I mean, in the test with @btime in the lines did you execute both the line with the @btime and without it? As in the code you posted.

Yes, it is. And the lines with @btime give outputs, the without provide parameters for follows.

But then the two codes do different things, in the run with @btime in the lines:

        @btime push!($mask_outs, $primary_map)
        push!(mask_outs, primary_map)

will push two values inside primary_map. This will mess algorithm, no?

3 Likes

Yes, It is.
But that will not change the result: sum time much less than one call.

Whatโ€™s more it takes only 60ms to execute in pytorch.

You might be interested in trying TimerOutputs for organizing a bunch of timers like this. Maybe the general performance tips are also useful.

2 Likes

This result isnโ€™t actually strange for a couple reasons:

  1. push! increases the dataโ€™s size, and the @btime version runs that in a huge loop in the global scope. For example, your push!(bg_outs, bg_x) line would have a much larger bg_outs if you ran @btime push!(bg_outs, bg_x) beforehand. It should make sense that when your data is so different between the 2 versions, they wonโ€™t do the same things and wonโ€™t take the same time.

  2. To make multiple lines work together, the function may be doing some work between the lines. For example, poorly inferred variables (type instability) will require methods to be dynamically dispatched, which takes more time, and @btimeing an individual line in the global scope completely misses that. On the other hand, the compiler may also do some optimizations when multiple lines are put together, so it takes less time (see example below). You probably want to know how fast a line is inside your function, not by itself in the global scope, so you need a different tool, not @btime.

Counterexample where function runs faster than its individual lines run separately
julia> function f(x::Int)
         if x >= 0 # maybe x
           x = x/3 # becomes Float64
         end
         x + 1
       end
f (generic function with 1 method)

julia> @btime f($3)
  5.915 ns (0 allocations: 0 bytes)
2.0

julia> @btime $3 >= 0
  1.508 ns (0 allocations: 0 bytes)
true

julia> @btime $3 / 3
  5.032 ns (0 allocations: 0 bytes)
1.0

julia> @btime $1.0 + 1
  1.508 ns (0 allocations: 0 bytes)
2.0

julia> 5.915, 1.508+5.032+1.508
(5.915, 8.048)

julia> @btime f($(-3))
  1.806 ns (0 allocations: 0 bytes)
-2

julia> @btime $(-3) >= 0
  1.508 ns (0 allocations: 0 bytes)
false

julia> @btime $(-3) + 1
  1.509 ns (0 allocations: 0 bytes)
-2

julia> 1.806, 1.508+1.509
(1.806, 3.017)

Thanks!
For your 1st reason, push! ops does not change calculation flow, because elements in the Vectors never perform as inputs.
For your 2nd reason, I have make it sure that the function here is type-stable.
For this example, is very different from mine: mine is sum of the time every line is much less than once call without inner @btime.

Thanks, I will try now.

The result is:

 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
                                    Time                    Allocations      
                           โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
     Tot / % measured:          26.1s /   2.7%           2.84GiB /   0.3%    

 Section           ncalls     time    %tot     avg     alloc    %tot      avg
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 up_mask4               3    686ms   98.6%   229ms   9.26MiB   94.7%  3.09MiB
 up_bg9                 3   8.14ms    1.2%  2.71ms    441KiB    4.4%   147KiB
 atts_mask2             3    740ฮผs    0.1%   247ฮผs   44.0KiB    0.4%  14.7KiB
 atts_bg3               3    423ฮผs    0.1%   141ฮผs   44.0KiB    0.4%  14.7KiB
 conv_final_bg11        1    159ฮผs    0.0%   159ฮผs   4.67KiB    0.0%  4.67KiB
 encoder_outs1          3   7.05ฮผs    0.0%  2.35ฮผs     0.00B    0.0%    0.00B
 smr_outs5              3   5.77ฮผs    0.0%  1.92ฮผs     0.00B    0.0%    0.00B
 push7                  3   3.95ฮผs    0.0%  1.32ฮผs     80.0B    0.0%    26.7B
 push8                  3   3.71ฮผs    0.0%  1.24ฮผs     0.00B    0.0%    0.00B
 smr_outs6              3   3.51ฮผs    0.0%  1.17ฮผs     0.00B    0.0%    0.00B
 push10                 3   1.96ฮผs    0.0%   655ns     80.0B    0.0%    26.7B
 push12                 1    351ns    0.0%   351ns     0.00B    0.0%    0.00B
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

2 Likes

In fact, I want to reduce the execution time of this function to less than 60ms, i.e. a little faster than the pytorch version.

This will be a problem if you want optimization tips. Is the struct CoarseDecoder ... end block short enough for us to take a look?

struct SMRBlock{C<:Conv, U<:UpConv, S<:SelfAttentionSimple}
    upconv::U
    primary_mask::C
    self_calibrated::S
end
struct ECABlocks{C<:Conv,AMP<:Flux.AdaptiveMeanPool}
    conv::C
    avg_pool::AMP
end
struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
    up_conv::C
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    conv1::C
    conv2::Vector{CH}
    conv3::Vector{C}
    act::F
    concat::Bool
    residual::Bool
end
struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
    up_convs_bg::Vector{MBE}
    up_convs_mask::Vector{SMR}
    atts_mask::Vector{ECA}
    atts_bg::Vector{ECA}
    conv_final_bg::C
    use_att::Bool
end

The whole code, can be executed as follows:

using Flux: output_size
using Base: _before_colon, concatenate_setindex!
using Flux, BSON
using Revise
using BenchmarkTools
using InteractiveUtils
using TimerOutputs
const to = TimerOutput()

xshape(x) = eltype(x) <: AbstractArray ? (length(x), xshape(x[1])...) : [length(x)]
tovec(x) = eltype(x) <: AbstractArray ? tovec([xx_ for x_ in x for xx_ in x_]) : [x...]

struct DownConv{C1<:Conv,C2<:Conv, BI<:Union{BatchNorm,InstanceNorm},M<:MaxPool,F<:Function}
    conv1::C1
    conv2::Vector{C2}
    act::F
    norm1::BI
    bn::Vector{BI}
    residual::Bool
    pooling:: Bool
    pool::M
end
function DownConv(in_channels, out_channels, blocks; pooling=true, norm=Flux.BatchNorm, act=Flux.relu, residual=true, dilations=[])
    # norm = norm == "bn" ? Flux.BatchNorm : norm == "in" ? Flux.InstanceNorm : error("Unknown type:\t$norm")
    
    conv1 = Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    norm1 = norm(out_channels)
    dilations = length(dilations) == 0 ? [1 for i in 1:blocks] : dilations
    conv2 =[Conv((3,3), out_channels=>out_channels; dilation=dilations[i], pad=dilations[i]) 
        for i in 1:blocks ]

    bn = fill(norm(out_channels), blocks)
    
    pool = Flux.MaxPool((2, 2), stride=2)
    
    DownConv{typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)}(conv1, conv2, act, norm1, bn, residual, pooling, pool)
end

function (m::DownConv)(x)
    x1 = m.act.(m.norm1(m.conv1(x)))
    for (idx, conv) in enumerate(m.conv2)
        x2 = conv(x1)
        x2 = m.bn[idx](x2)
        x1 = m.residual ? x1+x2 : x2
        x1 = m.act.(x1)
    end
    before_pool = deepcopy(x1)
    x1 = m.pooling ? m.pool(x1) : x1

    x1, before_pool
end

Flux.@functor DownConv

struct UpConv{C<:Conv,BI<:Union{BatchNorm,InstanceNorm},F<:Function}
    up_conv::C
    conv1::C
    conv2::Vector{C}
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    act::F
    concat::Bool
    use_mask::Bool
    residual::Bool
end
function UpConv(in_channels, out_channels, blocks; residual=true, norm=Flux.BatchNorm, act=Flux.relu, concat=true, use_att=false, use_mask=false, dilations=[], out_fuse=false)
    up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    norm0 = norm(out_channels)
    if length(dilations)==0
        dilations = [1 for _ in 1:blocks]
    end
    if concat
        conv1 = Conv((3,3), (2*out_channels+(use_mask ? 1 : 0))=>out_channels; pad=1, bias=true)
        norm1 = norm(out_channels)
    else
        conv1 = Conv((3,3), out_channels=>out_channels; pad=1, bias=true)
        norm1 = norm(out_channels)
    end
    conv2 = [ Conv((3,3), out_channels=>out_channels; dilation = dilations[i], pad=dilations[i], bias=true)     for i in 1:blocks ]
    bn = [norm(out_channels) for _ in 1:blocks]
    UpConv{typeof(up_conv), typeof(norm0), typeof(act)}(up_conv, conv1, conv2, bn, norm0, norm1, act, concat, use_mask, residual)
end
struct OutFuse{v}
end
OutFuse(x) = OutFuse{x}()
function (m::UpConv)(::OutFuse{true}, from_up, from_down; mask=nothing, se=nothing)#::Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}
    from_up = @timeit to "4.1.1" m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))
    x1 = @timeit to "4.1.2" m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (!(from_down===nothing) ? from_up + from_down : from_up)
    xfuse = x1 = @timeit to "4.1.3" m.act.(m.norm1(m.conv1(x1)))
    for (idx, conv) in enumerate(m.conv2)
        x2 = @timeit to "4.1.4" conv(x1)
        x2 = @timeit to "4.1.5" m.bn[idx](x2)
        if !(se===nothing) && idx == length(m.conv2)
            x2 = se(x2)
        end
        if m.residual
            x2 = x2 + x1
        end
        x2 = @timeit to "4.1.6" m.act.(x2)
        x1 = x2
    end
    x1, xfuse
end

function (m::UpConv)(::OutFuse{false}, from_up, from_down; mask=nothing, se=nothing)#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
    from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))

    x1 = m.concat ? 
    (m.use_mask ? Base.cat(from_up, from_down, mask; dims=Val(3)) : Base.cat(from_up, from_down; dims=Val(3))) : 
    (!(from_down===nothing) ? from_up + from_down : from_up)

    x1 = m.act.(m.norm1(m.conv1(x1)))

    for (idx, conv) in enumerate(m.conv2)
        x2 = conv(x1)
        x2 = m.bn[idx](x2)
        if !(se===nothing) && idx == length(m.conv2)
            x2 = se(x2)
        end
        if m.residual
            x2 = x2 + x1
        end
        x2 = m.act.(x2)
        x1 = x2
    end
    x1
end


Flux.@functor UpConv

struct CFFBlock{C1<:Conv, C2<:Conv, D1<:DownConv, D2<:DownConv, D3<:DownConv, CH1<:Chain, CH2<:Chain}
    up32::C1
    up31::C2
    down1::D1
    down2::D2
    down3::D3
    conv22::CH1
    conv33::CH2
end

function CFFBlock(; down=DownConv, up=UpConv, ngf::Int=32) 
    p = [
        Conv((3,3), ngf*4 => ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
        Conv((3,3), ngf*4=>ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
        DownConv(ngf, ngf, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
        DownConv(ngf, ngf*2, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
        DownConv(ngf*2, ngf*4, 3; pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[1,2,5]),
        Chain(
            Conv((3,3), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
            Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
        ),
        Chain(
            Conv((3,3), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
            Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
        )
    ]
    CFFBlock(p...)
end

function (m::CFFBlock)(inputs)
    x1, x2, x3 = inputs
    # @show size(x1)
    # @show size(x2)
    # @show size(x3)
    #@show (size(x2)[2], size(x2)[1])
    #@show (size(x1)[2], size(x1)[1])
    x32 = Flux.upsample_bilinear(x3; size=(size(x2)[2], size(x2)[1]))
    x32 = m.up32(x32)
    x31 = Flux.upsample_bilinear(x3; size=(size(x1)[2], size(x1)[1]))
    x31 = m.up31(x31)
    # cross-connection
    x, d1 = m.down1(x1 + x31)
    x, d2 = m.down2(x + m.conv22(x2) + x32)
    #@show size(x)
    #@show size(x3)
    #@show size(m.conv33(x3))
    d3, _ = m.down3(x + m.conv33(x3))
    d1,d2,d3
end

Flux.@functor CFFBlock

struct ECABlocks{C<:Conv,AMP<:Flux.AdaptiveMeanPool}
    conv::C
    avg_pool::AMP
end

ECABlocks(channel, k_size=3) = ECABlocks(
        Conv((k_size,), 1=>1, sigmoid; pad=(k_size-1)รท2, bias=false),
        AdaptiveMeanPool((1,1))
    )
function (m::ECABlocks)(x)
    # h, w, c, b = size(x)

    y = m.avg_pool(x)
    y = Flux.unsqueeze(permutedims(m.conv(permutedims(reshape(y, size(y)[2:end]), (2,1,3))), (2, 1, 3)), dims=1)
    # y = Flux.sigmoid(y) ๅœจ่ฟ™้‡Œๅฏไปฅๆ”พๅˆฐconv้‡ŒๅŽป
    x .* y
end
Flux.@functor ECABlocks

struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
    up_conv::C
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    conv1::C
    conv2::Vector{CH}
    conv3::Vector{C}
    act::F
    concat::Bool
    residual::Bool
end

function MBEBlock(in_channels=512, out_channels=3; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true, is_final=true)
    up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    conv1 = Conv((3,3), (concat ? 2*out_channels : out_channels)=>out_channels; pad=1, bias=true)
    conv2 = Vector{Chain}()
    conv3 = Vector{Conv}()
    for i in 1:blocks
        push!(conv2, Chain(
            Conv((5,5), (out_channels รท 2 + 1)=>(out_channels รท 4), Flux.relu; stride=1, pad=2, bias=true),
            Conv((5,5), (out_channels รท 4)=>1, Flux.sigmoid_fast; stride=1, pad=2, bias=true)
        ))
        push!(conv3, Conv((3,3), (out_channels รท 2)=>out_channels; pad=1, bias=true))
    end
    bn = [norm(out_channels) for i in 1:blocks]

    MBEBlock(up_conv, bn, norm(out_channels), norm(out_channels), conv1, conv2, conv3, act, concat, residual)
end

function (m::MBEBlock)(from_up, from_down, mask=nothing)
    from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, 2.0f0))))
    if m.concat
        x1 = cat(from_up, from_down; dims=Val(3))
    elseif !(from_down === nothing)
        x1 = from_up + from_down
    else
        x1 = from_up
    end
    x1 = m.act.(m.norm1(m.conv1(x1)))
    #residual structure
    H, W, C, _ = size(x1)
    for (idx, (conv1, conv2)) in enumerate(zip(m.conv2, m.conv3))
        #@show size(x1)
        #@show size(x1[:,:,1:(C รท 2 ), :])
        #@show size(mask)
        mask = conv1(cat(x1[:,:,1:(C รท 2), :], mask; dims=Val(3)))
        x2_actv = x1[:,:,(C รท 2+1) : end, :] .*mask
        x2 = conv2(x1[:,:,(C รท 2+1) : end, :] + x2_actv)
        x2 = m.bn[idx](x2)
        if m.residual
            x = x2 + x1
        end
        x2 = m.act.(x2)
        x1 = x2
    end
    x1
end
Flux.@functor MBEBlock

struct SelfAttentionSimple{C<:Conv, CH<:Chain, AA<:AbstractArray{Float32, 4}}
    k_center::Int32
    q_conv::C
    k_conv::C
    v_conv::C
    sim_func::C
    out_conv::CH
    min_area::Float32
    threshold::Float32
    k_weight::AA

end

SelfAttentionSimple(in_channel, k_center) = SelfAttentionSimple(Int32(k_center), 
    Conv((1, 1), in_channel=>in_channel),
    Conv((1, 1), in_channel=>in_channel*k_center),
    Conv((1, 1), in_channel=>in_channel*k_center),
    Conv((1,1), (2*in_channel) => 1; stride=1, pad=0, bias=true),
    Chain(
        Conv((3,3), in_channel=>(in_channelรท8), Flux.relu; stride=1, pad=1),
        Conv((3,3), (in_channelรท8)=>1; stride=1, pad=1)
    ),
    100.0f0, 0.5f0,
    fill(1.0f0, (1, 1, k_center, 1))
)


function compute_attention(m::SelfAttentionSimple, query::T, key::T, mask::T, eps=1) where{T}#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
    h,w,c,b = size(query)
    #@show size(query)
    # @btime $m.q_conv($query)
    query = @timeit to "cmp_att_1" m.q_conv(query)
    key_in = key
    # @btime $m.k_conv($key_in)
    key = @timeit to "cmp_att_2" m.k_conv(key_in)
    # keys = [view(key, :,:,i:(i+c-1),:) for i in 1:c:size(key)[3]] 
    keys = @timeit to "cmp_att_3" Vector{T}([key[:,:,i:(i+c-1),:] for i in 1:c:size(key)[3]])
    # @btime eltype($mask).($mask .> $m.threshold)
    importance_map = @timeit to "cmp_att_4" eltype(mask).(mask .> m.threshold)
    # @btime sum($importance_map, dims=[1,2])
    s_area = @timeit to "cmp_att_5" sum(importance_map, dims=[1,2])
    # mask = s_area .>= m.min_area
    # s_area = s_area .* mask + m.min_area .* .!mask
    # @btime clamp_lo(one(eltype($s_area))*$m.min_area).($s_area)
    @timeit to "cmp_att_6" clamp!(s_area, m.min_area, 1.0f7) #clamp_lo(one(eltype(s_area))*m.min_area).(s_area)

    s_area = @timeit to "cmp_att_7" s_area[:, :, 1:1, :]#view(s_area, :, :, 1:1, :)

    @timeit to "cmp_att_8" if m.k_center != 2
        # @btime [sum(k .*$importance_map, dims=[1, 2]) ./ $s_area  for k in $keys]
        ks = Vector{T}() 
        for k in keys
            push!(ks, sum(k .*importance_map, dims=(Val(1), Val(2))) ./ s_area)  
        end
        keys = ks
        # keys = [sum(keys[1] .*importance_map, dims=[1, 2]) ./ s_area, sum(keys[2] .*importance_map, dims=[1, 2]) ./ s_area]
    else
        # @btime [sum($keys[1] .* $importance_map, dims=[1,2]) ./$s_area,
        # sum($keys[2] .*(one(eltype($importance_map)) .- $importance_map), dims=[1,2]) ./ (size($keys[2])[1] * size($keys[2])[2] .- $s_area .+ $eps),
        # ]
        keys = Vector{T}([sum(keys[1] .* importance_map, dims=[1,2]) ./s_area,
        sum(keys[2] .*(one(eltype(importance_map)) .- importance_map), dims=[1,2]) ./ (size(keys[2])[1] * size(keys[2])[2] .+ eps .- s_area)
        ])
    end

    f_query = query
    
    f_key =  @timeit to "cmp_att_9" [reshape(k, (1, 1, c, b)) .* fill!(typeof(f_query)(undef, size(f_query)), 1)  for k in keys]
    # f_key = [Flux.upsample_nearest(reshape(k, (1, 1, c, b)), size=(size(f_query)[1], size(f_query)[2])) for k in keys]
    # @btime [reshape(k, (1, 1, $c, $b)) for k in $keys]
    # f_key = [reshape(k, (1, 1, c, b)) for k in keys]
    # @btime [cat([k for i in 1:size($f_query)[1]]..., dims=1) for k in $f_key]
    # f_key = [cat([k for i in 1:size(f_query)[1]]..., dims=1) for k in f_key]
    # @btime [cat([k for i in 1:size($f_query)[2]]..., dims=2) for k in $f_key]
    # f_key = [cat([k for i in 1:size(f_query)[2]]..., dims=2) for k in f_key]
    attention_scores = Vector{T}()
    for k in f_key
        # @btime Flux.tanh_fast.(cat($f_query, $k, dims=3))
        combine_qk = @timeit to "cmp_att_10" Flux.tanh_fast.(cat(f_query, k, dims=Val(3)))
        # @btime $m.sim_func($combine_qk)
        sk = @timeit to "cmp_att_11" m.sim_func(combine_qk)
        @timeit to "cmp_att_12" push!(attention_scores, sk)
    end
    # @btime cat($attention_scores...; dims=3)
    # s = cat(attention_scores...; dims=Val(3))
    s =  @timeit to "cmp_att_13" reshape(mapreduce(Flux.flatten, vcat, attention_scores), size(attention_scores[1])[1], size(attention_scores[1])[2], sum([size(x)[3] for x in attention_scores]), size(attention_scores[1])[4])
    # @btime permutedims($s, (3,1,2,4))
    s = @timeit to "cmp_att_14" permutedims(s, (3,1,2,4))
    # @btime $m.v_conv($key_in)
    v =  @timeit to "cmp_att_15" m.v_conv(key_in)
    if m.k_center == 2
        # @btime sum($v[:, :, 1:$c-1, :] .* $importance_map, dims=[1,2]) ./ $s_area
        v_fg =  @timeit to "cmp_att_16" sum(v[:, :, 1:c-1, :] .* importance_map, dims=[1,2]) ./ s_area
        # @btime sum($v[:,:,$c:end, :] .* (1 .- $importance_map);dims=[1,2]) ./ (size($v)[1] * size($v)[2] .- $s_area .+ $eps)
        v_bg =  @timeit to "cmp_att_17" sum(v[:,:,c:end, :] .* (1 .- importance_map);dims=[1,2]) ./ (size(v)[1] * size(v)[2] .- s_area .+ eps)
        v = @timeit to "cmp_att_18" cat(v_fg, v_bg; dims=Val(3))
    else
        # @btime sum($v .* $importance_map, dims=[1,2]) ./ $s_area
        v = @timeit to "cmp_att_19" sum(v .* importance_map, dims=[1,2]) ./ s_area
    end
    v = @timeit to "cmp_att_20" reshape(v, (c, m.k_center, b))
    #@show size(s), (m.k_center, h*w, b)
    #@show size(v)
    # @btime permutedims(reshape(Flux.batched_mul($v, reshape($s, ($m.k_center, $h*$w, $b))), ($c,$h,$w,$b)), (2, 3, 1,4))
    attn = @timeit to "cmp_att_21" permutedims(reshape(Flux.batched_mul(v, reshape(s, (m.k_center, h*w, b))), (c,h,w,b)), (2, 3, 1,4))
    #@show size(query)
    #@show size(attn)
    # @btime $m.out_conv($attn+$query)
    @timeit to "cmp_att_22" m.out_conv(attn + query)
end

function (m::SelfAttentionSimple)(xin, xout, xmask)
    h, w, c, b_num= size(xin)
    attention_score = compute_attention(m, xin, xout, xmask)
    attention_score = reshape(attention_score, (h, w, 1, b_num))
    return xout, Flux.sigmoid.(attention_score)
end
Flux.@functor SelfAttentionSimple

struct SMRBlock{C<:Conv, U<:UpConv, S<:SelfAttentionSimple}
    upconv::U
    primary_mask::C
    self_calibrated::S
end

SMRBlock(ins, outs, k_center; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true) = SMRBlock(
    UpConv(ins, outs, blocks; residual=residual, concat=concat, norm=norm, act=act),
    Conv((1,1), outs=>1, Flux.sigmoid_fast; stride=1, pad=0, bias=true),
    SelfAttentionSimple(outs, k_center)
)

function (m::SMRBlock)(input, encoder_outs=nothing)
    # @btime $m.upconv(OutFuse(true), $input, $encoder_outs)
    mask_x, fuse_x = @timeit to "4.1" m.upconv(OutFuse(true), input, encoder_outs)
    # @btime $m.primary_mask($mask_x)
    primary_mask = @timeit to "4.2" m.primary_mask(mask_x)
    # # @btime $m.self_calibrated($mask_x, $mask_x, $primary_mask)
    mask_x, self_calibrated_mask = @timeit to "4.3" m.self_calibrated(mask_x, mask_x, primary_mask)
    return Dict(
        "feats"=>[mask_x],
        "attn_maps"=>[primary_mask, self_calibrated_mask]
    )
end
Flux.@functor SMRBlock

struct CoarseEncoder{D<:DownConv}
    down_convs::Vector{D}
end
function CoarseEncoder(in_channels::Int=3, depth::Int=3; blocks=1, start_filters=32, residual=true, norm=Flux.BatchNorm, act=Flux.relu)
    down_convs = []
    outs = nothing
    if isa(blocks, Tuple)
        blocks = blocks[0]
    end

    for i in 1:depth
        ins = i==1 ? in_channels : outs
        outs = start_filters*(2^(i-1))
        pooling = true
        # #@show ins, depth
        push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act))
    end
    CoarseEncoder{DownConv}(down_convs)
end

function (m::CoarseEncoder)(x)
    nx = x
    encoder_outs = Vector{typeof(x)}()
    for d_conv in m.down_convs
        nx, before_pool = d_conv(nx)
        push!(encoder_outs, before_pool)
    end
    nx, encoder_outs
end
Flux.@functor CoarseEncoder

struct SharedBottleNeck{U<:UpConv, D<:DownConv, ECA1<:ECABlocks, ECA2<:ECABlocks}
    up_convs::Vector{U}
    down_convs::Vector{D}
    up_im_atts::Vector{ECA1}
    up_mask_atts::Vector{ECA2}
end
function SharedBottleNeck(in_channels=512, depth=5, shared_depth=2; start_filters=32, blocks=1, residual=true, concat=true, norm=Flux.BatchNorm, act=Flux.relu, dilations=[1,2,5])
    start_depth = depth - shared_depth
    max_filters = 512
    down_convs = Vector{DownConv}()
    up_convs = Vector{UpConv}()
    up_im_atts = Vector{ECABlocks}()
    up_mask_atts = Vector{ECABlocks}()
    outs = 0
    for i in start_depth:depth-1
        # #@show i, start_depth
        ins = i == start_depth ? in_channels : outs 
        outs = min(ins*2, max_filters)
        # encoder convs
        pooling = i<depth-1 ? true : false
        push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act, dilations=dilations))
        # decoder convs
        if i < depth - 1
            up_conv = UpConv(min(outs*2, max_filters), outs, blocks, residual=residual, concat=concat, norm=norm, act=Flux.relu, dilations=dilations)
            push!(up_convs, up_conv)
            push!(up_im_atts, ECABlocks(outs))
            push!(up_mask_atts, ECABlocks(outs))
        end
    end
    @show eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)
    SharedBottleNeck{eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)}(up_convs, down_convs, up_im_atts, up_mask_atts)
end

function (m::SharedBottleNeck)(input)
    # encoder convs
    im_encoder_outs = Vector{typeof(input)}()
    mask_encoder_outs = Vector{typeof(input)}()
    x = input
    for (i, d_conv) in enumerate(m.down_convs)
        x, before_pool = d_conv(x)
        push!(im_encoder_outs, before_pool)
        push!(mask_encoder_outs, before_pool)
    end
    x_im = x

    x_mask = x
    #@show size(x_mask)
    # Decoder convs
    x = x_im
    for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_im_atts))) in enumerate(zip(m.up_convs, m.up_im_atts))
        before_pool = im_encoder_outs === nothing ? nothing : im_encoder_outs[end-i]
        x = up_conv(OutFuse(false), x, before_pool, se=attn)
    end
    x_im = x

    x = x_mask
    for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_mask_atts))) in enumerate(zip(m.up_convs, m.up_mask_atts))
        before_pool = mask_encoder_outs === nothing ? nothing : mask_encoder_outs[end-i]
        x = up_conv(OutFuse(false), x, before_pool, se=attn)
    end
    x_mask = x

    x_im, x_mask
end
Flux.@functor SharedBottleNeck

struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
    up_convs_bg::Vector{MBE}
    up_convs_mask::Vector{SMR}
    atts_mask::Vector{ECA}
    atts_bg::Vector{ECA}
    conv_final_bg::C
    use_att::Bool
end

function CoarseDecoder(in_channels=512, out_channels=3, k_center=2; norm=Flux.BatchNorm, act=Flux.relu, depth=5, blocks=1, residual=true, concat=true, use_att=false)
    up_convs_bg = Vector{MBEBlock}()
    up_convs_mask = Vector{SMRBlock}()
    atts_bg = Vector{ECABlocks}()
    atts_mask = Vector{ECABlocks}()
    outs = in_channels
    for i in 1:depth
        ins = outs 
        outs = ins รท 2
        # background reconstruction branch
        up_conv = MBEBlock(ins, outs; blocks=blocks, residual=residual, concat=concat, norm=norm, act=act)
        push!(up_convs_bg, up_conv)
        if use_att
            push!(atts_bg, ECABlocks(outs))
        end
        #mask prediction branch
        up_conv = SMRBlock(ins, outs, k_center; norm=norm, act=act, blocks=blocks, residual=residual, concat=concat)
        push!(up_convs_mask, up_conv)
        if use_att
            push!(atts_mask, ECABlocks(outs))
        end
    end
    conv_final_bg = Conv((1,1), outs=>out_channels, stride=1, pad=0, bias=true)
    CoarseDecoder{typeof(conv_final_bg),eltype(up_convs_bg),eltype(up_convs_mask), eltype(atts_mask)}(up_convs_bg, up_convs_mask, atts_mask, atts_bg, 
        conv_final_bg,
        use_att
    )
end

function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
    bg_x = bg
    mask_x = mask
    mask_outs = Vector{T}()
    bg_outs = Vector{T}()
    for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
        # @btime before_pool = $encoder_outs[end-($i-1)]
        before_pool = @timeit to "encoder_outs1" encoder_outs[end-(i-1)] #encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)]
        if m.use_att
            # @btime $m.atts_mask[$i]($before_pool)
            mask_before_pool = @timeit to "atts_mask2" m.atts_mask[i](before_pool)
            # @btime $m.atts_bg[$i]($before_pool)
            bg_before_pool = @timeit to "atts_bg3" m.atts_bg[i](before_pool)
        end
        # @btime $up_mask($mask_x,$mask_before_pool)
        # @code_warntype up_mask(mask_x, mask_before_pool)
        smr_outs = @timeit to "up_mask4" up_mask(mask_x, mask_before_pool)
        # @btime mask_x = $smr_outs["feats"][1]
        mask_x = @timeit to "smr_outs5" smr_outs["feats"][1]
        # @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
        primary_map, self_calibrated_map = @timeit to "smr_outs6" smr_outs["attn_maps"]
        # @btime push!($mask_outs, $primary_map)
        @timeit to "push7" push!(mask_outs, primary_map)
        # @btime push!($mask_outs, $self_calibrated_map)
        @timeit to "push8" push!(mask_outs, self_calibrated_map)
        # @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
        bg_x = @timeit to "up_bg9" up_bg(bg_x, bg_before_pool, self_calibrated_map) # ่ฟ™้‡Œๅฏ่ƒฝๆœ‰้—ฎ้ข˜
        # @btime push!($bg_outs, $bg_x)
        @timeit to "push10" push!(bg_outs, bg_x)
    end
    if m.conv_final_bg !== nothing
        # @btime $m.conv_final_bg($bg_x)
        bg_x = @timeit to "conv_final_bg11" m.conv_final_bg(bg_x)
        # @btime push!($bg_outs, $bg_x)
        @timeit to "push12" push!(bg_outs, bg_x)
    end
    #@show length(bg_outs)
    #@show length(mask_outs)
    return bg_outs, mask_outs, nothing
end
Flux.@functor CoarseDecoder

struct Refinement{C<:Conv, CH1<:Chain, CH2<:Chain, CH3<:Chain, CH4<:Chain,D1<:DownConv,D2<:DownConv,D3<:DownConv,CFF<:CFFBlock}
    conv_in::CH1
    dec_conv2::C
    dec_conv3::CH2
    dec_conv4::CH3
    down1::D1
    down2::D2
    down3::D3
    cff_blocks::Vector{CFF}
    out_conv::CH4
    n_skips::Int
end

Refinement(;in_channels=3, out_channels=3, shared_depth=2, down=DownConv, up=UpConv, ngf=32, n_cff=3, n_skips=3) = Refinement(
    Chain(
        Conv((3,3), in_channels=>ngf; stride=1, pad=1, bias=true),
        Flux.InstanceNorm(ngf),
        x->Flux.leakyrelu.(x, eltype(x)(0.2))
    ),
    Conv((1,1), ngf=>ngf; stride=1, pad=0, bias=true),
    Chain(
        Conv((1,1), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
        Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
    ),
    Chain(
        Conv((1,1), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
        Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
    ),
    down(ngf, ngf, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[]),
    down(ngf, ngf*2, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[]),
    down(ngf*2, ngf*4, 3, pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[1,2,5]),
    [CFFBlock(;ngf=ngf) for i in 1:n_cff],
    Chain(
        Conv((3,3), (ngf+ngf*2+ngf*4)=>ngf; stride=1, pad=1, bias=true),
        Flux.InstanceNorm(ngf),
        x->Flux.leakyrelu.(x, eltype(x)(0.2)),
        Conv((1,1), ngf=>out_channels; stride=1, pad=0)
    ),
    n_skips
)

function (m::Refinement)(input, coarse_bg, mask, encoder_outs, decoder_outs) 
    xin = cat(coarse_bg, mask, dims=Val(3))
    # @btime $m.conv_in($xin)
    x = m.conv_in(xin)
    # @btime $m.dec_conv2($decoder_outs[1])
    m.n_skips < 1 && (x += m.dec_conv2(decoder_outs[1]))
    # @btime $m.down1($x)
    x,d1 = m.down1(x)
    # @btime $m.dec_conv3($decoder_outs[2])
    m.n_skips < 2 && (x += m.dec_conv3(decoder_outs[2]))
    # @btime $m.down2($x)
    x,d2 = m.down2(x)
    # @btime $m.dec_conv4($decoder_outs[3])
    m.n_skips < 3 && (x += m.dec_conv4(decoder_outs[3]))
    # @btime $m.down3($x)
    x,d3 = m.down3(x)

    xs = [d1, d2, d3]
    for block in m.cff_blocks
        # @btime $block($xs)
        xs = block(xs)
    end
    # @btime [Flux.upsample_bilinear(x_hr; size=(size($coarse_bg)[2], size($coarse_bg)[1])) for x_hr in $xs]
    xs = [Flux.upsample_bilinear(x_hr; size=(size(coarse_bg)[2], size(coarse_bg)[1])) for x_hr in xs]
    # @btime $m.out_conv(cat($xs..., dims=3))
    im = m.out_conv(cat(xs..., dims=Val(3)))
    im
end
Flux.@functor Refinement

struct SLBR
    encoder::CoarseEncoder
    shared_decoder::SharedBottleNeck
    coarse_decoder::CoarseDecoder
    refinement::Refinement
    long_skip::Bool
end

SLBR(; in_channels=3, depth=5, shared_depth=2, blocks=[1 for i in 1:5], out_channels_image=3, out_channels_mask=1, start_filters=32, residual=true, concat=true, long_skip=false, k_refine=3, n_skips=3, k_center=3) = SLBR(
    CoarseEncoder(in_channels, depth-shared_depth; blocks=blocks[1], start_filters=start_filters, residual=residual, norm=Flux.BatchNorm, act=Flux.relu),
    SharedBottleNeck(start_filters*2^(depth-shared_depth-1), depth, shared_depth; blocks=blocks[5], residual=residual, concat=concat, norm=Flux.InstanceNorm),
    CoarseDecoder(start_filters*2^(depth-shared_depth), out_channels_image, k_center; depth=depth-shared_depth, blocks=blocks[2], residual=residual, concat=concat, norm=Flux.BatchNorm, use_att=true),
    Refinement(; in_channels=4, out_channels=3, shared_depth=1, n_cff=k_refine, n_skips=n_skips),
    long_skip
)

function (m::SLBR)(synthesized)
    # @btime $m.encoder($synthesized) # (type stablity)-> 970.384 ฮผs (2239 allocations: 109.50 KiB)
    image_code, before_pool = m.encoder(synthesized)
    unshared_before_pool = before_pool

    # @code_warntype m.shared_decoder(image_code) # 2.848 ms (7389 allocations: 467.92 KiB) (type stablity)-> 2.858 ms (7309 allocations: 464.67 KiB)
    # @btime $m.shared_decoder($image_code)
    im, mask = m.shared_decoder(image_code)
    # @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool) # 233.433 ms (24298 allocations: 1.17 MiB) (type stablity)->234.892 ms (23475 allocations: 1.15 MiB)
    # @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool)
    ims, mask, wm = m.coarse_decoder(im, nothing, mask, unshared_before_pool)
    
    im = ims[end]

    reconstructed_image = Flux.tanh_fast.(im)

    if m.long_skip
        reconstructed_image = reconstructed_image + synthesized
        reconstructed_image = clamp.(reconstructed_image, zero(eltype(reconstructed_image)), one(eltype(reconstructed_image)))
    end
    reconstructed_mask = mask[end]
    reconstruct_wm = wm

    dec_feats = reverse(ims[1:end-1])
    #@show eltype(reconstructed_image)
    #@show eltype(reconstructed_mask)
    #@show eltype(synthTimerOutPutsesized)
    coarser = reconstructed_image .* reconstructed_mask + (one(eltype(reconstructed_mask)) .- reconstructed_mask) .* synthesized
    # @btime m.refinement($synthesized, $coarser, $reconstructed_mask, nothing, $dec_feats) # 664.924 ms (24094 allocations: 1.45 MiB)
    refine_bg = m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)

    refine_bg = clamp.(Flux.tanh_fast.(refine_bg ) + synthesized, zero(eltype(refine_bg)), one(eltype(refine_bg)))

    return [refine_bg, reconstructed_image], mask, [reconstruct_wm]
end
Flux.@functor SLBR

m = SLBR(; shared_depth=2, blocks=[3 for i in 1:5], long_skip=true, k_center=2) |> gpu

# BSON.@save "model.bson" m

x = rand(Float32, 720,720, 3, 1) |> gpu

# @btime m(x) 
out = m(x);
reset_timer!(to)
out = m(x);
@show length(out)
@show size(out[1][1])
@show size(out[2][end])

show(to)

The outputs with TimerOutputs looks unstable:

julia> include("slbr.jl")
WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.
(eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)) = (UpConv, DownConv, ECABlocks, ECABlocks)
length(out) = 3
size((out[1])[1]) = (720, 720, 3, 1)
size((out[2])[end]) = (720, 720, 1, 1)
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
                                    Time                    Allocations      
                           โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
     Tot / % measured:          1.31s /  32.5%           4.16MiB /  35.8%    

 Section           ncalls     time    %tot     avg     alloc    %tot      avg
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 up_bg9                 3    224ms   52.3%  74.6ms    614KiB   40.2%   205KiB
 up_mask4               3    203ms   47.5%  67.7ms    809KiB   53.0%   270KiB
   4.3                  3    201ms   47.0%  67.1ms    595KiB   39.0%   198KiB
     cmp_att_8          3    197ms   46.0%  65.7ms   92.2KiB    6.0%  30.7KiB
     cmp_att_11         6    497ฮผs    0.1%  82.9ฮผs   39.6KiB    2.6%  6.60KiB
     cmp_att_22         3    497ฮผs    0.1%   166ฮผs   50.2KiB    3.3%  16.7KiB
     cmp_att_2          3    470ฮผs    0.1%   157ฮผs   19.4KiB    1.3%  6.46KiB
     cmp_att_15         3    399ฮผs    0.1%   133ฮผs   19.4KiB    1.3%  6.46KiB
     cmp_att_1          3    325ฮผs    0.1%   108ฮผs   19.3KiB    1.3%  6.43KiB
     cmp_att_3          3    320ฮผs    0.1%   107ฮผs   26.5KiB    1.7%  8.84KiB
     cmp_att_10         6    318ฮผs    0.1%  52.9ฮผs   78.8KiB    5.2%  13.1KiB
     cmp_att_16         3    278ฮผs    0.1%  92.6ฮผs   42.6KiB    2.8%  14.2KiB
     cmp_att_17         3    222ฮผs    0.1%  73.9ฮผs   49.8KiB    3.3%  16.6KiB
     cmp_att_21         3    160ฮผs    0.0%  53.3ฮผs   9.23KiB    0.6%  3.08KiB
     cmp_att_13         3    124ฮผs    0.0%  41.4ฮผs   20.7KiB    1.4%  6.90KiB
     cmp_att_9          3    119ฮผs    0.0%  39.6ฮผs   22.8KiB    1.5%  7.59KiB
     cmp_att_18         3   99.0ฮผs    0.0%  33.0ฮผs   29.5KiB    1.9%  9.84KiB
     cmp_att_5          3   98.4ฮผs    0.0%  32.8ฮผs   19.0KiB    1.2%  6.34KiB
     cmp_att_14         3   54.1ฮผs    0.0%  18.0ฮผs   8.33KiB    0.5%  2.78KiB
     cmp_att_4          3   46.3ฮผs    0.0%  15.4ฮผs   9.14KiB    0.6%  3.05KiB
     cmp_att_7          3   45.4ฮผs    0.0%  15.1ฮผs   12.9KiB    0.8%  4.31KiB
     cmp_att_6          3   39.9ฮผs    0.0%  13.3ฮผs   6.75KiB    0.4%  2.25KiB
     cmp_att_12         6   5.87ฮผs    0.0%   978ns      240B    0.0%    40.0B
     cmp_att_20         3   2.36ฮผs    0.0%   788ns      192B    0.0%    64.0B
   4.1                  3   1.84ms    0.4%   615ฮผs    190KiB   12.4%  63.3KiB
     4.1.4              9    536ฮผs    0.1%  59.6ฮผs   57.9KiB    3.8%  6.43KiB
     4.1.1              3    293ฮผs    0.1%  97.7ฮผs   29.5KiB    1.9%  9.84KiB
     4.1.5              9    291ฮผs    0.1%  32.4ฮผs   8.02KiB    0.5%     912B
     4.1.3              3    250ฮผs    0.1%  83.2ฮผs   23.5KiB    1.5%  7.82KiB
     4.1.6              9    136ฮผs    0.0%  15.1ฮผs   4.36KiB    0.3%     496B
     4.1.2              3   98.5ฮผs    0.0%  32.8ฮผs   30.1KiB    2.0%  10.0KiB
   4.2                  3    152ฮผs    0.0%  50.7ฮผs   19.8KiB    1.3%  6.60KiB
 atts_mask2             3    327ฮผs    0.1%   109ฮผs   48.8KiB    3.2%  16.3KiB
 atts_bg3               3    251ฮผs    0.1%  83.7ฮผs   48.8KiB    3.2%  16.3KiB
 conv_final_bg11        1   48.8ฮผs    0.0%  48.8ฮผs   6.39KiB    0.4%  6.39KiB
 smr_outs6              3   3.71ฮผs    0.0%  1.24ฮผs     0.00B    0.0%    0.00B
 encoder_outs1          3   3.39ฮผs    0.0%  1.13ฮผs     0.00B    0.0%    0.00B
 push8                  3   2.99ฮผs    0.0%   996ns     0.00B    0.0%    0.00B
 smr_outs5              3   2.36ฮผs    0.0%   785ns     0.00B    0.0%    0.00B
 push7                  3   1.85ฮผs    0.0%   617ns     80.0B    0.0%    26.7B
 push10                 3   1.27ฮผs    0.0%   425ns     80.0B    0.0%    26.7B
 push12                 1    201ns    0.0%   201ns     0.00B    0.0%    0.00B
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
julia> include("slbr.jl")
WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.
(eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)) = (UpConv, DownConv, ECABlocks, ECABlocks)
length(out) = 3
size((out[1])[1]) = (720, 720, 3, 1)
size((out[2])[end]) = (720, 720, 1, 1)
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
                                    Time                    Allocations      
                           โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
     Tot / % measured:          1.12s /  44.8%           4.14MiB /  36.0%    

 Section           ncalls     time    %tot     avg     alloc    %tot      avg
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 up_mask4               3    495ms   98.9%   165ms    844KiB   55.4%   281KiB
   4.1                  3    263ms   52.7%  87.8ms    230KiB   15.1%  76.6KiB
     4.1.4              9    262ms   52.5%  29.1ms   97.9KiB    6.4%  10.9KiB
     4.1.1              3    316ฮผs    0.1%   105ฮผs   29.5KiB    1.9%  9.84KiB
     4.1.3              3    261ฮผs    0.1%  87.0ฮผs   23.5KiB    1.5%  7.82KiB
     4.1.5              9    118ฮผs    0.0%  13.1ฮผs   8.02KiB    0.5%     912B
     4.1.2              3    104ฮผs    0.0%  34.7ฮผs   30.1KiB    2.0%  10.0KiB
     4.1.6              9   85.2ฮผs    0.0%  9.47ฮผs   4.36KiB    0.3%     496B
   4.3                  3    231ms   46.2%  77.0ms    590KiB   38.7%   197KiB
     cmp_att_15         3    227ms   45.5%  75.8ms   38.0KiB    2.5%  12.7KiB
     cmp_att_22         3    680ฮผs    0.1%   227ฮผs   50.2KiB    3.3%  16.7KiB
     cmp_att_11         6    408ฮผs    0.1%  68.1ฮผs   39.6KiB    2.6%  6.60KiB
     cmp_att_10         6    301ฮผs    0.1%  50.2ฮผs   78.8KiB    5.2%  13.1KiB
     cmp_att_8          3    276ฮผs    0.1%  91.9ฮผs   68.5KiB    4.5%  22.8KiB
     cmp_att_16         3    257ฮผs    0.1%  85.6ฮผs   42.6KiB    2.8%  14.2KiB
     cmp_att_17         3    255ฮผs    0.1%  84.9ฮผs   49.8KiB    3.3%  16.6KiB
     cmp_att_2          3    222ฮผs    0.0%  74.1ฮผs   19.4KiB    1.3%  6.46KiB
     cmp_att_1          3    202ฮผs    0.0%  67.3ฮผs   19.3KiB    1.3%  6.43KiB
     cmp_att_3          3    143ฮผs    0.0%  47.7ฮผs   26.5KiB    1.7%  8.84KiB
     cmp_att_18         3    134ฮผs    0.0%  44.8ฮผs   29.5KiB    1.9%  9.84KiB
     cmp_att_13         3    118ฮผs    0.0%  39.3ฮผs   20.7KiB    1.4%  6.90KiB
     cmp_att_5          3    107ฮผs    0.0%  35.8ฮผs   19.0KiB    1.2%  6.34KiB
     cmp_att_9          3    103ฮผs    0.0%  34.3ฮผs   22.8KiB    1.5%  7.59KiB
     cmp_att_21         3   94.2ฮผs    0.0%  31.4ฮผs   9.23KiB    0.6%  3.08KiB
     cmp_att_14         3   48.3ฮผs    0.0%  16.1ฮผs   8.33KiB    0.5%  2.78KiB
     cmp_att_7          3   47.3ฮผs    0.0%  15.8ฮผs   12.9KiB    0.8%  4.31KiB
     cmp_att_4          3   47.0ฮผs    0.0%  15.7ฮผs   9.14KiB    0.6%  3.05KiB
     cmp_att_6          3   43.6ฮผs    0.0%  14.5ฮผs   6.75KiB    0.4%  2.25KiB
     cmp_att_12         6   26.4ฮผs    0.0%  4.40ฮผs      240B    0.0%    40.0B
     cmp_att_20         3   2.70ฮผs    0.0%   899ns      192B    0.0%    64.0B
   4.2                  3    168ฮผs    0.0%  56.0ฮผs   19.8KiB    1.3%  6.60KiB
 up_bg9                 3   4.67ms    0.9%  1.56ms    576KiB   37.8%   192KiB
 atts_mask2             3    321ฮผs    0.1%   107ฮผs   48.8KiB    3.2%  16.3KiB
 atts_bg3               3    250ฮผs    0.0%  83.3ฮผs   48.8KiB    3.2%  16.3KiB
 conv_final_bg11        1   74.1ฮผs    0.0%  74.1ฮผs   6.39KiB    0.4%  6.39KiB
 encoder_outs1          3   4.27ฮผs    0.0%  1.42ฮผs     0.00B    0.0%    0.00B
 smr_outs5              3   2.71ฮผs    0.0%   902ns     0.00B    0.0%    0.00B
 smr_outs6              3   1.81ฮผs    0.0%   602ns     0.00B    0.0%    0.00B
 push8                  3   1.79ฮผs    0.0%   597ns     0.00B    0.0%    0.00B
 push7                  3   1.76ฮผs    0.0%   586ns     80.0B    0.0%    26.7B
 push10                 3   1.74ฮผs    0.0%   580ns     80.0B    0.0%    26.7B
 push12                 1    299ns    0.0%   299ns     0.00B    0.0%    0.00B
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
julia> include("slbr.jl")
WARNING: redefinition of constant to. This may fail, cause incorrect answers, or produce other errors.
(eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)) = (UpConv, DownConv, ECABlocks, ECABlocks)
length(out) = 3
size((out[1])[1]) = (720, 720, 3, 1)
size((out[2])[end]) = (720, 720, 1, 1)
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
                                    Time                    Allocations      
                           โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
     Tot / % measured:          1.41s /  35.4%           4.16MiB /  35.8%    

 Section           ncalls     time    %tot     avg     alloc    %tot      avg
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 up_bg9                 3    255ms   51.3%  85.0ms    616KiB   40.3%   205KiB
 up_mask4               3    242ms   48.6%  80.6ms    807KiB   52.9%   269KiB
   4.3                  3    240ms   48.3%  80.1ms    593KiB   38.9%   198KiB
     cmp_att_9          3    236ms   47.5%  78.8ms   44.5KiB    2.9%  14.8KiB
     cmp_att_11         6    596ฮผs    0.1%  99.4ฮผs   39.6KiB    2.6%  6.60KiB
     cmp_att_22         3    549ฮผs    0.1%   183ฮผs   50.2KiB    3.3%  16.7KiB
     cmp_att_10         6    383ฮผs    0.1%  63.8ฮผs   78.8KiB    5.2%  13.1KiB
     cmp_att_15         3    279ฮผs    0.1%  92.8ฮผs   19.4KiB    1.3%  6.46KiB
     cmp_att_8          3    256ฮผs    0.1%  85.4ฮผs   68.5KiB    4.5%  22.8KiB
     cmp_att_17         3    249ฮผs    0.0%  82.9ฮผs   49.8KiB    3.3%  16.6KiB
     cmp_att_16         3    235ฮผs    0.0%  78.2ฮผs   42.6KiB    2.8%  14.2KiB
     cmp_att_2          3    204ฮผs    0.0%  68.0ฮผs   19.4KiB    1.3%  6.46KiB
     cmp_att_1          3    196ฮผs    0.0%  65.3ฮผs   19.3KiB    1.3%  6.43KiB
     cmp_att_13         3    135ฮผs    0.0%  44.9ฮผs   20.7KiB    1.4%  6.90KiB
     cmp_att_3          3    121ฮผs    0.0%  40.2ฮผs   26.5KiB    1.7%  8.84KiB
     cmp_att_18         3    115ฮผs    0.0%  38.3ฮผs   29.5KiB    1.9%  9.84KiB
     cmp_att_5          3    101ฮผs    0.0%  33.5ฮผs   19.0KiB    1.2%  6.34KiB
     cmp_att_21         3   86.4ฮผs    0.0%  28.8ฮผs   9.23KiB    0.6%  3.08KiB
     cmp_att_14         3   56.6ฮผs    0.0%  18.9ฮผs   8.33KiB    0.5%  2.78KiB
     cmp_att_4          3   47.1ฮผs    0.0%  15.7ฮผs   9.14KiB    0.6%  3.05KiB
     cmp_att_7          3   45.0ฮผs    0.0%  15.0ฮผs   12.9KiB    0.8%  4.31KiB
     cmp_att_6          3   39.8ฮผs    0.0%  13.3ฮผs   6.75KiB    0.4%  2.25KiB
     cmp_att_12         6   4.62ฮผs    0.0%   770ns      240B    0.0%    40.0B
     cmp_att_20         3   2.62ฮผs    0.0%   872ns      192B    0.0%    64.0B
   4.1                  3   1.51ms    0.3%   504ฮผs    190KiB   12.4%  63.3KiB
     4.1.4              9    552ฮผs    0.1%  61.3ฮผs   57.9KiB    3.8%  6.43KiB
     4.1.1              3    313ฮผs    0.1%   104ฮผs   29.5KiB    1.9%  9.84KiB
     4.1.3              3    257ฮผs    0.1%  85.6ฮผs   23.5KiB    1.5%  7.82KiB
     4.1.2              3    115ฮผs    0.0%  38.4ฮผs   30.1KiB    2.0%  10.0KiB
     4.1.5              9   89.6ฮผs    0.0%  10.0ฮผs   8.02KiB    0.5%     912B
     4.1.6              9   69.3ฮผs    0.0%  7.70ฮผs   4.36KiB    0.3%     496B
   4.2                  3    147ฮผs    0.0%  49.0ฮผs   19.8KiB    1.3%  6.60KiB
 atts_mask2             3    335ฮผs    0.1%   112ฮผs   48.8KiB    3.2%  16.3KiB
 atts_bg3               3    263ฮผs    0.1%  87.6ฮผs   48.8KiB    3.2%  16.3KiB
 conv_final_bg11        1   60.8ฮผs    0.0%  60.8ฮผs   6.39KiB    0.4%  6.39KiB
 encoder_outs1          3   3.56ฮผs    0.0%  1.19ฮผs     0.00B    0.0%    0.00B
 push7                  3   3.25ฮผs    0.0%  1.08ฮผs     80.0B    0.0%    26.7B
 push8                  3   3.03ฮผs    0.0%  1.01ฮผs     0.00B    0.0%    0.00B
 smr_outs5              3   2.95ฮผs    0.0%   984ns     0.00B    0.0%    0.00B
 smr_outs6              3   2.24ฮผs    0.0%   746ns     0.00B    0.0%    0.00B
 push10                 3   1.54ฮผs    0.0%   514ns     80.0B    0.0%    26.7B
 push12                 1    223ns    0.0%   223ns     0.00B    0.0%    0.00B
 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
julia> 

To make multiple lines work together, the function may be doing some work between the lines.

Thatโ€™s true. As the output with TimerOutputs changes again and again.
But, why and how to fix it?

I personally do not know, your code is very large and Iโ€™m not familiar with Flux. The reason I asked about your struct CoarseDecoder was to check the type annotations, and you do take care to annotate every field so they have a good chance of being concretely typed. To verify that type inference is going your way, you could use @code_warntype on a test method to check if all the variables are inferred. TimerOutputs does seem varying at a glance but hopefully it did narrow down to some methods that are taking the most time, focus your efforts on those.