Enabling broadcasting for named tuples alt. broadcast? broadcasted? broadcastable? BroadcastStyle?

Like many before me (if my googling serves me right), I found myself in need of broadcasting over named tuples. In this case it is in the hope that I can get MLUtils.jl to handle batches that are (you guessed it) named tuples.

This has deliberately been left undefined in wait of a decision about what semantics are preferable in this case. It has been pointed out many times, however, that en lieu of an official implementation, one is always free to define it oneself. I thought I’d try. I wrote the method below to override the broadcast behaviour in julia 1.8.

function Base.Broadcast.broadcast(f, nt::NamedTuple)
    fns = fieldnames(typeof(nt))
    values = map(x -> getfield(nt, x), fns)
    return (; zip(fns, Base.Broadcast.broadcast(f, values))...)

I still get an error, though:

ArgumentError("broadcasting over dictionaries and `NamedTuple`s is reserved")

broadcastable(#unused#::NamedTuple{(:d₁, :d₂, :d₃, :l), Tuple{Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Matrix{Float32}}}) at broadcast.jl:683

I start to think that also broadcastable needs to be overridden, especially since it is the function that throws the placeholder error message. What does that function do, though. It seems to me to be there mainly to answer the question: is it broadcastable? It also allows for a proxy object to be returned so that it can be broadcasted in place of the original, should that be needed.

Since I already said what broadcast should do to a named tuple (at least in a subset of situations), I imagine that I only need:

Base.Broadcast.broadcastable(nt::NamedTuple) = nt

Whereas the error message before was just what I got before overriding Base.Broadcast.broadcast, I now get a new error message:

MethodError(ndims, (NamedTuple{(:d₁, :d₂, :d₃, :l), Tuple{Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Matrix{Float32}}},), 0x00000000000075a7)

Base.Broadcast.BroadcastStyle(#unused#::Type{NamedTuple{(:d₁, :d₂, :d₃, :l), Tuple{Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Matrix{Float32}}}}) at broadcast.jl:103

combine_styles(c::NamedTuple{(:d₁, :d₂, :d₃, :l), Tuple{Array{Float32, 4}, Array{Float32, 4}, Array{Float32, 4}, Matrix{Float32}}}) at broadcast.jl:420

broadcasted at broadcast.jl:1309 [inlined]

It seems like I should also override ndims or maybe subtype Base.Broadcast.BroadcastStyle. Perhaps a judicious extension to combine_styles would make my problem go away; or is it Base.Broadcast.broadcasted that needs to be overridden. I am uncertain where in the documentation these things are fully clarified.

Basically I wonder what to do or if there is some other more natural path to take to accomplish what I wish to do, which is basically to take advantage of mapobs in MLUtils in conjunction with Flux.gpu to deliver data in a timely fashion to my graphics card without changing every implemtation decision I’ve made for the last year.

1 Like

I have since I last wrote found and read this in the docs and also a blog post on the subject of overriding broadcasting.

Even so, I find it a bit hard to wrap my head around. Is there someone who can break it down for me?

In my experimentation I discovered that Base.Broadcast.broadcast(sin, my_test_nt) works if I override Base.Broadcast.broadcast(::F, ::NT) where {F<:Function, NT<:NamedTuple}, but sin.(my_test_nt) doesn’t.

After reading the links above I realise that it is much more complicated. What is more, it seems that Base.Broadcast.broadcast doesn’t need to be overridden at all. I’ll try the documentation again, but I would still be grateful for any wise words about how to think about it. E.g. I defined Base.Broadcast.broadcastable(nt::NamedTuple) = nt rather than Ref(nt) with the idea that MLUtils.mapobs needed to recurse into the elements (in this case multidimensional arrays). Is this right? I also wonder why it isn’t possible for MLUtils.mapobs to operate on named tuples but it works for Flux.DataLoader?

In conclusion, any advice or hint appreciated!