Initializing Flux weights the same as PyTorch?

I came up with this function to initialize the weights the same way PyTorch does:

function Linear(in, out, activation)
    Dense(in, out, activation,
          initW=(_dims...) -> Float32.((rand(out, in).-0.5).*(2/sqrt(in))),
          initb=(_dims...) -> Float32.((rand(out).-0.5).*(2/sqrt(in))))
end

At least, for PyTorch’s Linear layers that’s how it works. You can easily verify this by creating a PyTorch Linear layer and looking at the minimum and maximum weight and bias values.

1 Like