Have no idea how to train neural networks using sampled data from distributions

Hi, I’m struggling with the use of Flux.jl with sampled data from distributions.

For example, I’m trying to train my model as the following steps:

step 1) construct an NN model

module Agents

using Flux

export MLPAgent, get

## MLPAgent
mutable struct MLPAgent
    model

    MLPAgent(args...; kwargs...) = init!(new(), args...; kwargs...)
end

"""
# Parameters
n_x: input size.
n_y: output size.
hidden_nodes: an array whose elements indicate the number of hidden layer nodes.
activation: activation function for all layers.
"""
function init!(agt::MLPAgent, n_x, n_y;
               hidden_nodes=[128], activation=σ)
    layer_nodes = [n_x, hidden_nodes...,  n_y]
    layers = _stack_network(layer_nodes, activation)
    model = Chain(layers...)
    agt.model = model
    return agt
end

function _stack_network(layer_nodes, activation)
    layers = []
    for i in 1:length(layer_nodes)-1
        push!(layers, Dense(layer_nodes[i:i+1]..., activation))
    end
    return layers
end

function Base.get(agt::MLPAgent, x)
    return agt.model(x)
end


end  # module

step 2) The output of the model is used as the parameter of multivariate normal distribution. Then, a point is sampled from the distribution. Finally, logpdf of the distribution at the point is evaluated for the loss function (Note: the loss shown below may be meaningless, just for test).

using Agents
using Random
using Flux
using Flux: @epochs
using Flux.Data: DataLoader

using LinearAlgebra
using Distributions


function test_Distributions()
    Random.seed!(2021)
    η = 0.1
    opt = Descent(η)
    n_x, n_y = 3, 3
    agt = MLPAgent(n_x, n_y, hidden_nodes=[128])
    d = MvNormal(Diagonal(agt.model(x)))
    loss(x, y) = Flux.Losses.mse(logpdf(d, rand(d, 1)), y)
    ps = Flux.params(agt.model)
    X = rand(3, 6000)
    Y = rand(3, 6000)
    train_x, test_x = _partitionTrainTest(X)
    train_y, test_y = _partitionTrainTest(Y)
    data = DataLoader(train_x, train_y, batchsize = 128)
    cb = function ()
        @show(loss(test_x, test_y))
        loss(test_x, test_y) < 0.08 && Flux.stop()
    end
    @epochs 10 Flux.train!(loss, ps, data, opt, cb = cb)
end

However, it seems not work with the error message shown below:

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#368#369")(::Nothing) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/array.jl:61
 [3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize! at ./broadcast.jl:845 [inlined]
 [6] materialize! at ./broadcast.jl:841 [inlined]
 [7] broadcast! at ./broadcast.jl:814 [inlined]
 [8] (::typeof(∂(broadcast!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [9] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/pdiagmat.jl:89 [inlined]
 [10] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [11] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/generics.jl:33 [inlined]
 [12] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [13] _rand! at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariate/mvnormal.jl:275 [inlined]
 [14] (::typeof(∂(_rand!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [15] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:70 [inlined]
 [16] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [17] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:69 [inlined]
 [18] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] loss at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:20 [inlined]
 [20] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
 [22] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}}})(::Float64) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [23] #15 at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
 [24] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
 [26] gradient(::Function, ::Zygote.Params) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [27] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
 [28] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::Descent; cb::var"#197#200"{var"#loss#199"{MvNormal{Float32,PDMats.PDiagMat{Float32,Array{Float32,1}},FillArrays.Zeros{Float32,1,Tuple{Base.OneTo{Int64}}}}},Array{Float64,2},Array{Float64,2}}) at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:80
 [30] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:115 [inlined]
 [31] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [32] test_Distributions() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:31
 [33] test_all() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:67
 [34] top-level scope at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70
 [35] include(::String) at ./client.jl:457
 [36] top-level scope at REPL[41]:1
in expression starting at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70

So my question is:
What’s the best way to do what I’m trying to?

Looks like the error is from trying to take the gradient of rand(d,1) in the loss function. I does not look like that is needed, so try to wrap it in nograd or move it outside the loss function.

If you need to take the gradient of some part of it, then you could look up the reparametrization trick used in VAEs for inspiration.

Oh, I should’ve taken a value of loss at a point x, which is already an argument of the loss function so I don’t need to sample a point in the function!
What a stupid mistake :frowning:
For the reparameterisation technique, actually not necessary for now, I’ll take a look at what you suggested later. thx!