Hi, I’m struggling with the use of Flux.jl with sampled data from distributions.
For example, I’m trying to train my model as the following steps:
step 1) construct an NN model
module Agents
using Flux
export MLPAgent, get
## MLPAgent
mutable struct MLPAgent
model
MLPAgent(args...; kwargs...) = init!(new(), args...; kwargs...)
end
"""
# Parameters
n_x: input size.
n_y: output size.
hidden_nodes: an array whose elements indicate the number of hidden layer nodes.
activation: activation function for all layers.
"""
function init!(agt::MLPAgent, n_x, n_y;
hidden_nodes=[128], activation=σ)
layer_nodes = [n_x, hidden_nodes..., n_y]
layers = _stack_network(layer_nodes, activation)
model = Chain(layers...)
agt.model = model
return agt
end
function _stack_network(layer_nodes, activation)
layers = []
for i in 1:length(layer_nodes)-1
push!(layers, Dense(layer_nodes[i:i+1]..., activation))
end
return layers
end
function Base.get(agt::MLPAgent, x)
return agt.model(x)
end
end # module
step 2) The output of the model is used as the parameter of multivariate normal distribution. Then, a point is sampled from the distribution. Finally, logpdf of the distribution at the point is evaluated for the loss function (Note: the loss shown below may be meaningless, just for test).
using Agents
using Random
using Flux
using Flux: @epochs
using Flux.Data: DataLoader
using LinearAlgebra
using Distributions
function test_Distributions()
Random.seed!(2021)
η = 0.1
opt = Descent(η)
n_x, n_y = 3, 3
agt = MLPAgent(n_x, n_y, hidden_nodes=[128])
d = MvNormal(Diagonal(agt.model(x)))
loss(x, y) = Flux.Losses.mse(logpdf(d, rand(d, 1)), y)
ps = Flux.params(agt.model)
X = rand(3, 6000)
Y = rand(3, 6000)
train_x, test_x = _partitionTrainTest(X)
train_y, test_y = _partitionTrainTest(Y)
data = DataLoader(train_x, train_y, batchsize = 128)
cb = function ()
@show(loss(test_x, test_y))
loss(test_x, test_y) < 0.08 && Flux.stop()
end
@epochs 10 Flux.train!(loss, ps, data, opt, cb = cb)
end
However, it seems not work with the error message shown below:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#368#369")(::Nothing) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/array.jl:61
[3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] materialize! at ./broadcast.jl:848 [inlined]
[5] materialize! at ./broadcast.jl:845 [inlined]
[6] materialize! at ./broadcast.jl:841 [inlined]
[7] broadcast! at ./broadcast.jl:814 [inlined]
[8] (::typeof(∂(broadcast!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[9] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/pdiagmat.jl:89 [inlined]
[10] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[11] unwhiten! at /home/jinrae/.julia/packages/PDMats/G0Prn/src/generics.jl:33 [inlined]
[12] (::typeof(∂(unwhiten!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[13] _rand! at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariate/mvnormal.jl:275 [inlined]
[14] (::typeof(∂(_rand!)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[15] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:70 [inlined]
[16] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[17] rand at /home/jinrae/.julia/packages/Distributions/YaVqp/src/multivariates.jl:69 [inlined]
[18] (::typeof(∂(rand)))(::Array{Float64,2}) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[19] loss at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:20 [inlined]
[20] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[21] (::Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
[22] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(λ)),Tuple{Tuple{Nothing,Nothing}}}})(::Float64) at /home/jinrae/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[23] #15 at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
[24] (::typeof(∂(λ)))(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[25] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
[26] gradient(::Function, ::Zygote.Params) at /home/jinrae/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
[27] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
[28] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::Descent; cb::var"#197#200"{var"#loss#199"{MvNormal{Float32,PDMats.PDiagMat{Float32,Array{Float32,1}},FillArrays.Zeros{Float32,1,Tuple{Base.OneTo{Int64}}}}},Array{Float64,2},Array{Float64,2}}) at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:80
[30] macro expansion at /home/jinrae/.julia/packages/Flux/05b38/src/optimise/train.jl:115 [inlined]
[31] macro expansion at /home/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[32] test_Distributions() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:31
[33] test_all() at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:67
[34] top-level scope at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70
[35] include(::String) at ./client.jl:457
[36] top-level scope at REPL[41]:1
in expression starting at /home/jinrae/.julia/dev/GliderPathPlanning/test/learning.jl:70
So my question is:
What’s the best way to do what I’m trying to?