I built the following model:
function my_model(n)
conv_1 = Conv((8,8), 3 => 32, stride=(4,4), relu)
conv_2 = Conv((4,4), 32 =>64, stride=(2,2), relu)
conv_3 = Conv((3,3), 64 =>64, stride=(1,1), relu)
model = Chain(
x -> x / 255,
conv_1,
conv_2,
conv_3,
x -> reshape(x, (:, 1)),
Dense(2304, 512, relu),
Dense(512, n),
)
return model
end
and tested it in the following way, where the data is stored in (width, height, # channels, # batches) order:
model = my_model(6)
test_input = rand(UInt8, (80, 80, 3, 1))
test_batch = rand(UInt8, (80, 80, 3, 32))
model(test_input)
model(test_batch)
When testing the model with test_batch I get a dimension mismatch error. (DimensionMismatch(“A has dimensions (512,2304) but B has dimensions (73728,1)”)). It seems as if the reshape command does not work for the batch in this way. Actually, I thought I can feed my network a single input as well as a batch when storing the data in WHCN order. Could anybody help me with this?