Showing images from CIFAR-10 with Plots.jl

I wish to show an image from CIFAR-10 loaded from the matlab files available here: CIFAR-10 and CIFAR-100 datasets

I had hoped to show the first image with the following code, but have not been able to make it work.

using MAT
using Flux
using Plots
using Images
using Colors

function loadfile(filename)
	all = matread(filename)

x = loadfile("../cifar10/data_batch_1.mat")[1, :] # load the array and get the first image
println(size(x))  # prints (3072, ). Each image is 32x32 pixels and has 3 channels.
ximg = (reshape(x, (32, 32, 3)) .% Int) ./ 255 # The type is now 32×32×3 Array{Float64,3}. Hopefully perfect for showing

# Now, plotting the image is tricky. 
colorview(RGB, ximg) # crashes with ERROR: LoadError: DimensionMismatch("indices Base.OneTo(32) are not consistent with color type ColorTypes.RGB{Float64}")
# Both of these plots are empty:

I expected this to be as easy as plot(colorview(RGB, ximg)), but I am clearly missing something.
I can not use MLDatasets for this as it is a school assignment, I have to use the linked matlab files.

Your input file stores color channel information in the last dimension of the array (“HWC” - height, width, color), while Images.jl stores each pixel’s color information in a single contiguous chunk of memory represented by RGB{Float64}, and colorview requires the input array to have the same layout (“CHW” - color, height, width). You can use permutedims to swap the indices so the color channel comes first:

julia> cview_img = colorview(RGB, permutedims(ximg, (3, 1, 2)))
32×32 reinterpret(reshape, RGB{Float64}, ::Array{Float64, 3}) with eltype RGB{Float64}:
 RGB{Float64}(0.480731,0.472963,0.493649)   …  RGB{Float64}(0.0425475,0.802405,0.772114)
 RGB{Float64}(0.38644,0.356001,0.842026)       RGB{Float64}(0.193704,0.0518138,0.103578)
 RGB{Float64}(0.248983,0.646537,0.738121)      RGB{Float64}(0.287988,0.231072,0.980442)
 RGB{Float64}(0.2245,0.68703,0.123556)         RGB{Float64}(0.308916,0.916772,0.769513)

Note that cview_img is described as reinterpret(reshape, ...). This means that cview_img refers to the same location in memory as ximg, and when accessing elements of cview_img, your indices must be redirected to the appropriate location within ximg, which will slow down operations that assume they’re working with contiguous strided data. This can be rectified with img = collect(cview_img), which will concretize the reinterpret(reshape, ...) by copying to a new block of memory with the correct layout, making it so indexing no longer needs to be redirected.

1 Like

That did it! I thought that reshaping would have the same result as permutedims, but apparently not.