I was right. The problem is reshape function. This model works:
function gather(x, bins)
s1 = size(x)[1]
return getindex(x, bins .+ s1.*(LinearIndices(bins) .- 1))
end
function flow(net, x, flip)
if flip == false
xa, xb = x[1:1,:], x[2:2,:]
out = net(xa)
Q = softmax(out, dims=1)
Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
alpha = 16.0*xb
bins = floor.(alpha)
alpha = alpha .- bins
Qcurr = gather(Q, Int.(bins .+ 1))
Qprev = gather(Qsum, Int.(bins .+ 1))
cb = alpha.*Qcurr .+ Qprev
absdetjac = prod(16.0*Qcurr, dims=1)
return cat(xa, cb, dims=1), absdetjac
else
xa, xb = x[1:1,:], x[2:2,:]
out = net(xb)
Q = softmax(out, dims=1)
Qsum = cat(zeros(1,size(xa)[2]), cumsum(Q, dims=1), dims=1)
alpha = 16.0*xa
bins = floor.(alpha)
alpha = alpha .- bins
Qcurr = gather(Q, Int.(bins .+ 1))
Qprev = gather(Qsum, Int.(bins .+ 1))
ca = alpha.*Qcurr .+ Qprev
absdetjac = prod(16.0*Qcurr, dims=1)
return cat(ca, xb, dims=1), absdetjac
end
end
And loss value decreases:
loss(x, 0) = 12.255305478795886
...
loss(x, 0) = 10.806089610225168
...
loss(x, 0) = 8.449290100950817
My model now works and it transforms a uniform distribution to a target distribution so I could generalize it to n-dimensional functions without reshaping network output to a 3d array.
However it would be nice to ask @MikeInnes to check reshape function and solve this issue.