How to use Flux to write a Multi-Head Output Network

I want to use julia.Flux to recurrent SAC algorithm. But I am facing a problem when I write the policy network modular because it is a multi-head output network, which receive a state matrix as input and output both expectation and variance. It can be easy to handle this in pytorch like this:

class PolicyNetwork(nn.Module):
    def __init__(self, n_states, n_actions, action_bounds, n_hidden_filters=256):
        super(PolicyNetwork, self).__init__()
        self.n_states = n_states
        self.n_hidden_filters = n_hidden_filters
        self.n_actions = n_actions
        self.action_bounds = action_bounds

        self.hidden1 = nn.Linear(in_features=self.n_states, out_features=self.n_hidden_filters)
        init_weight(self.hidden1)       # 初始化该神经层的参数
        self.hidden1.bias.data.zero_()  # 将偏差置零
        self.hidden2 = nn.Linear(in_features=self.n_hidden_filters, out_features=self.n_hidden_filters)
        init_weight(self.hidden2)
        self.hidden2.bias.data.zero_()

        self.mu = nn.Linear(in_features=self.n_hidden_filters, out_features=self.n_actions)
        init_weight(self.mu, initializer="xavier uniform")
        self.mu.bias.data.zero_()

        self.log_std = nn.Linear(in_features=self.n_hidden_filters, out_features=self.n_actions)
        init_weight(self.log_std, initializer="xavier uniform")
        self.log_std.bias.data.zero_()

    def forward(self, states):
        x = F.relu(self.hidden1(states))
        x = F.relu(self.hidden2(x))

        mu = self.mu(x)
        log_std = self.log_std(x)   # 给出方差
        return mu, log_std

but in julia, there is no class, so I have no idea how to deal with it. may anybody give me a code example using Flux? thanks very much!

I ask gpt about this, and it write a code like this for me:

# 创建策略函数神经网络
struct PolicyNetwork
    hidden1::Chain
    hidden2::Chain
    mu::Dense
    std::Dense
end

function PolicyNetwork(n_state::Int, n_action::Int, n_hidden::Int)
    hidden1 = Chain(
        Dense(n_state => n_hidden, init=Flux.glorot_uniform),
        BatchNorm(n_hidden)
        )
    hidden2 = Chain(
        Dense(n_hidden => n_hidden, init=Flux.glorot_uniform),
        BatchNorm(n_hidden)
    )
    mu = Dense(n_hidden => n_action, init=Flux.glorot_uniform)
    std = Dense(n_hidden => n_action, init=Flux.glorot_uniform)
    return PolicyNetwork(hidden1, hidden2, mu, std)
end

function (network::PolicyNetwork)(states)
    x = network.hidden1(states)
    x = network.hidden2(x)
    mu = network.mu(x)
    std = network.std(x)
    return mu, std
end

## 测试PolicyNetwork
network = PolicyNetwork(n_state, n_action, n_hidden)
x = [1;;]
mu, std = network(x)

it can run. Maybe this is a method. If you have a better method, I will be appretirate if you can tell me :heart_eyes:

plus, it can be very easy to transfer this network to gpu, just code like this:

function PolicyNetwork(n_state::Int, n_action::Int, n_hidden::Int)
    hidden1 = Chain(
        Dense(n_state => n_hidden, init=Flux.glorot_uniform),
        BatchNorm(n_hidden)
        )|>gpu
    hidden2 = Chain(
        Dense(n_hidden => n_hidden, init=Flux.glorot_uniform),
        BatchNorm(n_hidden)
    )|>gpu
    mu = Dense(n_hidden => n_action, init=Flux.glorot_uniform)|>gpu
    std = Dense(n_hidden => n_action, init=Flux.glorot_uniform)|>gpu
    return PolicyNetwork(hidden1, hidden2, mu, std)
end

others all the same

What AI doesn’t seem to know is that you need to tell Flux to look for parameters inside, by making the layer with a macro. Otherwise it will not be able to train:

julia> Flux.setup(Adam(), PolicyNetwork(2, 3, 4))
┌ Warning: setup found no trainable parameters in this model
└ @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/interface.jl:32
()

julia> Flux.@layer PolicyNetwork  # Defines methods for functor, show

julia> Flux.setup(Adam(), PolicyNetwork(2, 3, 4))  # now this sees parameters
(hidden1 = (layers = ((weight = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0;

After this, PolicyNetwork(2, 3, 4) |> gpu will move all the parameters, so your second definition should not be needed.

Note also that adding type parameters might be a good idea, for performance:

struct PolicyNetwork{A,B,C,D}
    hidden1::A
    hidden2::B
    mu::C
    std::D
end
2 Likes

Thank you very much, I was just doubting why it cannot be trained :heart_eyes: