Apply boolean mask in loss function

Is there a way to apply a boolean mask to the multidimensional arrays in a loss function using Flux? I’ve been unable to achieve this on the gpu. I’m using Julia 1.3 and Flux v0.10.4.

Consider the following MWE (this is just the autoencoder from the model zoo–the only thing I’ve modified is the loss function):

# Encode MNIST images as compressed vectors that can later be decoded back into
# images.
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition
using Parameters: @with_kw
using CUDAapi
if has_cuda()
    @info "CUDA is on"
    import CuArrays
    CuArrays.allowscalar(false)
end

@with_kw mutable struct Args
    lr::Float64 = 1e-3		# Learning rate
    epochs::Int = 2		# Number of epochs
    N::Int = 32			# Size of the encoding
    batchsize::Int = 4	# Batch size for training
    sample_len::Int = 20 	# Number of random digits in the sample image
    throttle::Int = 5		# Throttle timeout
end

function get_processed_data(args)
    # Loading Images
    imgs = MNIST.images()
    #Converting image of type RGB to float 
    imgs = channelview.(imgs)
    # Partition into batches of size 1000
    train_data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, args.batchsize)]
    
    train_data = gpu.(train_data)
    return train_data
end

function custom_loss(x,y)
    mask = gpu(rand(Bool, size(x)))
    mean_sq_error = sum((x[mask] .- y[mask]) .^ 2)/length(x[mask])

    return mean_sq_error 
end

function train(; kws...)
    args = Args(; kws...)	

    train_data = get_processed_data(args)

    @info("Constructing model......")
    # You can try to make the encoder/decoder network larger
    # Also, the output of encoder is a coding of the given input.
    # In this case, the input dimension is 28^2 and the output dimension of
    # encoder is 32. This implies that the coding is a compressed representation.
    # We can make lossy compression via this `encoder`.
    encoder = Dense(28^2, args.N, leakyrelu) |> gpu
    decoder = Dense(args.N, 28^2, leakyrelu) |> gpu 

    # Defining main model as a Chain of encoder and decoder models
    m = Chain(encoder, decoder)

    @info("Training model.....")
    loss(x) = custom_loss(m(x), x)
    ## Training
    evalcb = throttle(() -> @show(loss(train_data[1])), args.throttle)
    opt = ADAM(args.lr)
	
    @epochs args.epochs Flux.train!(loss, params(m), zip(train_data), opt, cb = evalcb)
	
    return m, args
end

using Images

img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample(m, args)
    imgs = MNIST.images()
    #Converting image of type RGB to float 
    imgs = channelview.(imgs)
    # `args.sample_len` random digits
    before = [imgs[i] for i in rand(1:length(imgs), args.sample_len)]
    # Before and after images
    after = img.(map(x -> cpu(m)(float(vec(x))), before))
    # Stack them all together
    hcat(vcat.(before, after)...)
end

cd(@__DIR__)
m, args= train()
# Sample output
@info("Saving image sample as sample_ae.png")
save("sample_ae.png", sample(m, args))

In particular, this is the loss function:

function custom_loss(x,y)
    mask = gpu(rand(Bool, size(x)))
    mean_sq_error = sum((x[mask] .- y[mask]) .^ 2)/length(x[mask])
    
    return mean_sq_error 
end

