Different behaviour between Flux.jl and Pytorch

I was right. The problem is reshape function. This model works:

function gather(x, bins)
    s1 = size(x)[1]
    return getindex(x, bins .+ s1.*(LinearIndices(bins) .- 1))
end

function flow(net, x, flip)
    if flip == false
        xa, xb = x[1:1,:], x[2:2,:]
        out = net(xa)
        Q = softmax(out, dims=1)
        Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
        alpha = 16.0*xb
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins .+ 1))
        Qprev = gather(Qsum, Int.(bins .+ 1))
        cb = alpha.*Qcurr .+ Qprev
        absdetjac = prod(16.0*Qcurr, dims=1)
        return cat(xa, cb, dims=1), absdetjac
    else
        xa, xb = x[1:1,:], x[2:2,:]
        out = net(xb)
        Q = softmax(out, dims=1)
        Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
        alpha = 16.0*xa
        bins = floor.(alpha)
        alpha = alpha .- bins
        Qcurr = gather(Q, Int.(bins .+ 1))
        Qprev = gather(Qsum, Int.(bins .+ 1))
        ca = alpha.*Qcurr .+ Qprev
        absdetjac = prod(16.0*Qcurr, dims=1)
        return cat(ca, xb, dims=1), absdetjac
    end
end

And loss value decreases:

loss(x, 0) = 12.255305478795886
...
loss(x, 0) = 10.806089610225168
...
loss(x, 0) = 8.449290100950817

My model now works and it transforms a uniform distribution to a target distribution so I could generalize it to n-dimensional functions without reshaping network output to a 3d array.

However it would be nice to ask @MikeInnes to check reshape function and solve this issue.

4 Likes