Help to extend the model in Metalhead.jl

I implemented a self-supervised learning algorithm, MAE (Masked Auto Encoder), based on Flux.jl and Metalhead.jl, It seems to work, especially when propagating forward, but during gradient descent, Zygote alerted me that the array had been mutated in my code.

My code is as follows:

the Decoder and contruction

struct MaskDecoder
    projection
    mask_token
    decoder
end

@functor MaskDecoder
"""
    MaskDecoder(
        imsize::NTuple{N,Int64,},
        patch_size::NTuple{N,Int64},
        encodplanes::Integer,
        pretrain::Bool=false,
        outchannels::Integer=3;
        embedplanes::Integer=512,
        depth::Integer=8, nheads::Integer=16
    )

 A decoder for Masked Auto Encoder

# Arguments

- `imsize`: the input image size.
- `patch_size`: the size of the patch.
- `encodplanes`: the number of channels of ViT encoder's output.
- `pretrain`: whether to load pretrained paramters.
- `outchannels`: the number of channels in the recovered image.
- `embedplanes`: the number of channels projected to decoder transformer.
- `depth`: the number of blocks in decoder transformer.
- `nheads`: number of attention heads in transformer.
"""
function MaskDecoder(
    imsize::NTuple{N,Int64,},
    patch_size::NTuple{N,Int64},
    encodplanes::Integer,
    pretrain::Bool=false,
    outchannels::Integer=3;
    embedplanes::Integer=512,
    depth::Integer=8, nheads::Integer=16
) where {N}
    npatches = prod(imsize .÷ patch_size) + 1
    model = MaskDecoder(
        Dense(encodplanes, embedplanes),
        zeros32(embedplanes, 1, 1), # the <MSK>
        Chain(
            ViPosEmbedding(embedplanes, npatches),
            transformer_encoder(embedplanes, depth, nheads),
            Dense(embedplanes, prod(patch_size) * outchannels)
        )
    )
    if pretrain
        throw(ErrorException("MAE Decoder do not have the pretrain paramters."))
    end
    return model
end

the decoder’s forward

function (md::MaskDecoder)(encode, ids_restore)
    embed = md.projection(encode) # with a CLS token
    _, n, B = size(embed) # 1+p, p is number of remained tokens, 1 is CLS
    N = size(md.decoder[1].vectors, 2) # N=1+P
    masks = repeat(md.mask_token, 1, N - n, B)
    embed_patch = cat(embed[:, 2:end, :], masks, dims=2)  # without CLS
    # <CLS> <tokens> <masks> -> <CLS> <mask and token in right position>
    embed_sorted = cat(embed[:, 1:1, :], embed_patch[:, ids_restore], dims=2)
    decode = md.decoder(embed_sorted) # P*P*C,1+P,B
    return decode
end

the MAE model define

it is followed the Metalhead.jl work.

struct MAE{N}
    encoder
    decoder
    patch_size::NTuple{N}
end

@functor MAE

function MAE(
    config::Symbol; imsize::Dims{N}, patch_size::Dims{N},
    inchannels::Integer=3, nclasses::Integer=1000,
    pretrain::Bool=false
) where {N}
    if N === 2
        bone = ViT(config; imsize, patch_size, pretrain, inchannels, nclasses)
    elseif N === 3
        bone = ViT3D(config; imsize, patch_size, inchannels, pretrain)
    else
        throw(ErrorException("Only support 2D or 3D images."))
    end
    return MAE(
        bone,
        MaskDecoder(
            imsize, patch_size, VIT_CONFIGS[config][:embedplanes], pretrain, inchannels;
            embedplanes=512, depth=8, nheads=16
        ),
        patch_size
    )
end

this is the forward functions

function make_mask(B, N, ratio)
    n = floor(Int, N * (1 - ratio))
    noise = rand(Float32, N, B)
    ids = CartesianIndices(noise)
    shuffle = ids[sortperm(noise, dims=1)]
    restore = ids[sortperm(shuffle, dims=1)]
    keep = shuffle[1:n, :]
    unseen = shuffle[n+1:end, :]
    return keep, unseen, restore
end

function masked_embedding(m::ViT, img, mask_ratio)
    b = backbone(m)
    # 1. patch embeddings
    # 2. append class token <CLS> (size C,1+P,B)
    # 3. add postion embeddings
    emb = b[1:3](img)
    # make mask
    C, N, B = size(emb)
    keep, unseen, restore = make_mask(B, N - 1, mask_ratio)
    emb_keep = cat(emb[:, 1:1, :], emb[:, keep], dims=2) # CLS token
    # 4. embeddings dropout, emb is size(C,n,B)
    # 5. transformer output
    x = b[4:5](emb_keep)
    return x, unseen, restore
