# Type inference for recursive broadcasting operations

Apologizes for the relatively long minimal example.

I want to write an efficient broadcasting which is able to sum results from different objects. As an example, suppose I want to compute the sum of periodic functions, as in this example:

``````abstract type Periodic{S <: Real} end

struct Sine{S} <: Periodic{S}
a::S
ω::S
φ::S
end

compute(h::Sine, x::Real) = h.a * sin(h.ω*x + h.φ)

struct Tophat{S} <: Periodic{S}
a::S
ω::S
φ::S
end

compute(h::Tophat, x::Real) = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Periodicsum{S}
f::Vector{Periodic{S}}
name::String
end

function compute(hs::Periodicsum, x::S) where {S}
ℓ = length(hs.f)
result = zero(S)
for n ∈ 1:ℓ
result += compute(hs.f[n], x)
end
result
end

function Base.broadcasted(::typeof(compute), hs::Periodicsum, xs::Array{S}) where {S}
ℓ = length(hs.f)
for n ∈ 2:ℓ
end
result
end

h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
hs = Periodicsum([h1, h2], "test")
``````

So, essentially, I have two kind of periodic functions, `Sine` and `Tophat`, and a container represented by the `Periodicsum` object. Calling `compute` will return the value of the periodic function or of the sum of periodic functions at a given point.

My problem is that a broadcast call to `compute` on a `Periodicsum` returns `Any` as an inferred static type, leading to potentially un-specialized code:

``````julia> @code_warntype compute.(hs, rand(100))
Variables
#self#::Core.Const(var"##dotfunction#258#2"())
x1::Periodicsum{Float64}
x2::Vector{Float64}

Body::Any
│   %2 = Base.materialize(%1)::Any
└──      return %2
``````

As far as I understand, this is due to the fact that the length of the array `hs.f` is no known at compile time: hence the `Broadcast.Broadcasted` object cannot be built with stable type.

Is there any possible cure for this issue? Note that it is probably not wise to use tuples for `hs.f`, since in principle I might have to use hundreds of periodic functions in a `Periodicsum`.

Edit: my bad, the first answer didn’t really test the program. Corrected version

``````abstract type Periodic{S <: Real} end

struct Sine{S <: Real} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(f::Sine{S}) where S <: Real = Base.RefValue(f)

compute(h::Sine{S}, x::S) where S <: Real = h.a * sin(h.ω*x + h.φ)

struct Tophat{S} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(h::Tophat{S}) where S <: Real = Base.RefValue(h)

compute(h::Tophat{S}, x::S) where S <: Real = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Periodicsum{S <: Real, T <: Periodic{S}}
f::Vector{T}
name::String
end

Base.broadcastable(hs::Periodicsum{S, T}) where {S <: Real, T <: Periodic{S}} = Base.RefValue(hs)

function compute(hs::Periodicsum{S, T}, x::S) where {S <: Real, T <: Periodic{S}}
ℓ = length(hs.f)
result = zero(S)
for n ∈ 1:ℓ
result += compute(hs.f[n], x)
end
result
end

h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
hs = Periodicsum([h1, h2], "test")
println(typeof(hs))
@code_warntype compute.(hs, rand(100))
compute.(hs, rand(100))
``````

shows

``````Periodicsum{Float64, Periodic{Float64}}
Variables
#self#::Core.Const(var"##dotfunction#257#7"())
x1::Periodicsum{Float64, Periodic{Float64}}
x2::Vector{Float64}

Body::Vector{Float64}
│   %2 = Base.materialize(%1)::Vector{Float64}
└──      return %2
100-element Vector{Float64}:
-1.475712631100109
-1.4546652819808716
-0.052988573776913156
-1.3419741424024791
1.730537805636733
-1.3103407114890016
-0.920771920686877
1.723220395758223
-0.15091175831418654
-0.6738947625140821
2.018530913870344
⋮
-0.6743266563361617
0.01418025679568403
1.9375723493945216
-0.9912883945101121
1.8683303901011636
-0.7796450936086905
-1.437739137958452
-0.46149396055626707
0.2785050637222302
-1.0134150998205733
``````
1 Like

Thank you for your reply and for taking time to check this. However, I think there is a problem with your code: both `{S <: Periodic}` in the `where` clause should be `{S <: Real}` or just `{S}`. In fact, the way you wrote it, a call to `compute.(hs, rand(100))` will raise an error. And, in fact, the computed type in your solution, i.e. `Vector{Union{}}` is quite suspicious.

With the corrected code, however, nothing changes, unfortunately: I still have `Any` as a computed type, as in my original implementation.

Sorry for my mistake, please see the corrected version.

Thank you again @goerch. I am not entirely sure why your code works: it seems very similar to mine, with just a few additions constraining `S <: Real`, but apparently this is enough to help Julia understand the final type.

The issue I have now is that this minimal code does not really capture the complexity of my situation. For example, when I add a couple of more `Periodic` functions, the type inference fails: I guess the reason is associated to the fact that the now we have a union type composed of too many options. Here is the new code that fails:

``````abstract type Periodic{S <: Real} end

struct Sine{S <: Real} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(f::Sine{S}) where S <: Real = Base.RefValue(f)

compute(h::Sine{S}, x::S) where S <: Real = h.a * sin(h.ω*x + h.φ)

struct Tophat{S <: Real} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(h::Tophat{S}) where S <: Real = Base.RefValue(h)

compute(h::Tophat{S}, x::S) where S <: Real = h.a * round(cos.(h.ω*x/2 + h.φ)^2)

struct Saw{S <: Real} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(h::Saw{S}) where S <: Real = Base.RefValue(h)

compute(h::Saw{S}, x::S) where S <: Real = h.a * mod2pi(h.ω*x + h.φ)

struct Rsaw{S <: Real} <: Periodic{S}
a::S
ω::S
φ::S
end

Base.broadcastable(h::Rsaw{S}) where S <: Real = Base.RefValue(h)

compute(h::Rsaw{S}, x::S) where S <: Real = h.a * trunc(mod2pi(h.ω*x*2 + h.φ))

struct Periodicsum{S <: Real}
f::Vector{Periodic{S}}
name::String
end

Base.broadcastable(hs::Periodicsum{S}) where {S <: Real} = Base.RefValue(hs)

function compute(hs::Periodicsum{S}, x::S) where {S <: Real}
ℓ = length(hs.f)
result = zero(S)
for n ∈ 1:ℓ
result += compute(hs.f[n], x)
end
result
end

h1 = Sine(2.0, 1.3, 2.7)
h2 = Tophat(1.2, 2.4, 0.6)
h3 = Saw(3.0, 2.1, 4.3)
h4 = Rsaw(0.2, 0.3, 0.5)
hs = Periodicsum([h1, h2], "test")
println(typeof(hs))
@code_warntype compute.(hs, rand(100))
``````

I believe the main issue was/is `Periodicsum` not being concrete. That is what the change

``````struct Periodicsum{S <: Real, T <: Periodic{S}}
f::Vector{T}
name::String
end
``````

tried to address. Other than that I use type annotations deliberately to better understand unknown programs, sorry.

I am not entirely sure, since your code with my original definition of `Periodicsum` works nicely (it produces a concrete type with broadcasting); also, your definition for the field `f::Vector{T}` where `T <: Periodic{S}` ends up by creating a non-concrete vector field when one initialises `hs = Periodicsum([h1, h2], "test")`, since `h1` and `h2` are of different types: in fact

``````julia> typeof(hs)
Periodicsum{Float64, Periodic{Float64}}
``````