I have notice a strange behaviour in Flux, and wonder if it is what should be expected or not. This is part of a WGAN.
If D is defined by chaining Dense layers, everything works well.
using Flux, Random, Statistics
using Flux: params
using Flux.Optimise: update!
rng = MersenneTwister(76543)
G = Chain(
Dense(4,28),
Dense(28, 14),
Dense(14, 14)
)
D = Chain(
Dense(14, 14),
Dense( 14, 28),
Dense(28, 1, sigmoid)
#Dense(hparams.n_features, 1, sigmoid)
)
x = randn(rng,Float32,14,100)
z = G(randn(rng,Float32,4,100))
function update_D!(D,x,z,rng)
pd = params(D)
ϵ = rand(rng,Float32,1,100)
xb = ϵ .* z+(1f0 .- ϵ) .* x
∇pdL = gradient(pd) do
∇xbD, = gradient(xb -> sum(D(xb)), xb)
loss = mean(D(z)) - mean(D(x)) + .3 * mean((sqrt.(sum(∇xbD.^2, dims=1) .+ 1f-12) .- 1f0).^2)
end
update!(ADAM(.001), pd, ∇pdL) # update discriminator/critic
end
update_D!(D,x,z,rng)
However, when a Dense layer is changed to a LSTM layer as below
rng = MersenneTwister(76543)
G = Chain(
Dense(4,28),
LSTM(28, 14),
Dense(14, 14)
)
D = Chain(
Dense(14, 14),
LSTM( 14, 28),
Dense(28, 1, sigmoid)
#Dense(hparams.n_features, 1, sigmoid)
)
x = randn(rng,Float32,14,100)
z = G(randn(rng,Float32,4,100))
function update_D!(D,x,z,rng)
pd = params(D)
ϵ = rand(rng,Float32,1,100)
xb = ϵ .* z+(1f0 .- ϵ) .* x
∇pdL = gradient(pd) do
∇xbD, = gradient(xb -> sum(D(xb)), xb)
loss = mean(D(z)) - mean(D(x)) + .3 * mean((sqrt.(sum(∇xbD.^2, dims=1) .+ 1f-12) .- 1f0).^2)
end
update!(ADAM(.001), pd, ∇pdL) # update discriminator/critic
end
update_D!(D,x,z,rng)
There is an error message. I don’t understand where is this error come from as D(x) is well define and produced a 1x100 matrix.
julia> update_D!(D,x,z,rng)
ERROR: MethodError: no method matching size(::ChainRulesCore.Tangent{Any, NTuple{4, Matrix{Float32}}})
Closest candidates are:
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
@ LinearAlgebra ~/.julia/juliaup/julia-1.9.2+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:582
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
@ LinearAlgebra ~/.julia/juliaup/julia-1.9.2+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
@ LinearAlgebra ~/.julia/juliaup/julia-1.9.2+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:585
...
Stacktrace:
[1] axes
@ ./abstractarray.jl:98 [inlined]
[2] _tryaxes(x::ChainRulesCore.Tangent{Any, NTuple{4, Matrix{Float32}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:188
[3] map
@ ./tuple.jl:274 [inlined]
[4] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:322 [inlined]
[5] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[6] _pullback
@ ./iterators.jl:370 [inlined]
[7] _pullback(::Zygote.Context{true}, ::typeof(zip), ::NTuple{4, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::ChainRulesCore.Tangent{Any, NTuple{4, Matrix{Float32}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[8] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[9] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
[10] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[11] _pullback
@ ./abstractarray.jl:3074 [inlined]
[12] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::Flux.var"#265#267", ::NTuple{4, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::ChainRulesCore.Tangent{Any, NTuple{4, Matrix{Float32}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[13] _pullback
@ ~/.julia/packages/Flux/n3cOc/src/layers/recurrent.jl:12 [inlined]
[14] _pullback(ctx::Zygote.Context{true}, f::Flux.var"#multigate_pullback#266"{Matrix{Float32}, Int64, Val{4}}, args::ChainRulesCore.Tangent{Any, NTuple{4, Matrix{Float32}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[15] _pullback
@ ~/.julia/packages/Zygote/JeHtr/src/compiler/chainrules.jl:211 [inlined]
[16] _pullback(ctx::Zygote.Context{true}, f::Zygote.ZBack{Flux.var"#multigate_pullback#266"{Matrix{Float32}, Int64, Val{4}}}, args::NTuple{4, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[17] _pullback
@ ~/.julia/packages/Flux/n3cOc/src/layers/recurrent.jl:315 [inlined]
[18] _pullback(ctx::Zygote.Context{true}, f::Zygote.Pullback{Tuple{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#back#241"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Matrix{Float32}}}}, Zygote.var"#back#242"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{4, 2, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#309"{Zygote.Jnew{Pair{Int64, Int64}, Nothing, false}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}}}, Zygote.ZBack{Flux.var"#_size_check_pullback#201"{Tuple{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Matrix{Float32}, Pair{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}, Zygote.var"#back#242"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{4, 3, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}}, Zygote.ZBack{NNlib.var"#broadcasted_sigmoid_fast_pullback#151"{Matrix{Float32}}}, Zygote.ZBack{NNlib.var"#broadcasted_tanh_fast_pullback#145"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(Flux.reshape_cell_output), Matrix{Float32}, Matrix{Float32}}, Tuple{Zygote.var"#2065#back#228"{Zygote.var"#222#226"{2, UnitRange{Int64}}}, Zygote.Pullback{Tuple{typeof(lastindex), Tuple{Int64, Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#747"}}}, Zygote.ZBack{ChainRules.var"#:_pullback#276"{Tuple{Int64, Int64}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.var"#2799#back#621"{Zygote.var"#615#619"{Matrix{Float32}, Tuple{Colon, Int64}}}}}, Zygote.ZBack{ChainRules.var"#size_pullback#917"}}}, Zygote.var"#back#242"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{4, 4, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:Wh, Zygote.Context{false}, Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#muladd_pullback_1#1538"{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{NNlib.var"#broadcasted_sigmoid_fast_pullback#151"{Matrix{Float32}}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Matrix{Float32}, Matrix{Float32}}}, Zygote.ZBack{Flux.var"#174#175"}, Zygote.var"#back#242"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Matrix{Float32}}}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:b, Zygote.Context{false}, Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Vector{Float32}}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:Wi, Zygote.Context{false}, Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Matrix{Float32}}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#size_pullback#919"}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#size_pullback#919"}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.ZBack{NNlib.var"#broadcasted_tanh_fast_pullback#145"{Matrix{Float32}}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.ZBack{NNlib.var"#broadcasted_sigmoid_fast_pullback#151"{Matrix{Float32}}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:Wi, Zygote.Context{false}, Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Matrix{Float32}}}, Zygote.var"#back#241"{Zygote.var"#2033#back#209"{Zygote.var"#back#207"{4, 1, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Matrix{Float32}, Matrix{Float32}}}, Zygote.ZBack{Flux.var"#multigate_pullback#266"{Matrix{Float32}, Int64, Val{4}}}, Zygote.ZBack{ChainRules.var"#muladd_pullback_1#1538"{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}}, Zygote.var"#1926#back#157"{Zygote.var"#153#156"}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{Matrix{Float32}, Matrix{Float32}}}}}}, args::Tuple{Nothing, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
G