I just added dropgrad
as you are saying:
function loss_setflow(model::SetFlow,X)
z0 = dropgrad(rand(model.base))
Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
-sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end
I still get this error:
julia> Flux.train!(loss,ps,train_data,opt)
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] (::Zygote.var"#364#365")(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\array.jl:58
[3] (::Zygote.var"#2246#back#366"{Zygote.var"#364#365"})(::Nothing) at C:\Users\masen\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[4] _modify! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:83 [inlined]
[5] (::typeof(∂(_modify!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[6] generic_mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:108 [inlined]
[7] (::typeof(∂(generic_mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[8] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\generic.jl:126 [inlined]
[9] mul! at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\matmul.jl:208 [inlined]
[10] (::typeof(∂(mul!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[11] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\scalmat.jl:66 [inlined]
[12] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [13] unwhiten! at C:\Users\masen\.julia\packages\PDMats\G0Prn\src\generics.jl:33 [inlined]
[14] (::typeof(∂(unwhiten!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0 [15] _rand! at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariate\mvnormal.jl:275 [inlined]
[16] (::typeof(∂(_rand!)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[17] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\multivariates.jl:76 [inlined]
[18] (::typeof(∂(rand)))(::Nothing) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[19] rand at C:\Users\masen\.julia\packages\Distributions\HjzA0\src\genericrand.jl:22 [inlined]
[20] loss_setflow at .\REPL[527]:2 [inlined]
[21] (::typeof(∂(loss_setflow)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[22] loss at .\REPL[537]:1 [inlined]
[23] (::typeof(∂(loss)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[24] #150 at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191 [inlined]
[25] #1694#back at C:\Users\masen\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
[26] #15 at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:83 [inlined]
[27] (::typeof(∂(λ)))(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
[28] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:172
[29] gradient(::Function, ::Params) at C:\Users\masen\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
[30] macro expansion at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:82 [inlined]
[31] macro expansion at C:\Users\masen\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
[32] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM; cb::Flux.Optimise.var"#16#22") at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:80
[33] train!(::Function, ::Params, ::Array{Array{Float32,2},1}, ::ADAM) at C:\Users\masen\.julia\packages\Flux\05b38\src\optimise\train.jl:78
[34] top-level scope at REPL[540]:1
I have the loss
function defined as loss(x) = loss_setflow(set_flow,x)
, where set_flow
is initiated SetFlow model.
I finally figured out a workaround…
function loss_setflow(model::SetFlow,X)
z0 = Float32.(randn(length(model.base)))
Y, logJY, Z_new, logJZ = model((X,_init_logJ(X),z0,_init_logJ(z0)))
-sum(logpdf(MvNormal(size(Y,1),1),Y) .+ logJY') - sum(logpdf(model.base,Z_new) .+ logJZ)
end
This works…