end

function patchify(img::AbstractArray{T,4}, patch_size::Dims{2}) where {T}
    W, H, C, B = size(img)
    u, v = patch_size
    w, h = (W, H) .÷ patch_size
    imgs = reshape(img, u, w, v, h, C, B)
    #                   1, 2, 3, 4, 5, 6   
    patchs = permutedims(imgs, (1, 3, 5, 2, 4, 6)) # u,v,C,w,h,B
    return reshape(patchs, u * v * C, w * h, B)
end

function unpatchify(imref::AbstractArray, patchs::AbstractArray, patch_size::Dims{2})
    W, H, C, B = size(imref)
    p, q = patch_size
    w, h = (W, H) .÷ patch_size
    reshaped = reshape(patchs, patch_size..., C, w, h, B)
    # p,q,C,w,h,B -> p*w,q*h,C,B
    # 1,2,3,4,5,6 -> 1,4,2,5,3,6
    permuted = permutedims(reshaped, (1, 4, 2, 5, 3, 6))
    unpatched = reshape(permuted, p * w, q * h, C, B)
    return unpatched
end

function patchify(img::AbstractArray{T,5}, patch_size::Dims{3}) where {T}
    W, H, D, C, B = size(img)
    p, q, r = patch_size
    w, h, d = (W, H, D) .÷ patch_size
    imgs = reshape(img, p, w, q, h, r, d, C, B)
    #                   1, 2, 3, 4, 5, 6, 7, 8   
    patchs = permutedims(imgs, (1, 3, 5, 7, 2, 4, 6, 8)) # p,q,r,C,w,h,d,B
    return reshape(patchs, p * q * r * C, w * h * d, B)
end

function unpatchify(imref::AbstractArray, patchs::AbstractArray, patch_size::Dims{3})
    W, H, D, C, B = size(imref)
    p, q, r = patch_size
    w, h, d = (W, H, D) .÷ patch_size
    reshaped = reshape(patchs, patch_size..., C, w, h, d, B)
    # p,q,r,C,w,h,d,B -> p*w,q*h,r*d,C,B
    # 1,2,3,4,5,6,7,8 -> 1,5,2,6,3,7,4,8
    permuted = permutedims(reshaped, (1, 5, 2, 6, 3, 7, 4, 8))
    unpatched = reshape(permuted, p * w, q * h, r * d, C, B)
    return unpatched
end


function (m::MAE)(img::AbstractArray, mask_ratio::Real, criterion::Base.Callable)
    encode, ids_unseen, ids_restore = masked_embedding(m.encoder, img, mask_ratio)
    decode = m.decoder(encode, ids_restore)[:, 2:end, :] # P^n*C,wh,B without CLS
    patches = patchify(img, m.patch_size)
    loss = criterion(decode[:, ids_unseen], patches[:, ids_unseen])
    return loss
end


function (m::MAE)(x::AbstractArray, mask_ratio::Real; origin::Bool=false)
    encode, ids_unseen, ids_restore = masked_embedding(m.encoder, x, mask_ratio)
    if origin
        restore = patchify(img, m.patch_size)
        decode = m.decoder(encode, ids_restore)[:, 2:end, :] # without CLS
        restore[:, ids_unseen] .= decode[:, ids_restore] # P^n*C,wh,B
    else
        restore = m.decoder(encode, ids_restore)[:, 2:end, :] # P^n*C,wh,B
    end
    return unpatchify(x, restore, m.patch_size)
end

the test code

# loss function 
f(x, y) = mean((x .- y) .^ 2)
# 2D model
m = MAE(:tiny; imsize=(224, 224), patch_size=(16, 16))
x = rand(Float32, 224, 224, 3, 16)
println(m(x, 0.75) |> size) # (224,224,3,16)
println(m(x, 0.75, f))      # 3.8901417

opt = Flux.setup(Flux.Adam(), m)
for i in 1:10
    # @info "Iter $i"
    img = rand(Float32, 224, 224, 3, 16)
    try
        grads = Flux.gradient(m) do model
            loss = model(img, 0.75, f)
            @info loss
            loss
        end
        Flux.update!(opt, m, grads[1])
    catch e
        println(e)
    end
end

Are there explicit or implicit array mutation, and if so, what operations can be avoided :sob: