I am trying to generate the results of this paper https://arxiv.org/abs/1808.03856. I have created two similar models using Flux.jl and Pytorch.
There is no equivalence to pytorch.gather
function in Flux.jl, so I created a custom function called gather
and it returns the correct value.
The problem is that my Pytorch loss decreases but my Flux.jl doesn’t decrease and my Pytorch code generates the distribution correctly but my Flux.jl code only shows a nearly uniform distribution which is wrong.
It seems that there is a bug in Flux.jl.
Here is my Pytorch code:
using PyCall
using Plots
torch = pyimport("torch")
nn = pyimport("torch.nn")
optim = pyimport("torch.optim")
D = 2
h = 13
k = 16
@pydef mutable struct Flow <: nn.Module
function __init__(self, D, h, k, flip)
pybuiltin(:super)(Flow, self).__init__()
self.D = D
self.h = h
self.k = k
self.flip = flip
self.net = nn.Sequential(nn.Linear(div(D, 2), h),
nn.ReLU(),
nn.Linear(h, h),
nn.ReLU(),
nn.Linear(h, div(D, 2)*k)).double()
end
function forward(self, x)
if self.flip == false
xa = x.T[:__getitem__](pybuiltin(:slice)(0, div(self.D, 2))).T
xb = x.T[:__getitem__](pybuiltin(:slice)(div(self.D, 2), self.D)).T
out = self.net(xa)
Q = torch.reshape(out, (-1, div(self.D,2), self.k))
Q = torch.softmax(Q, dim=2)
Qsum = torch.cat([torch.zeros(x.size()[1], div(self.D,2), 1).double(), torch.cumsum(Q, dim=2)], dim=2)
alpha = self.k*xb
bins = torch.floor(alpha)
alpha -= bins
bins = bins.long()
Qcurr = torch.squeeze(torch.gather(Q, -1, torch.unsqueeze(bins, -1)), dim=-1)
Qprev = torch.squeeze(torch.gather(Qsum, -1, torch.unsqueeze(bins, -1)), dim=-1)
cb = alpha*Qcurr + Qprev
adj = torch.prod(self.k*Qcurr, dim=1)
return torch.cat([xa, cb], dim=-1), adj
else
xa = x.T[:__getitem__](pybuiltin(:slice)(0, div(self.D, 2))).T
xb = x.T[:__getitem__](pybuiltin(:slice)(div(self.D, 2), self.D)).T
out = self.net(xb)
Q = torch.reshape(out, (-1, div(self.D,2), self.k))
Q = torch.softmax(Q, dim=2)
Qsum = torch.cat([torch.zeros(x.size()[1], div(self.D,2), 1).double(), torch.cumsum(Q, dim=2)], dim=2)
alpha = self.k*xa
bins = torch.floor(alpha)
alpha -= bins
bins = bins.long()
Qcurr = torch.squeeze(torch.gather(Q, -1, torch.unsqueeze(bins, -1)), dim=-1)
Qprev = torch.squeeze(torch.gather(Qsum, -1, torch.unsqueeze(bins, -1)), dim=-1)
ca = alpha*Qcurr + Qprev
adj = torch.prod(self.k*Qcurr, dim=1)
return torch.cat([ca, xb], dim=-1), adj
end
end
end
@pydef mutable struct Flows <: nn.Module
function __init__(self, D, h, k, flips)
pybuiltin(:super)(Flows, self).__init__()
self[:_flows] = nn.ModuleList([Flow(D, h, k, flips[i]) for i in 1:length(flips)])
end
function forward(self, x)
absdetjac = torch.ones(size(x)[1]).double()
for _flow in self._flows
x, z = _flow(x)
absdetjac = absdetjac*z
end
return x, absdetjac
end
end
flow = Flows(D, h, k, [true, false])
function f(x)
10*torch.exp(-25*torch.sum((x - 0.25)^2, dim=1)) + 10*torch.exp(-25*torch.sum((x - 0.75)^2, dim=1))
end
optimizer = optim.Adam(flow.parameters(), lr=0.001, weight_decay=0.002)
for i=1:250
x = torch.rand(10000, 2).double()
z, J = flow(x)
z = f(z).detach()
loss = torch.mean((z*J)^2)
@show loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
end
x = torch.rand(1000, 2).double()
y = flow(x)[1]
sy = y.detach().numpy()
scatter(sy[:,1], sy[:,2], aspect_ratio=:equal)
And here is my Flux.jl code:
using Flux
using Statistics
using Plots
const D = 2
const h = 13
const k = 16
cumsum2(x) = cat(zeros(size(x)[1:2]), cumsum(x, dims=3), dims=3)
gather(x, bins) = getindex(x, LinearIndices(bins) .+ length(bins)*bins)
net1 = Chain(Dense(div(D,2), h, relu), Dense(h, h, relu), Dense(h, div(D,2)*k))
net2 = Chain(Dense(div(D,2), h, relu), Dense(h, h, relu), Dense(h, div(D,2)*k))
net1 = fmap(f64, net1)
net2 = fmap(f64, net2)
function flow(net, x, flip)
if flip == false
xa, xb = x[1:div(D,2),:], x[div(D,2)+1:D,:]
out = net(xa)
Q = reshape(out, (div(D,2), :, k))
Q = softmax(Q, dims=3)
Qsum = cumsum2(Q)
alpha = k*xb
bins = floor.(alpha)
alpha = alpha .- bins
Qcurr = gather(Q, Int.(bins))
Qprev = gather(Qsum, Int.(bins))
cb = alpha.*Qcurr .+ Qprev
absdetjac = prod(k*Qcurr, dims=1)
return cat(xa, cb, dims=1), absdetjac
else
xa, xb = x[1:div(D,2),:], x[div(D,2)+1:D,:]
out = net(xb)
Q = reshape(out, (div(D,2), :, k))
Q = softmax(Q, dims=3)
Qsum = cumsum2(Q)
alpha = k*xa
bins = floor.(alpha)
alpha = alpha .- bins
Qcurr = gather(Q, Int.(bins))
Qprev = gather(Qsum, Int.(bins))
ca = alpha.*Qcurr .+ Qprev
absdetjac = prod(k*Qcurr, dims=1)
return cat(ca, xb, dims=1), absdetjac
end
end
function flows(nets, x, flips)
z = ones(1,size(x)[2])
local absdetjac
for i=1:length(nets)
net = nets[i]
flip = flips[i]
x, absdetjac = flow(net, x, flip)
absdetjac = absdetjac .* z
end
return x, absdetjac
end
ps = params(net1,net2)
opt = Optimiser(ADAM(), WeightDecay(0.002))
f(x) = 10*exp.(-25*sum((x .- 0.25).^2, dims=1)) .+ 10*exp.(-25*sum((x .- 0.75).^2, dims=1))
Flux.@nograd f
function loss(x,y)
z, J = flows([net1, net2], x, [true, false])
z = f(z)
return mean((J.*z).^2)
end
for i=1:250
x = rand(2,10000)
@show loss(x,0)
Flux.train!(loss, ps, [(x,0)], opt)
end
x = rand(2,1000)
sx = flows([net1, net2], x, [true, false])[1]
scatter(sx[1,:], sx[2,:])
Another question: How can I trace the behaviour of gradient?