Different behaviour between Flux.jl and Pytorch

I am trying to generate the results of this paper https://arxiv.org/abs/1808.03856. I have created two similar models using Flux.jl and Pytorch.
There is no equivalence to pytorch.gather function in Flux.jl, so I created a custom function called gather and it returns the correct value.
The problem is that my Pytorch loss decreases but my Flux.jl doesn’t decrease and my Pytorch code generates the distribution correctly but my Flux.jl code only shows a nearly uniform distribution which is wrong.
It seems that there is a bug in Flux.jl.
Here is my Pytorch code:

using PyCall
using Plots

torch = pyimport("torch")
nn = pyimport("torch.nn")
optim = pyimport("torch.optim")

D = 2
h = 13
k = 16

@pydef mutable struct Flow <: nn.Module
    function __init__(self, D, h, k, flip)
        pybuiltin(:super)(Flow, self).__init__()
        self.D = D
        self.h = h
        self.k = k
        self.flip = flip
        self.net = nn.Sequential(nn.Linear(div(D, 2), h),
        nn.ReLU(),
        nn.Linear(h, h),
        nn.ReLU(),
        nn.Linear(h, div(D, 2)*k)).double()
    end
    function forward(self, x)
        if self.flip == false
            xa = x.T[:__getitem__](pybuiltin(:slice)(0, div(self.D, 2))).T
            xb = x.T[:__getitem__](pybuiltin(:slice)(div(self.D, 2), self.D)).T
            out = self.net(xa)
            Q = torch.reshape(out, (-1, div(self.D,2), self.k))
            Q = torch.softmax(Q, dim=2)
            Qsum = torch.cat([torch.zeros(x.size()[1], div(self.D,2), 1).double(), torch.cumsum(Q, dim=2)], dim=2)
            alpha = self.k*xb
            bins = torch.floor(alpha)
            alpha -= bins
            bins = bins.long()
            Qcurr = torch.squeeze(torch.gather(Q, -1, torch.unsqueeze(bins, -1)), dim=-1)
            Qprev = torch.squeeze(torch.gather(Qsum, -1, torch.unsqueeze(bins, -1)), dim=-1)
            cb = alpha*Qcurr + Qprev
            adj = torch.prod(self.k*Qcurr, dim=1)
            return torch.cat([xa, cb], dim=-1), adj
        else
            xa = x.T[:__getitem__](pybuiltin(:slice)(0, div(self.D, 2))).T
            xb = x.T[:__getitem__](pybuiltin(:slice)(div(self.D, 2), self.D)).T
            out = self.net(xb)
            Q = torch.reshape(out, (-1, div(self.D,2), self.k))
            Q = torch.softmax(Q, dim=2)
            Qsum = torch.cat([torch.zeros(x.size()[1], div(self.D,2), 1).double(), torch.cumsum(Q, dim=2)], dim=2)
            alpha = self.k*xa
            bins = torch.floor(alpha)
            alpha -= bins
            bins = bins.long()
            Qcurr = torch.squeeze(torch.gather(Q, -1, torch.unsqueeze(bins, -1)), dim=-1)
            Qprev = torch.squeeze(torch.gather(Qsum, -1, torch.unsqueeze(bins, -1)), dim=-1)
            ca = alpha*Qcurr + Qprev
            adj = torch.prod(self.k*Qcurr, dim=1)
            return torch.cat([ca, xb], dim=-1), adj
        end
    end
end

@pydef mutable struct Flows <: nn.Module
    function __init__(self, D, h, k, flips)
        pybuiltin(:super)(Flows, self).__init__()
        self[:_flows] = nn.ModuleList([Flow(D, h, k, flips[i]) for i in 1:length(flips)])
    end
    function forward(self, x)
        absdetjac = torch.ones(size(x)[1]).double()
        for _flow in self._flows
            x, z = _flow(x)
            absdetjac = absdetjac*z
        end
        return x, absdetjac
    end
end

flow = Flows(D, h, k, [true, false])

function f(x)
    10*torch.exp(-25*torch.sum((x - 0.25)^2, dim=1)) + 10*torch.exp(-25*torch.sum((x - 0.75)^2, dim=1))
end

optimizer = optim.Adam(flow.parameters(), lr=0.001, weight_decay=0.002)

for i=1:250
    x = torch.rand(10000, 2).double()
    z, J = flow(x)
    z = f(z).detach()
    loss = torch.mean((z*J)^2)
    @show loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
