Tell Flux not to differentiate something

Hi, I’m trying to create a model using Flux.

The model is quite complicated (it builds on Normalizing Flows, RealNVP) and uses many layers on top of each other. There are two inputs - vector Z and matrix X. X is transformed to matrix Y using neural networks that depend on Z. And Z is then transformed to Z_new using another neural network that also takes Y as input. You can see the calculations below, all stack. functions are some neural networks.

function (stack::SetFlowStack)(input)
    X, logJxy, Z, logJz = input
    # calculate RealNVP part
    Y_hat = X .* exp.(stack.sx(Z)) .+ stack.tx(Z)
    Y, tmpJ = stack.nvp((Y_hat,logJxy))
    logJY = tmpJ .- sum(stack.sx(Z),dims=1)

    # calculate global noise part
    ds = stack.deep_set(Y)
    Z_new = Z .* exp.(stack.sz(ds)) .+ stack.tz(ds)
    logJZ = logJz .- sum(stack.sz(ds),dims=1)

    return Y, logJY, Z_new, logJZ
end

The loss function is

function loss_setflow(model::SetFlow,X)
    z0 = rand(model.base)
    Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
    -sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end

where model.base is just simple isotropic Gaussian and SetFlow is just a Chain of SetFlowStacks.

The problem is, I can’t differentiate the loss function because of the mutation of arrays (I suspect mutating Z and Y).

Is there a way to tell Flux (for now) that Y_hat should only take Z as input variable (not differentiate it), and Z_new should not differentiate Y?

2 Likes

Hey,

I think Zygote.@ignore is what you are looking for:

Zygote.@ignore c .* v
Zygote.@ignore println("no differentiation on this unit")

Z_new = Zygote.@ignore Z .* exp.(stack.sz(ds)) .+ stack.tz(ds)

The problem with this solution is that I want to differentiate stack.sz, stack.tz, stack.deep_set… I need to differentiate Z_new, but do not want to take it as far as to differentiate the functions which created Y. I don’t know if it makes sense.

I’m afraid there is something more complicated happening under the hood. I guess I need to inspect further and figure out what to do with it.

Does someone have an idea how to figure out where is the array mutation happening?

Why not do this:

depset = stack.sz(ds)
tz = stack.tz(ds)
Z_new = Zygote.@ignore Z .* exp.(depset) .+ tz

Is this somewhat good? I am not sure this is right… I didn’t get the exact problem so just do something like this to separate certain grad and do only forward with some of the function… :smiley:

As a quick guess, try adding DistributionsAD.jl to the scope of your code. If it doesn’t help, could you please post the exact error you get?

2 Likes

Let me second the suggestion to use DistributionsAD. Without it, you’d have to reimplement the distributions themselves by hand in order to keep things AD friendly (e.g. the Flux model zoo examples don’t use Distributions at all).

More generally, Utilities · Zygote is a stop-grad operation equivalent to PyTorch’s detach or JAX’s stop_grad. If DistributionsAD doesn’t work out, I think this should at least handle the Z -> Y_hat connection.

Apparently, Flux tries to differentiate the z0 = rand(model.base) expression. Can I use something as dropgrad on z0 to tell Flux not to differentiate the expression? I tried z0 = dropgrad, but it does not make a difference. I made the model.base distribution a DistributionsAD.jl’s TuringMvNormal, but it didn’t help either. As the distribution is isotropic Gaussian, I don’t want it to be trainable…

Do you mind posting the updated loss code? I’m not qualified to comment on the validity of distributions, but dropgrad is a function and should be used like z0 = dropgrad(rand(model.base)) or otherexpression(dropgrad(z0)).

I just added dropgrad as you are saying:

function loss_setflow(model::SetFlow,X)
    z0 = dropgrad(rand(model.base))
    Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
    -sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end

I still get this error:

julia>     Flux.train!(loss,ps,train_data,opt)
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.var"#364#365")(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\array.jl:58
 [3] (::Zygote.var"#2246#back#366"{Zygote.var"#364#365"})(::Nothing) at C:\Users\masen\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] _modify! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:83 [inlined]
 [5] (::typeof(∂(_modify!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0  
 [6] generic_mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:108 [inlined]
 [7] (::typeof(∂(generic_mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [8] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:126 [inlined]
 [9] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\matmul.jl:208 [inlined]
 [10] (::typeof(∂(mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0     
 [11] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\scalmat.jl:66 [inlined]
 [12] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [13] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\generics.jl:33 [inlined]
 [14] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [15] _rand! at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariate\mvnormal.jl:275 [inlined]
 [16] (::typeof(∂(_rand!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0   
 [17] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariates.jl:76 [inlined]
 [18] (::typeof(∂(rand)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [19] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\genericrand.jl:22 [inlined]
 [20] loss_setflow at .\REPL[527]:2 [inlined]
 [21] (::typeof(∂(loss_setflow)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [22] loss at .\REPL[537]:1 [inlined]
 [23] (::typeof(∂(loss)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [24] #150 at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
 [25] #1694#back at C:\Users\masen\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [26] #15 at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:83 [inlined]
 [27] (::typeof(∂(λ)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [28] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172
 [29] gradient(::Function, ::Params) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
 [30] macro expansion at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:82 [inlined]
 [31] macro expansion at C:\Users\masen\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [32] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM; cb::Flux.Optimise.var"#16#22") at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:80
 [33] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM) at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:78
 [34] top-level scope at REPL[540]:1

I have the loss function defined as loss(x) = loss_setflow(set_flow,x), where set_flow is initiated SetFlow model.

I finally figured out a workaround…

function loss_setflow(model::SetFlow,X)
    z0 = Float32.(randn(length(model.base)))
    Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
    -sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end

This works…

3 Likes