CNN for MNIST

The thing to remember about CNNs is that their input is 3d (instead of the “typical” 1d vector input): X + Y + channel. Thus, when you batch lots of them together, you get a 4d input (instead of the “typical” matrix).

I think what might be tripping you up is that the MNIST dataset is implicitly 1-channel, so you’ve used unsqueeze to add in that third dimension. The batching is what adds that fourth dimension. You can of course test single images from either the test or train set — but you just need to either batch them together or make a single one 4d (again with unsqueeze).

3 Likes