Type inference for recursive broadcasting operations

Apologizes for the relatively long minimal example.

I want to write an efficient broadcasting which is able to sum results from different objects. As an example, suppose I want to compute the sum of periodic functions, as in this example:

abstract type Periodic{S <: Real} end

struct Sine{S} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(f::Sine) = Base.RefValue(f)

compute(h::Sine, x::Real) = h.a * sin(h.ω*x + h.φ)

struct Tophat{S} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(h::Tophat) = Base.RefValue(h)

compute(h::Tophat, x::Real) = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Periodicsum{S}
    f::Vector{Periodic{S}}
    name::String
end

Base.broadcastable(hs::Periodicsum) = Base.RefValue(hs)

function compute(hs::Periodicsum, x::S) where {S}
    ℓ = length(hs.f)
    result = zero(S)
    for n ∈ 1:ℓ
        result += compute(hs.f[n], x) 
    end
    result
end 

function Base.broadcasted(::typeof(compute), hs::Periodicsum, xs::Array{S}) where {S}
    ℓ = length(hs.f)
    result = Broadcast.Broadcasted(compute, (Base.RefValue(hs.f[1]), xs))
    for n ∈ 2:ℓ
        @inbounds result = Broadcast.Broadcasted(+, (result, Broadcast.Broadcasted(compute, (Base.RefValue(hs.f[n]), xs))))
    end
    result
end
    
h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
hs = Periodicsum([h1, h2], "test")

So, essentially, I have two kind of periodic functions, Sine and Tophat, and a container represented by the Periodicsum object. Calling compute will return the value of the periodic function or of the sum of periodic functions at a given point.

My problem is that a broadcast call to compute on a Periodicsum returns Any as an inferred static type, leading to potentially un-specialized code:

julia> @code_warntype compute.(hs, rand(100))
Variables
  #self#::Core.Const(var"##dotfunction#258#2"())
  x1::Periodicsum{Float64}
  x2::Vector{Float64}

Body::Any
1 ─ %1 = Base.broadcasted(Main.compute, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, F, Args} where {F, Args<:Tuple}
│   %2 = Base.materialize(%1)::Any
└──      return %2

As far as I understand, this is due to the fact that the length of the array hs.f is no known at compile time: hence the Broadcast.Broadcasted object cannot be built with stable type.

Is there any possible cure for this issue? Note that it is probably not wise to use tuples for hs.f, since in principle I might have to use hundreds of periodic functions in a Periodicsum.

Edit: my bad, the first answer didn’t really test the program. Corrected version

abstract type Periodic{S <: Real} end

struct Sine{S <: Real} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(f::Sine{S}) where S <: Real = Base.RefValue(f)

compute(h::Sine{S}, x::S) where S <: Real = h.a * sin(h.ω*x + h.φ)

struct Tophat{S} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(h::Tophat{S}) where S <: Real = Base.RefValue(h)

compute(h::Tophat{S}, x::S) where S <: Real = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Periodicsum{S <: Real, T <: Periodic{S}}
    f::Vector{T}
    name::String
end

Base.broadcastable(hs::Periodicsum{S, T}) where {S <: Real, T <: Periodic{S}} = Base.RefValue(hs)

function compute(hs::Periodicsum{S, T}, x::S) where {S <: Real, T <: Periodic{S}}
    ℓ = length(hs.f)
    result = zero(S)
    for n ∈ 1:ℓ
        result += compute(hs.f[n], x)
    end
    result
end

h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
hs = Periodicsum([h1, h2], "test")
println(typeof(hs))
@code_warntype compute.(hs, rand(100))
compute.(hs, rand(100))

shows

Periodicsum{Float64, Periodic{Float64}}
Variables
  #self#::Core.Const(var"##dotfunction#257#7"())
  x1::Periodicsum{Float64, Periodic{Float64}}
  x2::Vector{Float64}

