Multiple Dispatch with Abstract Integer

Hey everyone,
I created a struct UNet with an outer constructor and made it callable, but when trying to actually create an instance I get a MethodError that I don’t understand: I am type annotating ::Integer and ::Vector{Integer} which – from my understanding – should allow instantiating with any kind of int (e.g. Int64, as the default integer on my machine). Code and error are below, and I would really appreciate the help :slight_smile:

EDIT: it works if I annotate with ::Int64 everywhere, but that kinda goes against “annotating with the most general type”. Also: Can somebody explain me why the other suggested method (the inner constructor) suggests type ::Any and not the types that I specified in the struct?

"""UNet with Timestep Conditioning."""
struct UNet
    in_conv::Conv
    downsample_blocks::Vector{DownSampleBlock}
    upsample_blocks::Vector{UpSampleBlock}
    out_conv::Conv
end

function UNet(
    in_channels::Integer,
    out_channels::Integer,
    base_channels::Integer,
    channel_multipliers::Vector{Integer},
    t_embed_dim::Integer
)
    in_conv = Conv((3,3), in_channels => base_channels, pad=SamePad())

    channels = base_channels .* channel_multipliers
    @assert channels[1] == base_channels

    downblocks = []
    for (in_ch, out_ch) in zip(channels[1:end-1], channels[2:end])
        push!(downblocks, DownSampleBlock(in_ch, out_ch, t_embed_dim))
    end

    channels = 2 .* reverse(channels)

    upblocks = []
    for (in_ch, out_ch) in zip(channels[begin:end-1], channels[begin+1:end])
        push!(upblocks, UpSampleBlock(in_ch, div(out_ch, 2), t_embed_dim))
    end

    out_conv = Conv((3,3), div(channels[end], 2) => out_channels, pad=SamePad())

    return UNet(in_conv, downblocks, upblocks, out_conv)
end

function (m::UNet)(x::Array{Float32, 4}, t::Array{Float32, 2})
    skips = []

    x = m.in_conv(x)
    push!(skips, x)

    for block in m.downsample_blocks
        x = block(x, t)
        push!(skips, x)
    end

    for (i, block) in enumerate(m.upsample_blocks)
        x = cat(x, reverse(skips)[i], dims=3)
        x = block(x, t)
    end
    m.out_conv(x)
end
@functor UNet

Construction & error message

julia> include("unet.jl")

julia> UNet(3, 3, 64, [1, 2, 4, 8], 128)
ERROR: MethodError: no method matching UNet(::Int64, ::Int64, ::Int64, ::Vector{Int64}, ::Int64)

Closest candidates are:
  UNet(::Any, ::Any, ::Any, ::Any)
   @ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:108
  UNet(::Integer, ::Integer, ::Integer, ::Vector{Integer}, ::Integer)
   @ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:114

Stacktrace:
 [1] top-level scope
   @ REPL[2]:1

You need to write channel_multipliers::Vector{<:Integer} which allows a vector of any type that is a subtype of integer.

1 Like

Welcome to Julia, liopeer!

For an explanation, see here:
(Frequently Asked Questions · The Julia Language)

1 Like

The keywords in the Docs you are looking for, to grasp this behaviour are: invariant, covariant and contravariant types.

1 Like

Thank you to all of you! I guess I did not realize that this does not apply to composite types, even though that actually makes sense, now that I think about it.