Hey everyone,
I created a struct UNet
with an outer constructor and made it callable, but when trying to actually create an instance I get a MethodError
that I don’t understand: I am type annotating ::Integer and ::Vector{Integer} which – from my understanding – should allow instantiating with any kind of int (e.g. Int64, as the default integer on my machine). Code and error are below, and I would really appreciate the help
EDIT: it works if I annotate with ::Int64 everywhere, but that kinda goes against “annotating with the most general type”. Also: Can somebody explain me why the other suggested method (the inner constructor) suggests type ::Any and not the types that I specified in the struct?
"""UNet with Timestep Conditioning."""
struct UNet
in_conv::Conv
downsample_blocks::Vector{DownSampleBlock}
upsample_blocks::Vector{UpSampleBlock}
out_conv::Conv
end
function UNet(
in_channels::Integer,
out_channels::Integer,
base_channels::Integer,
channel_multipliers::Vector{Integer},
t_embed_dim::Integer
)
in_conv = Conv((3,3), in_channels => base_channels, pad=SamePad())
channels = base_channels .* channel_multipliers
@assert channels[1] == base_channels
downblocks = []
for (in_ch, out_ch) in zip(channels[1:end-1], channels[2:end])
push!(downblocks, DownSampleBlock(in_ch, out_ch, t_embed_dim))
end
channels = 2 .* reverse(channels)
upblocks = []
for (in_ch, out_ch) in zip(channels[begin:end-1], channels[begin+1:end])
push!(upblocks, UpSampleBlock(in_ch, div(out_ch, 2), t_embed_dim))
end
out_conv = Conv((3,3), div(channels[end], 2) => out_channels, pad=SamePad())
return UNet(in_conv, downblocks, upblocks, out_conv)
end
function (m::UNet)(x::Array{Float32, 4}, t::Array{Float32, 2})
skips = []
x = m.in_conv(x)
push!(skips, x)
for block in m.downsample_blocks
x = block(x, t)
push!(skips, x)
end
for (i, block) in enumerate(m.upsample_blocks)
x = cat(x, reverse(skips)[i], dims=3)
x = block(x, t)
end
m.out_conv(x)
end
@functor UNet
Construction & error message
julia> include("unet.jl")
julia> UNet(3, 3, 64, [1, 2, 4, 8], 128)
ERROR: MethodError: no method matching UNet(::Int64, ::Int64, ::Int64, ::Vector{Int64}, ::Int64)
Closest candidates are:
UNet(::Any, ::Any, ::Any, ::Any)
@ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:108
UNet(::Integer, ::Integer, ::Integer, ::Vector{Integer}, ::Integer)
@ Main ~/Documents/Development/mlblog/content/posts/2024-07-04_mnistjulia/unet.jl:114
Stacktrace:
[1] top-level scope
@ REPL[2]:1