Upsampling in Flux.jl

I try to implement a convolutional auto-encoder with Flux.jl . I have a simple version working in python (Tensforflow Keras API). However the Flux.jl based model does not seem to converge and gives poor results even on the training dataset.

Here is my code using MNIST:

using Flux, Flux.Data.MNIST
using Flux: @epochs, mse, throttle
using Base.Iterators: partition
using CuArrays
using Flux.Tracker: TrackedArray, track, @grad

# return a list of batches; every batch has the size (28,28,1,batch_size)
# The last batch can be smaller
function getdata(params...; batch_size = 64)
    imgs = MNIST.images(params...)
    @show length(imgs)

    # Partition into batches
    data = [reshape(cat(float.(imgs)...; dims = 3),(28,28,1,:)) for imgs in partition(imgs, batch_size)];
    data = [gpu(Float32.(d)) for d in data];
    return data
end

# https://github.com/FluxML/NNlib.jl/pull/95
function upsample(x)
    ratio = (2,2,1,1)
    y = similar(x, (size(x) .* ratio)...)
    for i in Iterators.product(Base.OneTo.(ratio)...)
        loc = map((i,r,s)->range(i, stop = s, step = r), i, ratio, size(y))
        @inbounds y[loc...] = x
    end
    y
end

model = Chain(Conv((3, 3), 1=>16, pad=(1,1), relu),
              MaxPool((2,2)),
              Conv((3, 3), 16=>8, pad=(1,1), relu),
              MaxPool((2,2)) ,

              Conv((3, 3), 8=>8, pad=(1,1), relu),

              upsample,
              Conv((3, 3), 8=>16, pad=(1,1), relu) ,
              upsample,
              Conv((3, 3), 16=>1, pad=(1,1), relu)) |> gpu;


loss(x) = mse(model(x), x)

# get training data
data = getdata()
@show size(model(data[1]))
@show loss(data[1])

evalcb = throttle(() -> @show(loss(data[1])), 5)
opt = ADAM()

@epochs 50 Flux.train!(loss, params(model), zip(data), opt, cb = evalcb)

# get testing data
data_test = getdata(:test)

testMSE = 0
for d in data_test
    global testMSE
    testMSE += size(d,4) * Tracker.data(loss(d))
end

testMSE /= sum(size.(data_test,4))

After 50 epochs, I get an MSE of 0.04334188. In Tensorflow/Keras I get an MSE of 0.004307 after 5 epochs.

For reference, here is also the Python code:

import tensorflow as tf
import tensorflow.keras.layers as layers

import numpy as np

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

print("Number of training images ",x_train.shape[0])



model = tf.keras.models.Sequential([
    layers.Reshape((28, 28,1),input_shape=(28,28)),
    layers.Conv2D(filters=16,kernel_size=3,padding="same",activation='relu'),
    layers.MaxPooling2D(pool_size=2),
    layers.Conv2D(filters=8,kernel_size=3,padding="same",activation='relu'),
    layers.MaxPooling2D(pool_size=2),

    layers.Conv2D(filters=8,kernel_size=3,padding="same",activation='relu'),

    layers.UpSampling2D(size=(2,2)),
    layers.Conv2D(filters=16,kernel_size=3,padding="same",activation='relu'),
    layers.UpSampling2D(size=(2,2)),
    layers.Conv2D(filters=1,kernel_size=3,padding="same",activation='relu'),
    layers.Reshape((28, 28))
])

model.compile(optimizer='adam',
              loss='MSE')

model.fit(x_train, x_train, epochs=5, batch_size=64)
#model.evaluate(x_test, x_test)
print("MSE",np.mean((model.predict(x_test) - x_test)**2))

I got the upsampling function from a pending PR on NNlib.jl. The code looks correct to me, but could it be that Flux.Tracker is not able to compute its gradient property?

2 Likes

I just checked that Flux.jl and Tensorflow use the parameters (such as learning rate) values for the optimizations:

https://github.com/FluxML/Flux.jl/blob/master/src/optimise/optimisers.jl#L99
https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer

1 Like

I also tried without relu in the final layer, but the results are quite similar:

Flux MSE after 50 epoch: 0.063297756f0
Tensorflow MSE after 5 epoch: 0.005704676563238298

So still a factor of 10 (even after 50 epoch compared to just 5 epochs)
Here is the first test image for flux:
flux

And for tensorflow:
tensorflow

2 Likes

