How to drop the dropout layers in Flux.jl when assessing model performance



I was playing around with the Flux.jl framework and I fitted a model with dropouts at every layer.

using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using CuArrays
# Classify MNIST digits with a convolutional network

imgs = MNIST.images()

labels = onehotbatch(MNIST.labels(), 0:9)

# Partition into batches of size 32
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
         for i in partition(1:60_000, 32)]

train = gpu.(train)

# Prepare test set (first 1,000 images)
#tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tX = reshape(reduce(hcat, vec.(float.(MNIST.images(:test)))),28,28,1,10_000) |> gpu
tY = onehotbatch(MNIST.labels(:test), 0:9) |> gpu

trainX = reshape(reduce(hcat, vec.(float.(MNIST.images()))),28,28,1,60_000) |> gpu
trainY = onehotbatch(MNIST.labels(), 0:9) |> gpu

 m = Chain(
     Conv((3, 3), 1=>32, relu),
     x -> maxpool(x, (2,2)),
     Conv((3, 3), 32=>16, relu),
     x -> maxpool(x, (2,2)),
     Conv((3, 3), 16=>10, relu),
     x -> reshape(x, :, size(x, 4)),
     Dense(90, 10), softmax) |> gpu

I then trained it


loss(x, y) = crossentropy(m(x), y)

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

opt = ADAM()

evalcb = throttle(() -> @show(accuracy(tX, tY)), 5)
@time Flux.@epochs 40 Flux.train!(loss, params(m), train, opt, cb = evalcb)
accuracy(tX, tY)

So far so good. But I was thinking, m is defined with lots of dropouts, when I implement the model, I want to not use the dropouts anymore, so essentially I want to run m now with a dropout ratio of 0, or in other words keep the same weights that I have trained, but remove the drop out layers.

How can I easily achieve that with the Flux framework? Is there a function such as m1 = remove_drops(m) that can do that?

help?> Flux.testmode!
  testmode!(m, false)

  Put layers like Dropout and BatchNorm into testing mode (or back to training
  mode with false).