Hi, I’m struggling with the use of Flux.jl with sampled data from distributions.
For example, I’m trying to train my model as the following steps:
step 1) construct an NN model
module Agents
using Flux
export MLPAgent, get
## MLPAgent
mutable struct MLPAgent
    model
    MLPAgent(args...; kwargs...) = init!(new(), args...; kwargs...)
end
"""
# Parameters
n_x: input size.
n_y: output size.
hidden_nodes: an array whose elements indicate the number of hidden layer nodes.
activation: activation function for all layers.
"""
function init!(agt::MLPAgent, n_x, n_y;
               hidden_nodes=[128], activation=σ)
    layer_nodes = [n_x, hidden_nodes...,  n_y]
    layers = _stack_network(layer_nodes, activation)
    model = Chain(layers...)
    agt.model = model
    return agt
end
function _stack_network(layer_nodes, activation)
    layers = []
    for i in 1:length(layer_nodes)-1
        push!(layers, Dense(layer_nodes[i:i+1]..., activation))
    end
    return layers
end
function Base.get(agt::MLPAgent, x)
    return agt.model(x)
end
end  # module
step 2) The output of the model is used as the parameter of multivariate normal distribution. Then, a point is sampled from the distribution. Finally, logpdf of the distribution at the point is evaluated for the loss function (Note: the loss shown below may be meaningless, just for test).
using Agents
using Random
using Flux
using Flux: @epochs
using Flux.Data: DataLoader
using LinearAlgebra
using Distributions
function test_Distributions()
    Random.seed!(2021)
    η = 0.1
    opt = Descent(η)
    n_x, n_y = 3, 3
    agt = MLPAgent(n_x, n_y, hidden_nodes=[128])
    d = MvNormal(Diagonal(agt.model(x)))
    loss(x, y) = Flux.Losses.mse(logpdf(d, rand(d, 1)), y)
    ps = Flux.params(agt.model)
    X = rand(3, 6000)
    Y = rand(3, 6000)
    train_x, test_x = _partitionTrainTest(X)
    train_y, test_y = _partitionTrainTest(Y)
    data = DataLoader(train_x, train_y, batchsize = 128)
    cb = function ()
        @show(loss(test_x, test_y))
        loss(test_x, test_y) < 0.08 && Flux.stop()
    end
    @epochs 10 Flux.train!(loss, ps, data, opt, cb = cb)
end
However, it seems not work with the error message shown below:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#368#369")(::Nothing) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/array.jl:61
 [3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize! at ./broadcast.jl:845 [inlined]
 [6] materialize! at ./broadcast.jl:841 [inlined]
 [7] broadcast! at ./broadcast.jl:814 [inlined]
 [8] (::typeof(∂(broadcast!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [9] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/pdiagmat.jl:89 [inlined]
 [10] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [11] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/generics.jl:33 [inlined]
 [12] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [13] _rand! at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariate/mvnormal.jl:275 [inlined]
 [14] (::typeof(∂(_rand!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [15] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:70 [inlined]
 [16] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [17] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:69 [inlined]
 [18] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] loss at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:20 [inlined]
 [20] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
 [22] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}}})(::Float64) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [23] #15 at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
 [24] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
 [26] gradient(::Function, ::Zygote.Params) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [27] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
 [28] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::Descent; cb::var"#197#200"{var"#loss#199"{MvNormal{Float32,PDMats.PDiagMat{Float32,Array{Float32,1}},FillArrays.Zeros{Float32,1,Tuple{Base.OneTo{Int64}}}}},Array{Float64,2},Array{Float64,2}}) at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:80
 [30] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:115 [inlined]
 [31] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [32] test_Distributions() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:31
 [33] test_all() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:67
 [34] top-level scope at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70
 [35] include(::String) at ./client.jl:457
 [36] top-level scope at REPL[41]:1
in expression starting at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70
So my question is:
What’s the best way to do what I’m trying to?
