Dearests
I am encountering some problems with Flux.
I am working on Julia 1.3.1
Zygote v0.4.6, Flux v0.10.1, CuArrays v1.6.0. No gpus for the time being.
The (not-so) minimal partially working example is the following:
using Flux,Zygote,CuArrays, LinearAlgebra
const q = 4
const N = 10
Z = rand(Float32,3*q,100)
vchain1=[Chain(Dense(q, q), softmax),
Chain(Dense(2*q, q), softmax)]
vchain2 = [Chain(Dense(q, q, relu), Dense(q, q), softmax),
Chain(Dense(2*q, 2*q, relu), Dense(2*q, q), softmax)]
log0(x::Number) = x > 0 ? log(x) : zero(x)
CuArrays.@cufunc log0(x::Number) = x > 0 ? log(x) : zero(x)
function loss(x,vmodel)
logeta = Float32(0.0)
@inbounds for site in 1:2
thechunk = vmodel[site]
idxcond = 1:site*q
zcond = x[idxcond,:]
idxsite = site*q + 1 : (site+1)*q
xsite = x[idxsite,:]
logeta += dot(log0.(thechunk(zcond)),xsite)
end
return -logeta
end
myloss1(x) = loss(x,vchain1)
myloss2(x) = loss(x,vchain2)
myl2_loss1(x) = loss(x,vchain1) + sum(norm,Flux.params(vchain1))
myl2_loss2(x) = loss(x,vchain2) + sum(norm,Flux.params(vchain2))
All four loss functions run smoothly:
julia> [lf(Z) for lf in (myloss1,myloss2,myl2_loss1,myl2_loss2)]
4-element Array{Float32,1}:
585.86035
556.96045
589.7834
566.8873
The computation of gradients is ok for the non l2
regularised loss functions
Flux.gradient(()->myloss1(Z),Flux.params(vchain1))
julia> ∇1=Flux.gradient(()->myloss1(Z),Flux.params(vchain1))
Grads(...)
julia> ∇2=Flux.gradient(()->myloss2(Z),Flux.params(vchain2))
Grads(...)
But the regularised counterparts throw an error.
julia> ∇1=Flux.gradient(()->myl2_loss1(Z),Flux.params(vchain1))
ERROR: Mutating arrays is not supported
....
I believe that this has to do with https://github.com/FluxML/Zygote.jl/issues/231
There are some hints, but I would like to know if there is a suggested solutions.
A second problem arises when I try to train to the non regularised network for which the computation of the gradient seems to work (for mytrain2
the same error is thrown).
Flux.train!(myloss1,Flux.params(vchain1),Z,ADAM(0.001))
ERROR: MethodError: no method matching getindex(::Float32, ::UnitRange{Int64}, ::Colon)
Closest candidates are:
getindex(::Number) at number.jl:75
getindex(::Number, ::Integer) at number.jl:77
getindex(::Number, ::Integer...) at number.jl:82
...
Any ideas?