Body::Vector{Float64}
1 ─ %1 = Base.broadcasted(Main.compute, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(compute), Tuple{Base.RefValue{Periodicsum{Float64, Periodic{Float64}}}, Vector{Float64}}}
│   %2 = Base.materialize(%1)::Vector{Float64}
└──      return %2
100-element Vector{Float64}:
 -1.475712631100109
 -1.4546652819808716
 -0.052988573776913156
 -1.3419741424024791
  1.730537805636733
 -1.3103407114890016
 -0.920771920686877
  1.723220395758223
 -0.15091175831418654
 -0.6738947625140821
  2.018530913870344
  ⋮
 -0.6743266563361617
  0.01418025679568403
  1.9375723493945216
 -0.9912883945101121
  1.8683303901011636
 -0.7796450936086905
 -1.437739137958452
 -0.46149396055626707
  0.2785050637222302
 -1.0134150998205733
1 Like

Thank you for your reply and for taking time to check this. However, I think there is a problem with your code: both {S <: Periodic} in the where clause should be {S <: Real} or just {S}. In fact, the way you wrote it, a call to compute.(hs, rand(100)) will raise an error. And, in fact, the computed type in your solution, i.e. Vector{Union{}} is quite suspicious.

With the corrected code, however, nothing changes, unfortunately: I still have Any as a computed type, as in my original implementation.

Sorry for my mistake, please see the corrected version.

Thank you again @goerch. I am not entirely sure why your code works: it seems very similar to mine, with just a few additions constraining S <: Real, but apparently this is enough to help Julia understand the final type.

The issue I have now is that this minimal code does not really capture the complexity of my situation. For example, when I add a couple of more Periodic functions, the type inference fails: I guess the reason is associated to the fact that the now we have a union type composed of too many options. Here is the new code that fails:

abstract type Periodic{S <: Real} end

struct Sine{S <: Real} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(f::Sine{S}) where S <: Real = Base.RefValue(f)

compute(h::Sine{S}, x::S) where S <: Real = h.a * sin(h.ω*x + h.φ)

struct Tophat{S <: Real} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(h::Tophat{S}) where S <: Real = Base.RefValue(h)

compute(h::Tophat{S}, x::S) where S <: Real = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Saw{S <: Real} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(h::Saw{S}) where S <: Real = Base.RefValue(h)

compute(h::Saw{S}, x::S) where S <: Real = h.a * mod2pi(h.ω*x + h.φ)

struct Rsaw{S <: Real} <: Periodic{S}
    a::S
    ω::S
    φ::S
end

Base.broadcastable(h::Rsaw{S}) where S <: Real = Base.RefValue(h)

compute(h::Rsaw{S}, x::S) where S <: Real = h.a * trunc(mod2pi(h.ω*x*2 + h.φ))

struct Periodicsum{S <: Real}
    f::Vector{Periodic{S}}
    name::String
end

Base.broadcastable(hs::Periodicsum{S}) where {S <: Real} = Base.RefValue(hs)

function compute(hs::Periodicsum{S}, x::S) where {S <: Real}
    ℓ = length(hs.f)
    result = zero(S)
    for n ∈ 1:ℓ
        result += compute(hs.f[n], x)
    end
    result
end

h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
h3 = Saw(3.0, 2.1, 4.3)
h4 = Rsaw(0.2, 0.3, 0.5)
hs = Periodicsum([h1, h2], "test")
println(typeof(hs))
@code_warntype compute.(hs, rand(100))

I believe the main issue was/is Periodicsum not being concrete. That is what the change

struct Periodicsum{S <: Real, T <: Periodic{S}}
    f::Vector{T}
    name::String
end

tried to address. Other than that I use type annotations deliberately to better understand unknown programs, sorry.

I am not entirely sure, since your code with my original definition of Periodicsum works nicely (it produces a concrete type with broadcasting); also, your definition for the field f::Vector{T} where T <: Periodic{S} ends up by creating a non-concrete vector field when one initialises hs = Periodicsum([h1, h2], "test"), since h1 and h2 are of different types: in fact

julia> typeof(hs)
Periodicsum{Float64, Periodic{Float64}}