# Tell Flux not to differentiate something

Hi, I’m trying to create a model using Flux.

The model is quite complicated (it builds on Normalizing Flows, RealNVP) and uses many layers on top of each other. There are two inputs - vector `Z` and matrix `X`. `X` is transformed to matrix `Y` using neural networks that depend on `Z`. And `Z` is then transformed to `Z_new` using another neural network that also takes `Y` as input. You can see the calculations below, all `stack.` functions are some neural networks.

``````function (stack::SetFlowStack)(input)
X, logJxy, Z, logJz = input
# calculate RealNVP part
Y_hat = X .* exp.(stack.sx(Z)) .+ stack.tx(Z)
Y, tmpJ = stack.nvp((Y_hat,logJxy))
logJY = tmpJ .- sum(stack.sx(Z),dims=1)

# calculate global noise part
ds = stack.deep_set(Y)
Z_new = Z .* exp.(stack.sz(ds)) .+ stack.tz(ds)
logJZ = logJz .- sum(stack.sz(ds),dims=1)

return Y, logJY, Z_new, logJZ
end
``````

The loss function is

``````function loss_setflow(model::SetFlow,X)
z0 = rand(model.base)
Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
-sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end
``````

where `model.base` is just simple isotropic Gaussian and `SetFlow` is just a `Chain` of `SetFlowStack`s.

The problem is, I can’t differentiate the loss function because of the mutation of arrays (I suspect mutating `Z` and `Y`).

Is there a way to tell Flux (for now) that `Y_hat` should only take `Z` as input variable (not differentiate it), and `Z_new` should not differentiate `Y`?

2 Likes

Hey,

I think `Zygote.@ignore` is what you are looking for:

``````Zygote.@ignore c .* v
Zygote.@ignore println("no differentiation on this unit")

Z_new = Zygote.@ignore Z .* exp.(stack.sz(ds)) .+ stack.tz(ds)
``````

The problem with this solution is that I want to differentiate `stack.sz, stack.tz, stack.deep_set`… I need to differentiate `Z_new`, but do not want to take it as far as to differentiate the functions which created Y. I don’t know if it makes sense.

I’m afraid there is something more complicated happening under the hood. I guess I need to inspect further and figure out what to do with it.

Does someone have an idea how to figure out where is the array mutation happening?

Why not do this:

``````depset = stack.sz(ds)
tz = stack.tz(ds)
Z_new = Zygote.@ignore Z .* exp.(depset) .+ tz
``````

Is this somewhat good? I am not sure this is right… I didn’t get the exact problem so just do something like this to separate certain grad and do only forward with some of the function…

As a quick guess, try adding DistributionsAD.jl to the scope of your code. If it doesn’t help, could you please post the exact error you get?

2 Likes

Let me second the suggestion to use DistributionsAD. Without it, you’d have to reimplement the distributions themselves by hand in order to keep things AD friendly (e.g. the Flux model zoo examples don’t use `Distributions` at all).

More generally, Utilities · Zygote is a stop-grad operation equivalent to PyTorch’s `detach` or JAX’s `stop_grad`. If DistributionsAD doesn’t work out, I think this should at least handle the `Z -> Y_hat` connection.

Apparently, Flux tries to differentiate the `z0 = rand(model.base)` expression. Can I use something as `dropgrad` on `z0` to tell Flux not to differentiate the expression? I tried `z0 = dropgrad`, but it does not make a difference. I made the `model.base` distribution a DistributionsAD.jl’s `TuringMvNormal`, but it didn’t help either. As the distribution is isotropic Gaussian, I don’t want it to be trainable…

Do you mind posting the updated loss code? I’m not qualified to comment on the validity of distributions, but dropgrad is a function and should be used like `z0 = dropgrad(rand(model.base))` or `otherexpression(dropgrad(z0))`.

I just added `dropgrad` as you are saying:

``````function loss_setflow(model::SetFlow,X)
Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
-sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end
``````

I still get this error:

``````julia>     Flux.train!(loss,ps,train_data,opt)
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] (::Zygote.var"#364#365")(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\array.jl:58
[4] _modify! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:83 [inlined]
[5] (::typeof(∂(_modify!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[6] generic_mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:108 [inlined]
[7] (::typeof(∂(generic_mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[8] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:126 [inlined]
[9] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\matmul.jl:208 [inlined]
[10] (::typeof(∂(mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[11] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\scalmat.jl:66 [inlined]
[12] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [13] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\generics.jl:33 [inlined]
[14] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [15] _rand! at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariate\mvnormal.jl:275 [inlined]
[16] (::typeof(∂(_rand!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[17] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariates.jl:76 [inlined]
[18] (::typeof(∂(rand)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[19] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\genericrand.jl:22 [inlined]
[20] loss_setflow at .\REPL[527]:2 [inlined]
[21] (::typeof(∂(loss_setflow)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[22] loss at .\REPL[537]:1 [inlined]
[23] (::typeof(∂(loss)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[24] #150 at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
[26] #15 at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:83 [inlined]
[27] (::typeof(∂(λ)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[28] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172
[30] macro expansion at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:82 [inlined]
[31] macro expansion at C:\Users\masen\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
[32] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM; cb::Flux.Optimise.var"#16#22") at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:80
[33] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM) at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:78
[34] top-level scope at REPL[540]:1
``````

I have the `loss` function defined as `loss(x) = loss_setflow(set_flow,x)`, where `set_flow` is initiated SetFlow model.

I finally figured out a workaround…

``````function loss_setflow(model::SetFlow,X)
z0 = Float32.(randn(length(model.base)))
Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
-sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end
``````

This works…

3 Likes