end

x = torch.rand(1000, 2).double()
y = flow(x)[1]
sy = y.detach().numpy()
scatter(sy[:,1], sy[:,2], aspect_ratio=:equal)

And here is my Flux.jl code:

using Flux
using Statistics
using Plots

const D = 2
const h = 13
const k = 16

cumsum2(x) = cat(zeros(size(x)[1:2]), cumsum(x, dims=3), dims=3)
gather(x, bins) = getindex(x, LinearIndices(bins) .+ length(bins)*bins)

net1 = Chain(Dense(div(D,2), h, relu), Dense(h, h, relu), Dense(h, div(D,2)*k))
net2 = Chain(Dense(div(D,2), h, relu), Dense(h, h, relu), Dense(h, div(D,2)*k))
net1 = fmap(f64, net1)
net2 = fmap(f64, net2)

function flow(net, x, flip)
    if flip == false
        xa, xb = x[1:div(D,2),:], x[div(D,2)+1:D,:]
        out = net(xa)
        Q = reshape(out, (div(D,2), :, k))
        Q = softmax(Q, dims=3)
        Qsum = cumsum2(Q)
        alpha = k*xb
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins))
        Qprev = gather(Qsum, Int.(bins))
        cb = alpha.*Qcurr .+ Qprev
        absdetjac = prod(k*Qcurr, dims=1)
        return cat(xa, cb, dims=1), absdetjac
    else
        xa, xb = x[1:div(D,2),:], x[div(D,2)+1:D,:]
        out = net(xb)
        Q = reshape(out, (div(D,2), :, k))
        Q = softmax(Q, dims=3)
        Qsum = cumsum2(Q)
        alpha = k*xa
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins))
        Qprev = gather(Qsum, Int.(bins))
        ca = alpha.*Qcurr .+ Qprev
        absdetjac = prod(k*Qcurr, dims=1)
        return cat(ca, xb, dims=1), absdetjac
    end
end

function flows(nets, x, flips)
    z = ones(1,size(x)[2])
    local absdetjac
    for i=1:length(nets)
        net = nets[i]
        flip = flips[i]
        x, absdetjac = flow(net, x, flip)
        absdetjac = absdetjac .* z
    end
    return x, absdetjac
end

ps = params(net1,net2)
opt = Optimiser(ADAM(), WeightDecay(0.002))

f(x) = 10*exp.(-25*sum((x .- 0.25).^2, dims=1)) .+ 10*exp.(-25*sum((x .- 0.75).^2, dims=1))
Flux.@nograd f

function loss(x,y)
    z, J = flows([net1, net2], x, [true, false])
    z = f(z)
    return mean((J.*z).^2)
end

for i=1:250
    x = rand(2,10000)
    @show loss(x,0)
    Flux.train!(loss, ps, [(x,0)], opt)
end

x = rand(2,1000)
sx = flows([net1, net2], x, [true, false])[1]
scatter(sx[1,:], sx[2,:])

Another question: How can I trace the behaviour of gradient?

Yes, use Flux.gradient or pullback instead of train! so that you can analyze the gradients and model parameters before each update.

Another point to consider is that Flux and PyTorch initialize Dense layers differently by default. See Initializing Flux weights the same as PyTorch? - #4 by DevJac. If you can verify that a) the initializations are similar, b) the outputs from each intermediate step of the forward pass are similar, and c) the gradients are similar, then I think the behaviour shouldn’t be much different between PyTorch and Flux. If there is a bug (which seems somewhat unlikely since you’re using a plain MLP on CPU), I would imagine it’s somewhere in the backwards pass (and thus will show up in the gradients).

Also note that there’s a WIP PR for gather/scatter in NNlib: Add scatter/gather operations and their gradients by yuehhua · Pull Request #255 · FluxML/NNlib.jl · GitHub. You could try this to see if it generates more coherent results (but I would walk through the process in the previous post first)

The problem isn’t initilization. The problem is that the loss function doesn’t change at all. I tried to run it for thousands of epochs but didn’t see any change.

I forgot about this, but you can use Utilities · Zygote to log gradients during intermediate steps of your model. If it’s logging mostly 0 or extremely large values on the first couple of iterations, than there’s no need to train for another 248 epochs.

ScatterNNlib.gather returns array mutation error. I tried to use Zygote.Buffer but it doesn’t work.

