Initializing Flux weights the same as PyTorch?

I am trying to replicate some results from a PyTorch model. I believe layers in Flux have their weights initialized quite differently than in PyTorch. For example, Flux layers have 0 bias to begin with, but PyTorch layers do have some bias by default.

Does anyone know how I can initialize Flux weights the same way PyTorch does?

I suspect as Julia and Flux grow in popularity, I wont be the only one wanting to do this.

The PyTorch layer docs will generally specify how each parameter is initialized. It’s also a single click to see the source, which shows you exactly which functions are being called. I believe Utility Functions · Flux implements almost all of torch.nn.init — PyTorch 1.7.0 documentation already.

Dense(512, 128, initW=(dims...) -> Flux.kaiming_uniform(dims...; gain=sqrt(1/3)), initb=initW=(dims...) -> Flux.kaiming_uniform(dims...; gain=sqrt(1/3)))

Seems to initialize the weights similar to PyTorch. I don’t know why PyTorch is using a gain of sqrt(1/3) but that’s what the source seems to show.

See: pytorch/ at master · pytorch/pytorch · GitHub

In the future you can just do:

Dense(512, 128, initW=Flux.kaiming_uniform(gain=sqrt(1/3)), initb=Flux.kaiming_uniform(gain=sqrt(1/3)))

That will compile now, but wont work. You’ll have to wait for my Flux PR to be merged before this shorter line will work: Fix layer init functions kwargs getting overwritten by DevJac · Pull Request #1499 · FluxML/Flux.jl · GitHub

Edit: Actually, this post isn’t quite right. The initW is correct, but there is no way to use Flux.kaiming_uniform to initialize the biases the same way as PyTorch, as far as I can tell.

1 Like

I came up with this function to initialize the weights the same way PyTorch does:

function Linear(in, out, activation)
    Dense(in, out, activation,
          initW=(_dims...) -> Float32.((rand(out, in).-0.5).*(2/sqrt(in))),
          initb=(_dims...) -> Float32.((rand(out).-0.5).*(2/sqrt(in))))

At least, for PyTorch’s Linear layers that’s how it works. You can easily verify this by creating a PyTorch Linear layer and looking at the minimum and maximum weight and bias values.

1 Like

I never did managed to replicate this particular Q-learning algorithm I was trying to. I eventually saw that the RMSProp implementations of PyTorch and Flux are different: PyTorch’s RMSProp has a smoothing parameter and Flux’s doesn’t. I don’t know if that alone was the cause of my problems.

Full story: I successfully replicated a simpler Q-learning algorithm, and was getting metrics very similar to the implementation I was trying to replicate. Then I tried this more complicated algorithm and wasn’t able to get similar results with similar hyper-parameters. I was able to get both algorithms to work in Flux, ultimately, but it sometimes required different hyper-parameters.

1 Like