I am not sure if this is still relevant for you, but it might be for anyone who stumbles upon this. The reason your network is not really learning is because your upsampling does not keep the data type if the input is a TrackedArray - the output is just a regular Array. You can check that by removing the very last convolution layer - Flux should start complaining that the loss is not a Tracked scalar. In practice, only the last layer is being optimized, so it is not surprising that you get very poor results. You can try something like this

function upsample(x)
  ratio = (2, 2, 1, 1)
  (h, w, c, n) = size(x)
  y = similar(x, (1, ratio[1], 1, ratio[2], 1, 1))
  fill!(y, 1)
  z = reshape(x, (h, 1, w, 1, c, n))  .* y
  reshape(permutedims(z, (2,1,4,3,5,6)), size(x) .* ratio) 
end
8 Likes

Thanks a lot for your help! It does work on a CPU. But unfortunately, on a GPU, I see now the following error:

julia> include("mnist_cae_flux2.jl")
[ Info: Recompiling stale cache file /home/ulg/gher/abarth/.julia/compiled/v1.1/CuArrays/7YFE0.ji for CuArrays [3a865a2d-5b23-5a0f-bc46-62713ec82fae]
[ Info: Recompiling stale cache file /home/ulg/gher/abarth/.julia/compiled/v1.1/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
length(imgs) = 60000
size(model(data[1])) = (28, 28, 1, 64)
loss(data[1]) = 0.10397781f0 (tracked)
[ Info: Epoch 1
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/fLiQ1/src/indexing.jl:16
ERROR: LoadError: MethodError: no method matching culiteral_pow(::typeof(^), ::ForwardDiff.Dual{Nothing,Float32,1}, ::Val{2})
Closest candidates are:
  culiteral_pow(::typeof(^), ::Union{Float32, Float64}, ::Val{2}) at /home/ulg/gher/abarth/.julia/packages/CuArrays/wXQp8/src/broadcast.jl:46
  culiteral_pow(::typeof(^), ::Union{Float32, Float64}, ::Val{p}) where p at /home/ulg/gher/abarth/.julia/packages/CuArrays/wXQp8/src/broadcast.jl:48
  culiteral_pow(::typeof(^), ::Union{Float32, Float64}, ::Val{0}) at /home/ulg/gher/abarth/.julia/packages/CuArrays/wXQp8/src/broadcast.jl:44
  ...
Stacktrace:
 [1] (::getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)})(::Function, ::ForwardDiff.Dual{Nothing,Float32,1}, ::ForwardDiff.Dual{Nothing,Float32,1}, ::Val{2}) at ./broadcast.jl:298
 [2] partial(::getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)}, ::Float32, ::Int64, ::Function, ::Float32, ::Float32, ::Val{2}) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/array.jl:505
 [3] _broadcast_getindex at ./broadcast.jl:578 [inlined]
 [4] getindex at ./broadcast.jl:511 [inlined]
 [5] copy at ./broadcast.jl:787 [inlined]
 [6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayStyle{CuArray},Nothing,typeof(Tracker.partial),Tuple{Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)}},CuArray{Float32,4},Int64,Base.RefValue{typeof(^)},CuArray{Float32,4},CuArray{Float32,4},Base.RefValue{Val{2}}}}) at ./broadcast.jl:753
 [7] broadcast(::typeof(Tracker.partial), ::Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)}}, ::CuArray{Float32,4}, ::Int64, ::Vararg{Any,N} where N) at ./broadcast.jl:707
 [8] ∇broadcast(::typeof(Tracker.partial), ::Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)}}, ::CuArray{Float32,4}, ::Int64, ::Base.RefValue{typeof(^)}, ::TrackedArray{…,CuArray{Float32,4}}, ::CuArray{Float32,4}, ::Base.RefValue{Val{2}}) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/array.jl:509
 [9] copy(::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,NTuple{4,Base.OneTo{Int64}},typeof(Tracker.partial),Tuple{Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)}},CuArray{Float32,4},Int64,Base.RefValue{typeof(^)},TrackedArray{…,CuArray{Float32,4}},CuArray{Float32,4},Base.RefValue{Val{2}}}}) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/array.jl:540
 [10] materialize at ./broadcast.jl:753 [inlined]
 [11] #547 at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/array.jl:513 [inlined]
 [12] macro expansion at ./sysimg.jl:275 [inlined]
 [13] ntuple at ./sysimg.jl:271 [inlined]
 [14] (::getfield(Tracker, Symbol("#back#548")){4,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{…,CuArray{Float32,4}},CuArray{Float32,4},Base.RefValue{Val{2}}}})(::CuArray{Float32,4}) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/array.jl:513
 [15] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){4,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(-)}},typeof(CuArrays.culiteral_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{…,CuArray{Float32,4}},CuArray{Float32,4},Base.RefValue{Val{2}}}},Tuple{Nothing,Tracker.Tracked{CuArray{Float32,4}},Nothing,Nothing}}, ::CuArray{Float32,4}, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:35
 [16] back(::Tracker.Tracked{CuArray{Float32,4}}, ::CuArray{Float32,4}, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:58
 [17] foreach at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [18] back_(::Tracker.Call{getfield(Tracker, Symbol("##482#483")){TrackedArray{…,CuArray{Float32,4}}},Tuple{Tracker.Tracked{CuArray{Float32,4}}}}, ::Float32, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:38
 [19] back(::Tracker.Tracked{Float32}, ::Float32, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:58
 [20] #13 at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [21] foreach at ./abstractarray.jl:1867 [inlined]
 [22] back_(::Tracker.Call{getfield(Tracker, Symbol("##278#281")){Rational{Int64}},Tuple{Tracker.Tracked{Float32},Nothing}}, ::Float32, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:38
 [23] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:58
 [24] #back!#15 at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:77 [inlined]
 [25] #back! at ./none:0 [inlined]
 [26] #back!#32 at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/real.jl:16 [inlined]
 [27] back!(::Tracker.TrackedReal{Float32}) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/lib/real.jl:14
 [28] gradient_(::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss),Tuple{CuArray{Float32,4}}}, ::Tracker.Params) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:4
 [29] #gradient#24(::Bool, ::Function, ::Function, ::Tracker.Params) at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:164
 [30] gradient at /home/ulg/gher/abarth/.julia/packages/Tracker/SAr25/src/back.jl:164 [inlined]
 [31] macro expansion at /home/ulg/gher/abarth/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
 [32] macro expansion at /home/ulg/gher/abarth/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [33] #train!#12(::getfield(Flux, Symbol("#throttled#18")){getfield(Flux, Symbol("##throttled#10#14")){Bool,Bool,getfield(Main, Symbol("##12#13")),Int64}}, ::Function, ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{CuArray{Float32,4},1}}}, ::ADAM) at /home/ulg/gher/abarth/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
 [34] (::getfield(Flux.Optimise, Symbol("#kw##train!")))(::NamedTuple{(:cb,),Tuple{getfield(Flux, Symbol("#throttled#18")){getfield(Flux, Symbol("##throttled#10#14")){Bool,Bool,getfield(Main, Symbol("##12#13")),Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{CuArray{Float32,4},1}}}, ::ADAM) at ./none:0
 [35] top-level scope at /home/ulg/gher/abarth/.julia/packages/Flux/dkJUV/src/optimise/train.jl:106
 [36] top-level scope at /home/ulg/gher/abarth/.julia/packages/Juno/oLB1d/src/progress.jl:134
 [37] top-level scope at util.jl:156
 [38] include at ./boot.jl:326 [inlined]
 [39] include_relative(::Module, ::String) at ./loading.jl:1038
 [40] include(::Module, ::String) at ./sysimg.jl:29
 [41] include(::String) at ./client.jl:403
 [42] top-level scope at none:0
in expression starting at /home/users/a/b/abarth/projects/Julia/share/mnist_cae_flux2.jl:89

It could be this issue fixed in CuArrays (but not yet released).
https://github.com/JuliaGPU/CuArrays.jl/issues/378

1 Like

If I use the git master version of CuArrays. I works! Thank you very much.

Here are the run times for 5 epochs on the same machine:

Flux.jl: 29.8 seconds (smallest run-time out of 3 tests)
Tensorflow (python): 19.7 seconds

Unfortunately, the run time difference is too large for me, to migrate to Flux.jl. However, if I manage to optimize the Flux.jl model, I will post my updates here.

Flux: 0.9; julia 1.2.0; Tensorflow: 1.12; CuArrays 8e00af6

1 Like

This variant of @vitskvara solution is marginally faster: 25.69 s compared to 29.8 s previously by avoiding a call to permutedims.

function upsample(x)
  ratio = (2, 2, 1, 1)
  (h, w, c, n) = size(x)
  y = similar(x, (ratio[1], 1, ratio[2], 1, 1, 1))
  fill!(y, 1)
  z = reshape(x, (1, h, 1, w, c, n))  .* y
  reshape(z, size(x) .* ratio)
end
3 Likes