I built the following model:
function my_model(n) conv_1 = Conv((8,8), 3 => 32, stride=(4,4), relu) conv_2 = Conv((4,4), 32 =>64, stride=(2,2), relu) conv_3 = Conv((3,3), 64 =>64, stride=(1,1), relu) model = Chain( x -> x / 255, conv_1, conv_2, conv_3, x -> reshape(x, (:, 1)), Dense(2304, 512, relu), Dense(512, n), ) return model end
and tested it in the following way, where the data is stored in (width, height, # channels, # batches) order:
model = my_model(6) test_input = rand(UInt8, (80, 80, 3, 1)) test_batch = rand(UInt8, (80, 80, 3, 32)) model(test_input) model(test_batch)
When testing the model with test_batch I get a dimension mismatch error. (DimensionMismatch(“A has dimensions (512,2304) but B has dimensions (73728,1)”)). It seems as if the reshape command does not work for the batch in this way. Actually, I thought I can feed my network a single input as well as a batch when storing the data in WHCN order. Could anybody help me with this?