Hello people. I am trying to implement an actor critic network using Flux. However for some reason my network is not being updated. Here is a sample from the code.
state_dim1 = 14
output_dim = 1
function actor_model(state_dim)
return Chain(
Dense(state_dim, 100),
Dense(100, 200),
Dense(200, 150),
# Dense(150, 150),
Dense(150,1,tanh))
end
struct Join{T, F}
combine::F
paths::T
end
# allow Join(op, m1, m2, ...) as a constructor
Join(combine, paths...) = Join(combine, paths)
Flux.@functor Join
(m::Join)(xs::Tuple) = m.combine(map((f, x) -> f(x), m.paths, xs)...)
(m::Join)(xs...) = m(xs)
function critic_model(state_dim,output_dim)
return Chain(
Join(vcat,
Chain(Dense(state_dim => 100, σ), Dense(100 => 64)), # branch 1
Dense(output_dim => 64, tanh) # branch 2
),
Dense(128,84,relu),Dense(84,10,relu),Dense(10,1,relu)
)
end
critic = gpu(critic_model(state_dim1,output_dim))
target_critic = gpu(critic_model(state_dim1,output_dim))
actor = gpu(actor_model(state_dim1))
target_actor = gpu(actor_model(state_dim1))
println(Flux.params(actor)[1,1][8:16],":Initial params actor")
# println(Flux.params(critic)[1,1][1:8],":Initial params critic")
s1 = [52.256588,3.001099,47.256588,1.0010991,52.815,2.0,52.565,2.0,52.315,2.0,52.065,2.0,51.815,2.0] |>gpu
a1 = [-0.43706045] |>gpu
r1 = [-5.0] |>gpu
s2 = [52.222652,3.0011091,47.222652,1.0011091,52.815,2.0,52.565,2.0,52.315,2.0,52.065,2.0,51.815,2.0] |>gpu
a2 = target_actor(s2)
next_val = target_critic((s2,a2))
y_expected = r1 .+ 0.2.*next_val
#### critic update
critic_loss(x,y) = Flux.mse(x, y)
prms_critic = Flux.params(critic)
opt = Flux.Adam()
data_critic = [(s1,a1,y_expected)] |> gpu
Flux.train!((x,y,z) -> critic_loss(critic((x,y)),z), prms_critic, data_critic, opt)
#### actor update
actor_loss(x,y) = -1*sum(critic((x,y)))
opt = Adam(0.5)
prms_actor = Flux.params(actor)
data_actor = [(s1,s1)] |> gpu
Flux.train!((x,y) -> actor_loss(x,actor(y)), prms_actor, data_actor, opt)
println(Flux.params(actor)[1,1][8:16])
# println(Flux.params(critic)[1,1][1:8])
Julia Version 1.6.7
I am not able to figure out the error in this case. Any help is appreciated. Thanks