Flux chain type unstable when broadcasting inside gradient

I’ve been fighting for a couple of days to get my Flux/Zygote autodiff code type stable.
I am not sure what’s the problem, but coming up with a MWE, it looks like broadcasting Flux.Chain is problematic ?

using Flux, Zygote
import Statistics: mean

function internfunc_nobroad(m, x, y)
    modelvals = m(x)
    Flux.mse(modelvals, y)

function internfunc_broad(m, x, y)
    modelvals = m.(x)
    mses = Flux.mse.(modelvals, y)
    return mean(mses)

function wrapfunc(model, xdata, ydata, func)
    grad = let xdata=xdata, ydata=ydata
        Zygote.gradient(m -> func(m, xdata, ydata), model)
    return grad

Run the following in REPL

julia> fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))
julia> fx = [fill(5f0, 5) for _ in 1:10]
julia> fy = fill(2f0, 10)
julia> @code_warntype wrapfunc(fc, fx, fy, internfunc_broad) # type unstable

julia> @code_warntype wrapfunc(fc, fx[1], fy[1], internfunc_nobroad) # type stable

I made a similar issue in Flux.jl

Okey, I think I got it… I should convert the input to a matrix and not a Vector of Vectors. Then, Flux handles that nicely.

fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)

@code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)

well… After switching from Flux.mse to Flux.huber_loss I get type unstable code again…

function internfunc_nobroad_huberloss(m, x, y)
    modelvals = m(x)
    Flux.huber_loss(modelvals, y)

@code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)

This looks definitely like a bug.
I made an issue. Feel free to drop some hints if you know why is that and how could it be tackled.