Hey everyone,
I created a struct UNet with an outer constructor and made it callable, but when trying to actually create an instance I get a MethodError that I don’t understand: I am type annotating ::Integer and ::Vector{Integer} which – from my understanding – should allow instantiating with any kind of int (e.g. Int64, as the default integer on my machine). Code and error are below, and I would really appreciate the help ![]()
EDIT: it works if I annotate with ::Int64 everywhere, but that kinda goes against “annotating with the most general type”. Also: Can somebody explain me why the other suggested method (the inner constructor) suggests type ::Any and not the types that I specified in the struct?
"""UNet with Timestep Conditioning."""
struct UNet
    in_conv::Conv
    downsample_blocks::Vector{DownSampleBlock}
    upsample_blocks::Vector{UpSampleBlock}
    out_conv::Conv
end
function UNet(
    in_channels::Integer,
    out_channels::Integer,
    base_channels::Integer,
    channel_multipliers::Vector{Integer},
    t_embed_dim::Integer
)
    in_conv = Conv((3,3), in_channels => base_channels, pad=SamePad())
    channels = base_channels .* channel_multipliers
    @assert channels[1] == base_channels
    downblocks = []
    for (in_ch, out_ch) in zip(channels[1:end-1], channels[2:end])
        push!(downblocks, DownSampleBlock(in_ch, out_ch, t_embed_dim))
    end
    channels = 2 .* reverse(channels)
    upblocks = []
    for (in_ch, out_ch) in zip(channels[begin:end-1], channels[begin+1:end])
        push!(upblocks, UpSampleBlock(in_ch, div(out_ch, 2), t_embed_dim))
    end
    out_conv = Conv((3,3), div(channels[end], 2) => out_channels, pad=SamePad())
    return UNet(in_conv, downblocks, upblocks, out_conv)
end
function (m::UNet)(x::Array{Float32, 4}, t::Array{Float32, 2})
    skips = []
    x = m.in_conv(x)
    push!(skips, x)
    for block in m.downsample_blocks
        x = block(x, t)
        push!(skips, x)
    end
    for (i, block) in enumerate(m.upsample_blocks)
        x = cat(x, reverse(skips)[i], dims=3)
        x = block(x, t)
    end
    m.out_conv(x)
end
@functor UNet
Construction & error message
julia> include("unet.jl")
julia> UNet(3, 3, 64, [1, 2, 4, 8], 128)
ERROR: MethodError: no method matching UNet(::Int64, ::Int64, ::Int64, ::Vector{Int64}, ::Int64)
Closest candidates are:
  UNet(::Any, ::Any, ::Any, ::Any)
   @ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:108
  UNet(::Integer, ::Integer, ::Integer, ::Vector{Integer}, ::Integer)
   @ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:114
Stacktrace:
 [1] top-level scope
   @ REPL[2]:1