Running this, I receive the error:

ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at <home_dir>/.julia/packages/GPUArrays/OXvxB/src/host/indexing.jl:41
 [3] getindex at <home_dir>/.julia/packages/GPUArrays/OXvxB/src/host/indexing.jl:96 [inlined]
 [4] _getindex at ./abstractarray.jl:1004 [inlined]
 [5] getindex at ./abstractarray.jl:981 [inlined]
 [6] iterate at ./multidimensional.jl:593 [inlined]
 [7] iterate at ./multidimensional.jl:584 [inlined]
 [8] iterate at ./generator.jl:44 [inlined]
 [9] collect(::Base.Generator{Base.LogicalIndex{CartesianIndex{2},CuArrays.CuArray{Bool,2,Nothing}},Base.var"#409#410"}) at ./array.jl:622
 [10] collect at ./multidimensional.jl:571 [inlined]
 [11] ensure_indexable at ./multidimensional.jl:627 [inlined]
 [12] SubArray at ./subarray.jl:22 [inlined]
 [13] unsafe_view at ./subarray.jl:163 [inlined]
 [14] view at ./subarray.jl:158 [inlined]
 [15] _cuview(::CuArrays.CuArray{Float32,2,Nothing}, ::Tuple{CuArrays.CuArray{Bool,2,Nothing}}, ::CuArrays.NonContiguous) at <home_dir>/.julia/packages/CuArrays/l0gXB/src/subarray.jl:39
 [16] view at <home_dir>/.julia/packages/CuArrays/l0gXB/src/subarray.jl:24 [inlined]
 [17] (::Zygote.var"#1042#1044"{CuArrays.CuArray{Float32,2,Nothing},Tuple{CuArrays.CuArray{Bool,2,Nothing}}})(::CuArrays.CuArray{Float32,1,Nothing}) at <home_dir>/.julia/packages/Zygote/YeCEW/src/lib/array.jl:44
 [18] (::Zygote.var"#2740#back#1038"{Zygote.var"#1042#1044"{CuArrays.CuArray{Float32,2,Nothing},Tuple{CuArrays.CuArray{Bool,2,Nothing}}}})(::CuArrays.CuArray{Float32,1,Nothing}) at <home_dir>/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [19] custom_loss at <run_dir>/mwe_crash.jl:38 [inlined]
 [20] (::typeof(∂(custom_loss)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [21] loss at <run_dir>/mwe_crash.jl:62 [inlined]
 [22] (::typeof(∂(λ)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#172#173"{typeof(∂(λ)),Tuple{Tuple{Nothing}}})(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:171
 [24] #337#back at <home_dir>/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [25] #17 at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [26] (::typeof(∂(λ)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#51#52"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [28] gradient(::Function, ::Zygote.Params) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [29] macro expansion at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [30] macro expansion at <home_dir>/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [31] #train!#12(::Flux.var"#throttled#20"{Flux.var"#throttled#16#21"{Bool,Bool,var"#9#12"{Array{CuArrays.CuArray{Float32,2,Nothing},1},var"#loss#11"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}},Int64}}, ::typeof(Flux.Optimise.train!), ::var"#loss#11"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{CuArrays.CuArray{Float32,2,Nothing},1}}}, ::ADAM) at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [32] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Flux.var"#throttled#20"{Flux.var"#throttled#16#21"{Bool,Bool,var"#9#12"{Array{CuArrays.CuArray{Float32,2,Nothing},1},var"#loss#11"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}},Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{CuArrays.CuArray{Float32,2,Nothing},1}}}, ::ADAM) at ./none:0
 [33] macro expansion at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:122 [inlined]
 [34] macro expansion at <home_dir>/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [35] #train#8(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(train)) at <run_dir>/mwe_crash.jl:67
 [36] train() at <run_dir>/mwe_crash.jl:45
 [37] top-level scope at <run_dir>/mwe_crash.jl:89
 [38] include at ./boot.jl:328 [inlined]
 [39] include_relative(::Module, ::String) at ./loading.jl:1105
 [40] include(::Module, ::String) at ./Base.jl:31
 [41] include(::String) at ./client.jl:424
 [42] top-level scope at REPL[1]:1

So it looks like it’s falling back to a scalar implementation. Instead of explicitly masking, I also tried this variant of the loss function:

function custom_loss(x,y)
    mask = gpu(rand(Bool, size(x)))
    mean_sq_error = sum(((x .* mask) .- (y .* mask)) .^ 2)/sum(x .* mask)

    return mean_sq_error 
end

This resulted in the following error:

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#1050#1051")(::Nothing) at <home_dir>/.julia/packages/Zygote/YeCEW/src/lib/array.jl:61
 [3] (::Zygote.var"#2776#back#1052"{Zygote.var"#1050#1051"})(::Nothing) at <home_dir>/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] rand! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/Random/src/Random.jl:271 [inlined]
 [5] (::typeof(∂(rand!)))(::Array{Float32,2}) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [6] rand! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/Random/src/Random.jl:267 [inlined]
 [7] (::typeof(∂(rand!)))(::Array{Float32,2}) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [8] rand at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/Random/src/Random.jl:288 [inlined]
 [9] (::typeof(∂(rand)))(::Array{Float32,2}) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [10] rand at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/Random/src/Random.jl:289 [inlined]
 [11] (::typeof(∂(rand)))(::Array{Float32,2}) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [12] custom_loss at <run_dir>/mwe_crash.jl:36 [inlined]
 [13] (::typeof(∂(custom_loss)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [14] loss at <run_dir>/mwe_crash.jl:62 [inlined]
 [15] (::typeof(∂(λ)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#172#173"{typeof(∂(λ)),Tuple{Tuple{Nothing}}})(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:171
 [17] #337#back at <home_dir>/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [18] #17 at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [19] (::typeof(∂(λ)))(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#51#52"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [21] gradient(::Function, ::Zygote.Params) at <home_dir>/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [22] macro expansion at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [23] macro expansion at <home_dir>/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [24] #train!#12(::Flux.var"#throttled#20"{Flux.var"#throttled#16#21"{Bool,Bool,var"#25#28"{Array{CuArrays.CuArray{Float32,2,Nothing},1},var"#loss#27"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}},Int64}}, ::typeof(Flux.Optimise.train!), ::var"#loss#27"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{CuArrays.CuArray{Float32,2,Nothing},1}}}, ::ADAM) at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [25] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Flux.var"#throttled#20"{Flux.var"#throttled#16#21"{Bool,Bool,var"#25#28"{Array{CuArrays.CuArray{Float32,2,Nothing},1},var"#loss#27"{Chain{Tuple{Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}},Dense{typeof(leakyrelu),CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}}},Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{CuArrays.CuArray{Float32,2,Nothing},1}}}, ::ADAM) at ./none:0
 [26] macro expansion at <home_dir>/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:122 [inlined]
 [27] macro expansion at <home_dir>/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [28] #train#24(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(train)) at <run_dir>/mwe_crash.jl:67
 [29] train() at <run_dir>/mwe_crash.jl:45
 [30] top-level scope at <run_dir>/mwe_crash.jl:89
 [31] include at ./boot.jl:328 [inlined]
 [32] include_relative(::Module, ::String) at ./loading.jl:1105
 [33] include(::Module, ::String) at ./Base.jl:31
 [34] include(::String) at ./client.jl:424
 [35] top-level scope at REPL[1]:1

It doesn’t seem to me like I’m explicitly mutating any arrays here, so I’m guessing there’s something going on with temporaries.

Is this expected behavior? If so, is there a way that I can achieve what I’m trying to do?

Thanks!