Have you tried using @showgrad on your original code that doesn’t throw any errors? That should be step one for troubleshooting this.

I used @showgrad. Here are examples of the gradient of loss with respect to some parameters of each network:
gs[net2[2].W] =

13×13 Matrix{Float64}:
  0.0205577    0.0224886   0.0  …   0.0148803   0.0   0.0247433
  0.0          0.0         0.0      0.0         0.0   0.0
 -0.0157534   -0.017233    0.0     -0.0114028   0.0  -0.0189608
 -0.00968057  -0.0105898   0.0     -0.00700712  0.0  -0.0116516
  0.0          0.0         0.0      0.0         0.0   0.0
  0.0          0.0         0.0  …   0.0         0.0   0.0
  0.0          0.0         0.0      0.0         0.0   0.0
 -0.030542    -0.0334107   0.0     -0.0221073   0.0  -0.0367604
 -0.00298348  -0.00326371  0.0     -0.00215954  0.0  -0.00359093
  0.0206564    0.0225966   0.0      0.0149518   0.0   0.0248621
  0.0392505    0.0429372   0.0  …   0.0284108   0.0   0.047242
  0.0382797    0.0418752   0.0      0.0277081   0.0   0.0460736
  0.0          0.0         0.0      0.0         0.0   0.0
`

gs[net2[2].b] =

13-element Vector{Float64}:
 -1.0430295949853674e-15
  0.0
 -2.4839071771642907e-16
  4.979876103497483e-16
  0.0
  0.0
  0.0
  4.629001175426861e-16
 -7.046636494797975e-16
 -2.2768245622195593e-18
 -1.8267451353665143e-16
 -2.0393842864452338e-16
  0.0

gs[net1[2].W] =

16×13 Matrix{Float64}:
 -5.5032e-6    -8.78365e-6   0.0  …  0.0  0.0  0.0  0.0  -3.92421e-6
 -5.80678e-5   -9.26818e-5   0.0     0.0  0.0  0.0  0.0  -4.14069e-5
 -9.8981e-5    -0.000157983  0.0     0.0  0.0  0.0  0.0  -7.05812e-5
  0.000173205   0.000276453  0.0     0.0  0.0  0.0  0.0   0.000123509
 -0.000614545  -0.000980873  0.0     0.0  0.0  0.0  0.0  -0.000438218
  6.32818e-5    0.000101004  0.0  …  0.0  0.0  0.0  0.0   4.51248e-5
  0.000187183   0.000298762  0.0     0.0  0.0  0.0  0.0   0.000133476
 -0.000231297  -0.000369172  0.0     0.0  0.0  0.0  0.0  -0.000164933
  0.000197706   0.000315557  0.0     0.0  0.0  0.0  0.0   0.00014098
 -9.71435e-5   -0.00015505   0.0     0.0  0.0  0.0  0.0  -6.92709e-5
 -0.000141942  -0.000226554  0.0  …  0.0  0.0  0.0  0.0  -0.000101216
  0.000200753   0.000320421  0.0     0.0  0.0  0.0  0.0   0.000143153
  0.000226922   0.000362189  0.0     0.0  0.0  0.0  0.0   0.000161813
  0.000414478   0.000661546  0.0     0.0  0.0  0.0  0.0   0.000295555
  0.000171626   0.000273932  0.0     0.0  0.0  0.0  0.0   0.000122383
  1.2138e-5     1.93734e-5   0.0  …  0.0  0.0  0.0  0.0   8.65535e-6

gs[net1[2].b] =

13-element Vector{Float64}:
  1.1746652912522637e-17
 -3.057788939588024e-19
  0.0
  1.577937677440036e-17
 -2.0298297548002053e-17
  7.108300493358088e-18
  1.6618786425129373e-17
  0.0
  0.0
  0.0
  0.0
  0.0
 -3.0696474008495844e-18

And the values of elements of (gs[neti[1].W], gs[neti[1].b]) and (gs[neti[3].W], gs[neti[1].b]) are similar to the values of elements of (gs[neti[2].W], gs[neti[2].b) which i = 1,2.

I am sure that these values are wrong. Why does my model generate these numbers?

Could someone help me? Should I write custom derivatives for my neural network?

What is the reason for calling Flux.train! in the for loop? No clue if it’s your problem, but I would try gradient, pullback, update! or something similar if you are running into errors or weird things. I find that stepping through things methodically, printing the gradients out on each iteration, and confirming they are what I expect to help in most situations. Flux.train! doesn’t necessarily provide the level of granularity you may need for your lower level debugging.

Thank you for your reply, but it is not my problem. I am sure the value of gradient is wrong. Just look at the gradient with respect to biases.

Yes, but where does it start being wrong? That was the purpose of @andrewdinhobl’s question and the reason I suggested using @showgrad.

For example, I took the liberty of trying your code and logging some basic stats about each parameter and gradient for a couple epochs (output format is norm(gradient) mean(gradient) std(gradient) | norm(param) mean(param) std(param)):

PyTorch:

loss.item()=33.864630556567505
1.479 -0.005 0.427 | 2.422 -0.302 0.624
2.930 -0.031 0.845 | 2.553 0.217 0.702
12.658 0.108 0.971 | 1.988 -0.018 0.152
7.191 0.298 2.052 | 0.586 0.096 0.137
14.012 -0.000 0.974 | 2.249 -0.004 0.156
18.170 -0.000 4.692 | 0.580 0.048 0.142
2.033 0.102 0.577 | 2.287 0.144 0.643
2.890 0.150 0.820 | 2.206 0.259 0.577
21.168 -0.110 1.629 | 1.966 -0.004 0.152
8.385 -0.265 2.405 | 0.594 0.001 0.171
20.064 0.000 1.395 | 2.249 0.016 0.155
17.908 -0.000 4.624 | 0.675 0.044 0.168
loss.item()=33.659579468582415
1.481 -0.012 0.427 | 2.419 -0.302 0.624
3.052 -0.046 0.880 | 2.553 0.217 0.701
12.824 0.098 0.984 | 1.981 -0.018 0.152
7.255 0.269 2.076 | 0.585 0.096 0.137
14.319 -0.000 0.995 | 2.244 -0.003 0.156
18.375 0.000 4.744 | 0.582 0.048 0.142
1.937 0.078 0.553 | 2.285 0.144 0.642
2.764 0.111 0.790 | 2.207 0.259 0.577
20.603 -0.145 1.583 | 1.960 -0.004 0.151
8.197 -0.333 2.341 | 0.592 0.001 0.171
20.417 -0.000 1.419 | 2.245 0.017 0.155
18.241 0.000 4.710 | 0.675 0.044 0.168
loss.item()=32.94436537376331
1.459 -0.014 0.421 | 2.417 -0.302 0.623
2.910 -0.044 0.839 | 2.553 0.218 0.701
12.160 0.077 0.935 | 1.975 -0.018 0.151
6.905 0.209 1.982 | 0.585 0.096 0.136
13.784 0.000 0.958 | 2.239 -0.003 0.156
17.681 0.000 4.565 | 0.583 0.048 0.142
1.888 0.080 0.539 | 2.284 0.144 0.642
2.742 0.101 0.784 | 2.207 0.260 0.577
20.723 -0.186 1.588 | 1.954 -0.003 0.151
8.216 -0.414 2.332 | 0.590 0.001 0.170
19.839 0.000 1.379 | 2.241 0.017 0.155
17.463 0.000 4.509 | 0.675 0.044 0.168

Flux:

loss = 35.86602093659519
0.092 0.009 0.025 | 1.663 -0.054 0.477
0.143 0.017 0.037 | 1.897 -0.227 0.494
0.217 0.002 0.017 | 2.137 -0.010 0.165
0.229 0.018 0.063 | 0.675 -0.012 0.194
0.190 0.001 0.013 | 2.276 0.002 0.158
0.000 0.000 0.000 | 0.713 0.055 0.175
0.032 -0.002 0.009 | 1.481 -0.088 0.417
0.038 0.000 0.011 | 1.903 0.069 0.545
0.125 0.001 0.010 | 2.134 -0.025 0.163
0.070 0.006 0.019 | 0.629 -0.048 0.175
0.052 0.000 0.004 | 2.284 -0.003 0.159
0.000 0.000 0.000 | 0.640 0.028 0.163
loss = 35.86116996476047
0.081 0.009 0.021 | 1.660 -0.054 0.476
0.094 0.010 0.025 | 1.893 -0.226 0.493
0.207 0.000 0.016 | 2.132 -0.010 0.164
0.140 0.007 0.040 | 0.673 -0.012 0.194
0.194 -0.000 0.013 | 2.271 0.002 0.158
0.000 -0.000 0.000 | 0.711 0.055 0.175
0.086 -0.005 0.024 | 1.478 -0.088 0.417
0.094 -0.003 0.027 | 1.898 0.069 0.543
0.237 0.003 0.018 | 2.129 -0.025 0.162
0.160 0.011 0.045 | 0.629 -0.048 0.175
0.056 0.000 0.004 | 2.280 -0.003 0.158
0.000 -0.000 0.000 | 0.639 0.028 0.162
loss = 35.170570526399544
0.068 0.008 0.018 | 1.656 -0.054 0.475
0.063 0.004 0.018 | 1.889 -0.226 0.492
0.212 0.002 0.016 | 2.127 -0.010 0.164
0.137 0.018 0.035 | 0.672 -0.012 0.194
0.161 -0.000 0.011 | 2.265 0.002 0.157
0.000 -0.000 0.000 | 0.710 0.054 0.174
0.039 0.002 0.011 | 1.474 -0.088 0.416
0.080 0.004 0.023 | 1.894 0.069 0.542
0.204 0.001 0.016 | 2.125 -0.025 0.162
0.145 0.003 0.042 | 0.629 -0.048 0.175
0.059 0.000 0.004 | 2.275 -0.003 0.158
0.000 0.000 0.000 | 0.638 0.028 0.162

Notice how the former has much larger values than the latter? Armed with this knowledge, you can now go back with something like @showgrad to determine which operation(s) are reducing the magnitude of the gradients far more than they should be.

Let me be clear though: there is no shortcut here. You likely will have to trudge through and print/interrogate the gradients and outputs at multiple points in the model to hunt down the discrepancy. We can offer some tips and anecdotes on common pitfalls, but debugging fancy non-standard DL models is a bit too much.

3 Likes

I can’t answer the question, I’ll just like to comment. I think it’s great that you or anyone is trying to use Flux, to reimplement important papers without Python code. Now you’re debugging, but is your goal to see if Flux is faster? Might:

be helpful? And since you used Pytorch without Flux, is using this a better way:

I’m trying to learn this stuff myself, and curious about the capabilities of Flux/Julia-only code vs. non-Julia or hybrid solutions.

One other thing I might add is that after debugging you find a function or operation that does return a bad gradient, we can help you open a ticket in the appropriate place. But we would need a specific function.

One difference that I’ve run into in Flux vs. Pytorch is that they handle broadcasting differently and make different assumptions about the shapes of input arrays. Might be worth double-checking the shapes of input and output arrays at every step to make sure Flux isn’t broadcasting when you don’t expect it to.

The shapes of input and output are the same. I still don’t understand what the problem is but I guess my mistake is reshaping 2d array to 3d array. Reshaping works well in PyTorch but I am not sure how it works in Flux.
I tried to drop some functions and write the code differently. The only functions which I didn’t change are reshape and cat. I don’t think cat is the source of problem.
I am still trying. I could simply use a 2d array and reshape it to another 2d array.

My goal is to do some research.

I was right. The problem is reshape function. This model works:

function gather(x, bins)
    s1 = size(x)[1]
    return getindex(x, bins .+ s1.*(LinearIndices(bins) .- 1))
end

function flow(net, x, flip)
    if flip == false
        xa, xb = x[1:1,:], x[2:2,:]
        out = net(xa)
        Q = softmax(out, dims=1)
        Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
        alpha = 16.0*xb
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins .+ 1))
        Qprev = gather(Qsum, Int.(bins .+ 1))
        cb = alpha.*Qcurr .+ Qprev
        absdetjac = prod(16.0*Qcurr, dims=1)
        return cat(xa, cb, dims=1), absdetjac
    else
        xa, xb = x[1:1,:], x[2:2,:]
        out = net(xb)
        Q = softmax(out, dims=1)
        Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
        alpha = 16.0*xa
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins .+ 1))
        Qprev = gather(Qsum, Int.(bins .+ 1))
        ca = alpha.*Qcurr .+ Qprev
        absdetjac = prod(16.0*Qcurr, dims=1)
        return cat(ca, xb, dims=1), absdetjac
    end
end

And loss value decreases:

loss(x, 0) = 12.255305478795886
...
loss(x, 0) = 10.806089610225168
...
loss(x, 0) = 8.449290100950817

My model now works and it transforms a uniform distribution to a target distribution so I could generalize it to n-dimensional functions without reshaping network output to a 3d array.

However it would be nice to ask @MikeInnes to check reshape function and solve this issue.

3 Likes