I’ve been fighting for a couple of days to get my Flux
/Zygote
autodiff code type stable.
I am not sure what’s the problem, but coming up with a MWE, it looks like broadcasting Flux.Chain
is problematic ?
using Flux, Zygote
import Statistics: mean
function internfunc_nobroad(m, x, y)
modelvals = m(x)
Flux.mse(modelvals, y)
end
function internfunc_broad(m, x, y)
modelvals = m.(x)
mses = Flux.mse.(modelvals, y)
return mean(mses)
end
function wrapfunc(model, xdata, ydata, func)
grad = let xdata=xdata, ydata=ydata
Zygote.gradient(m -> func(m, xdata, ydata), model)
end
return grad
end
Run the following in REPL
julia> fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))
julia> fx = [fill(5f0, 5) for _ in 1:10]
julia> fy = fill(2f0, 10)
julia> @code_warntype wrapfunc(fc, fx, fy, internfunc_broad) # type unstable
julia> @code_warntype wrapfunc(fc, fx[1], fy[1], internfunc_nobroad) # type stable
I made a similar issue in Flux.jl