@JHall I struggle with the same issue at the very moment. Based on the examples that the documentation of Flux provides (Link), I rewrote your code.

What did the trick for me is that the loss function that is called during each iteration of train! holds at least the neural network as an argument, such that changes in the parameters affect the loss that is produced.

```
#First we import packages
using Flux, CUDA, Random, Plots, Distributions
#set model parameters
β = 0.9 #common discount rate
γ = 2.0 #CRRA coefficient
rbar = 1.04 #common interest rate
#set ar-1 parameters
#sds
σ_r = 0.001 #idiosyncratic interest rate shocks
σ_p = 0.0001 #permanent component of idiosyncratic income
σ_q = 0.001 #other portion of idiosyncratic income
σ_δ = 0.001 #discount rate shocks
#Ar-1 coefficeints
ρ_r = 0.2
ρ_p = 0.999
ρ_q = 0.9
ρ_δ = 0.2
#calculate ergodic distributions
σ_e_r = σ_r/sqrt(1-ρ_r^2)
σ_e_p = σ_p/sqrt(1-ρ_p^2)
σ_e_q = σ_q/sqrt(1-ρ_q^2)
σ_e_δ = σ_δ/sqrt(1-ρ_δ^2)
#set the bounds for the wage
w_min = 0.1
w_max = 4.0
#since we have to use KKT, we need to convert the KKT constraints into a single objective using the fischer burnmeister function
function fb(a,b)
return a + b - sqrt(a^2 + b^2)
end
#now we use flux to construct the neural network
model = Chain(
Dense(5 => 32, relu),
Dense(32 => 32,relu),
Dense(32 => 32,relu),
Dense(32 => 2)) #the last layer we don't specify an activation function because it defaults to the identity map
function dr(model, r,δ,q,p,w)
#normalize the values of the variables
r = r/(σ_e_r/2)
δ = δ/(σ_e_δ/2)
p = p/(σ_e_p/2)
q = q/(σ_e_q/2)
#normalize income to be between -1 and 1
w = (w.-w_min)./((w_max-w_min)*2.0).-1.0
#construct a matrix from the entries in
s = [r δ q p w]'
#apply the neural net
x = model(s)
#apply different activation functions to keep everything nice on the outside
ξ = sigmoid(x[1,:])
#this keeps the consumption share of CIH in [0,1]
h = exp.(x[2,:])
return (ξ,h)
end
#select number of grid points
N_wage = 100
#create a grid of wages
wealth = range(w_min,w_max,N_wage)
ξ_vec, h_vec = dr(model, wealth*0,wealth*0,wealth*0,wealth*0,wealth)
plot(wealth,wealth.*ξ_vec);
plot!(wealth,wealth);
display(plot!(title = "Consumption policy", xlabel = "Wealth", ylabel = "Consumption"));
function Residuals(model, e_r,e_δ,e_q,e_p,r,δ,q,p,w)
#get the length
n = length(r)
#get the values of the state today
ξ, h = dr(model, r,δ,q,p,w)
c = ξ.*w
#transitions
r_prime = r*ρ_r + e_r
δ_prime = δ*ρ_δ + e_δ
p_prime = p*ρ_p + e_p
q_prime = q*ρ_q + e_q
#
w_prime = exp.(p_prime).*exp.(q_prime) .+ (w.-c).*rbar.*exp.(r_prime)
ξ_prime, h_prime = dr(model, r_prime,δ_prime,q_prime,p_prime,w_prime)
c_prime = ξ_prime.*w_prime
R1 = β*exp.(δ_prime .- δ).*(c_prime./c).^(-γ).*(rbar.*exp.(r_prime)) .- h
R2 = fb.(ones(n)-h,ones(n)-ξ)
return (R1, R2)
end
###Initialize distributions for the draws function
#distributions for the current states
r_dist = Normal(0,σ_e_r)
δ_dist = Normal(0,σ_e_δ)
p_dist = Normal(0,σ_e_p)
q_dist = Normal(0,σ_e_q)
w_dist = Uniform(w_min,w_max)
#distributions for the state tomorrow (1st draw)
er_dist = Normal(0,σ_r)
eδ_dist = Normal(0,σ_δ)
ep_dist = Normal(0,σ_p)
eq_dist = Normal(0,σ_q)
function Draws(model, n = 100)
#draw current states
r = rand(r_dist,n)
δ = rand(δ_dist,n)
p = rand(p_dist,n)
q = rand(q_dist,n)
w = rand(w_dist,n)
#first draw of epsilons for tomorrow
er_1 = rand(er_dist,n)
eδ_1 = rand(eδ_dist,n)
ep_1 = rand(ep_dist,n)
eq_1 = rand(eq_dist,n)
#second draw of epsilons for tomorrow
er_2 = rand(er_dist,n)
eδ_2 = rand(eδ_dist,n)
ep_2 = rand(ep_dist,n)
eq_2 = rand(eq_dist,n)
R1_e1, R2_e1 = Residuals(model, er_1,eδ_1,eq_1,ep_1,r,δ,q,p,w)
R1_e2, R2_e2 = Residuals(model, er_2,eδ_2,eq_2,ep_2,r,δ,q,p,w)
return R1_e1.*R1_e2 .+ R2_e1.*R2_e2
end
pars = Flux.params(model)
#test the draws function
Draws(model, 10000)
opt = Flux.setup(Flux.ADAM(0.01), model);
loss(nn, x, y) = sum(Draws(nn))
# testing the loss function
loss(model, 1, 1)
#create a grid of wages
wealth = range(w_min,w_max,N_wage)
@time for i in 1:10000
# training the model
Flux.train!(loss, model, [(1,1)], opt)
if i % 500 == 0
# print the loss every 1000 iterations
println("Iter: ", i, " with loss ", loss(model, 1, 1))
# plot the current consumption policy
ξ_vec, h_vec = dr(model, wealth*0,wealth*0,wealth*0,wealth*0,wealth)
plot(wealth,wealth.*ξ_vec)
plot!(wealth,wealth)
display(plot!(title = "Consumption policy", xlabel = "Wealth", ylabel = "Consumption"))
end
end
# evaluating the loss function
loss(model, 1, 1)
# plot the computed consumption policy
ξ_vec, h_vec = dr(model, wealth*0,wealth*0,wealth*0,wealth*0,wealth)
plot(wealth,wealth.*ξ_vec);
plot!(wealth,wealth);
display(plot!(title = "Consumption policy", xlabel = "Wealth", ylabel = "Consumption"));
```

Compared to the Python version, the code seems very slow, which I attribute to the very ad-hoc and quite bad implementation of the code. If you want to benchmark the code and look for the bottlenecks, I am happy to learn from your results!