Doh .
That’s bad from me, thank you both for that.
That question was more of a prelude for a slightly more involved example where I was seeing similar behaviour, this time without my faux pas with function declarations.
using Zygote
using Statistics
using Optimisers
a1 = 1f0 .+ randn(Float32, (1024,1024))
y1 = ones(Float32,(1024,1024))
iter = 100
rule = Optimisers.Adam()
state = Optimisers.setup(rule, a1)
function loss(x,y)
return mean(abs.((x.^2 .- y.^2).^2))
end
function train_instance(a1,y1,state)
grads = gradient(loss, a1,y1)[1]
Optimisers.update!(state, a1, grads)
end
function train(a1,y1,iter,state)
for i in 1:iter
state, a1 = train_instance(a1,y1,state)
end
return a1
end
p = train(a1,y1,iter,state)
This is still a pared down version of what I have but it includes more of the things I’m trying to do. It is contrived and does not perform well but I am most interested in trying to understand where the type instability comes from (Are my global non-constant a1 and y1 causing this? I found using @descend
quite confusing on my first try). This is my first time using optimisers and Zygote having come from JAX so I’m sure I’ve missed some really basic things, please point out all these bits!
When I look at @code_warntype
for the train variable I get the following readout:
MethodInstance for train(::Matrix{Float32}, ::Matrix{Float32}, ::Int64, ::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}})
from train(a1, y1, iter, state) @ Main
Arguments
#self#::Core.Const(train)
a1@_2::Matrix{Float32}
y1::Matrix{Float32}
iter::Int64
state@_5::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}
Locals
@_6::Union{Nothing, Tuple{Int64, Int64}}
@_7::Int64
i::Int64
a1@_9::Any
state@_10::Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}
Body::Any
1 ─ (a1@_9 = a1@_2)
│ (state@_10 = state@_5)
│ %3 = (1:iter)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│ (@_6 = Base.iterate(%3))
│ %5 = (@_6 === nothing)::Bool
│ %6 = Base.not_int(%5)::Bool
└── goto #4 if not %6
2 ┄ %8 = @_6::Tuple{Int64, Int64}
│ (i = Core.getfield(%8, 1))
│ %10 = Core.getfield(%8, 2)::Int64
│ %11 = Main.train_instance(a1@_9, y1, state@_10)::Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Any}
│ %12 = Base.indexed_iterate(%11, 1)::Core.PartialStruct(Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Int64}, Any[Optimisers.Leaf{Adam{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Core.Const(2)])
│ (state@_10 = Core.getfield(%12, 1))
│ (@_7 = Core.getfield(%12, 2))
│ %15 = Base.indexed_iterate(%11, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(3)])
│ (a1@_9 = Core.getfield(%15, 1))
│ (@_6 = Base.iterate(%3, %10))
│ %18 = (@_6 === nothing)::Bool
│ %19 = Base.not_int(%18)::Bool
└── goto #4 if not %19
3 ─ goto #2
4 ┄ return a1@_9