How to make a locally defined function type-stable?

Hi,
I’m wondering how and where I can add type annotations to make some code type-stable.

I’m working with a structure:

struct sensor{T} 
    location
    normal
    α
    parameters
    id
end

The fields and location, normal are always 3-element floats, T, is a type parameter that specifies the type of sensor and is not float or similar.

I’m running into type-stability issues when I pass this a struct into a function for integration with hcubature. In particular, the variables {xyz}_c and {xyz}_d are found to be any.

Does anyone have a tip how to add type annotations to the sensor struct to make the integrand type-stable?

function cylinder_sensor_response(
    s::sensor,
    X_c,
    r::T,
    h::T,
    I0::T;
    reltol::T = 1e-6,
    abstol::T = 1e-9
) where {T}
    # Unpack arguments
    x_c, y_c, z_c = X_c
    x_d, y_d, z_d = s.location
   
    
    # Define the integrand in cylindrical coords: (ρ, φ, z_local)
    # domain: ρ ∈ [0,r], φ ∈ [0,2π], z_local ∈ [-h/2, +h/2]
    integrand(ρφz::SVector{3, T2}) where {T2} = @inbounds begin
        ρ       = ρφz[1]
        φ       = ρφz[2]
        z_local = ρφz[3]
        
        # Convert from cylinder coords to cartesian
        x_s = x_c + ρ*cos(φ)
        y_s = y_c + ρ*sin(φ)
        z_s = z_c + z_local
        
        # Vector from detector to source point
        dx = x_s - x_d
        dy = y_s - y_d
        dz = z_s - z_d
        
        # Distance
        dist2 = dx*dx + dy*dy + dz*dz
        #
        return I0 * collimation_factor(det, Point3(x_s, y_s, z_s)) / dist2 * ρ
    end
    
    # Bounds in cylindrical coordinates
    lower_bounds = [0.0,    0.0,   -h/2]
    upper_bounds = [r,      2π,    +h/2]
    
    # Perform the triple integral
    result, err = hcubature(
        integrand,
        lower_bounds,
        upper_bounds;
        rtol = reltol,
        atol = abstol
    )
    
    return result
end
MethodInstance for integrand(::SVector{3, Float64})
  from integrand(ρφz::SVector{3, T2}) where T2 @ Main ~/julia_envs/phantoms_base/test/test_gridfree.jl:43
Static Parameters
  T2 = Float64
Arguments
  #self#::Core.Const(Main.integrand)
  ρφz::SVector{3, Float64}
Locals
  val::Union{}
  dist2::Any
  dz::Any
  dy::Any
  dx::Any
  z_s::Any
  y_s::Any
  x_s::Any
  z_local::Float64
  φ::Float64
  ρ::Float64
Body::Any
1 ─       nothing
│         (ρ = Base.getindex(ρφz, 1))
│         (φ = Base.getindex(ρφz, 2))
│         (z_local = Base.getindex(ρφz, 3))
│   %5  = Main.:+::Core.Const(+)
│   %6  = Main.x_c::Any
│   %7  = Main.:*::Core.Const(*)
│   %8  = ρ::Float64
│   %9  = φ::Float64
│   %10 = Main.cos(%9)::Float64
│   %11 = (%7)(%8, %10)::Float64
│         (x_s = (%5)(%6, %11))
│   %13 = Main.:+::Core.Const(+)
│   %14 = Main.y_c::Any
│   %15 = Main.:*::Core.Const(*)
│   %16 = ρ::Float64
│   %17 = φ::Float64
│   %18 = Main.sin(%17)::Float64
│   %19 = (%15)(%16, %18)::Float64
│         (y_s = (%13)(%14, %19))
│   %21 = Main.:+::Core.Const(+)
│   %22 = z_local::Float64
│         (z_s = (%21)(Main.z_c, %22))
│   %24 = x_s::Any
│         (dx = %24 - Main.x_d)
│   %26 = y_s::Any
│         (dy = %26 - Main.y_d)
│   %28 = z_s::Any
│         (dz = %28 - Main.z_d)
│   %30 = Main.:+::Core.Const(+)
│   %31 = dx::Any
│   %32 = dx::Any
│   %33 = (%31 * %32)::Any
│   %34 = dy::Any
│   %35 = dy::Any
│   %36 = (%34 * %35)::Any
│   %37 = dz::Any
│   %38 = dz::Any
│   %39 = (%37 * %38)::Any
│         (dist2 = (%30)(%33, %36, %39))
│   %41 = Main.:*::Core.Const(*)
│   %42 = Main.:/::Core.Const(/)
│   %43 = Main.:*::Core.Const(*)
│   %44 = Main.I0::Any
│   %45 = Main.collimation_factor::Core.Const(phantoms_base.collimation_factor)
│   %46 = Main.d::Any
│   %47 = x_s::Any
│   %48 = y_s::Any
│   %49 = z_s::Any
│   %50 = Main.Point3(%47, %48, %49)::Point
│   %51 = (%45)(%46, %50)::Any
│   %52 = (%43)(%44, %51)::Any
│   %53 = dist2::Any
│   %54 = (%42)(%52, %53)::Any
│   %55 = ρ::Float64
│   %56 = (%41)(%54, %55)::Any
└──       return %56
2 ─       Core.Const(:(val = nothing))
│         Core.Const(nothing)
│         Core.Const(:(val))
└──       Core.Const(:(return %60))

When you define a struct, any field without a type parameter defaults to type Any. When your function taking a sensor as input gets compiled, it can’t know the types of each element in general ahead of time.

2 Likes

Right. But when I define something like

s = sensor(Vec3d(0.0, 0.0, 0.0), Vec3d(1.0, 0.0, 0.0)...)

cylinder_sensor_response(s, Vec3d(0.0, 0.0, 0.0), ...)

shouldn’t it deduce the type of X_c and s.location when the function integrand is defined?

Anyway, I fixed type stability by explicitly defining the types of the struct members.

No. Think about what such a struct actually is. If you define

struct Foo
    x
end

then you are telling Julia to define a type Foo where foo.x can be any type of object. How could Julia store such a data structure? What would it look like in memory? The only way to store it for foo.x to be a pointer to an object with a type tag that can be looked up at runtime.

In contrast, if you do:

struct Foo{T}
    x::T
end

then you are telling Julia to define not a single type, but a family of types, in which each Foo{T} type has a field x of type T. e.g. if foo is an object of type Foo{Float64}, it can be stored as a single Float64 in memory, not a pointer to a Float64 object with a type tag — no pointer and no type tag are needed for foo.x, because the type of x is already part of the type of foo.

In your case, you defined struct sensor{T} , which defines a family of sensor{T} types, but you didn’t actually use T for the fields. So each sensor{T} object couls still hold fields of arbitrary types (regardless of T), so they have to be stored as pointers to objects with type tags that have to be looked up at runtime.

9 Likes

Thanks for the detailed explanation.
The example below describes better what confuses me. First define a=2.0 and it’s a Float64. Then define a closure, where the type of a is known when the function is defined. Calling f with typeof(x)=Float64 should be known to yield a Float64, but @code_warntype gives an Any.

julia> a = 2.0
2.0

julia> f(x) = a + x
f (generic function with 1 method)

julia> @code_warntype f(1.0)
MethodInstance for f(::Float64)
  from f(x) @ Main REPL[2]:1
Arguments
  #self#::Core.Const(Main.f)
  x::Float64
Body::Any
1 ─ %1 = Main.:+::Core.Const(+)
│   %2 = (%1)(Main.a, x)::Any
└──      return %2


If I define the same closure inside a function, it’s suddenly stable:

julia> function g(x)
       a = 2.0
       f(y) = a + x
       f(x)
       end
g (generic function with 1 method)

julia> g(2.0)
4.0

julia> @code_warntype g(2.0)
MethodInstance for g(::Float64)
  from g(x) @ Main REPL[1]:1
Arguments
  #self#::Core.Const(Main.g)
  x::Float64
Locals
  f::var"#f#1"{Float64, Float64}
  a::Float64
Body::Float64
1 ─       (a = 2.0)
│   %2  = Main.:(var"#f#1")::Core.Const(var"#f#1")
│   %3  = Core.typeof(x)::Core.Const(Float64)
│   %4  = a::Core.Const(2.0)
│   %5  = Core.typeof(%4)::Core.Const(Float64)
│   %6  = Core.apply_type(%2, %3, %5)::Core.Const(var"#f#1"{Float64, Float64})
│   %7  = a::Core.Const(2.0)
│         (f = %new(%6, x, %7))
│   %9  = f::Core.PartialStruct(var"#f#1"{Float64, Float64}, Any[Float64, Core.Const(2.0)])
│   %10 = (%9)(x)::Float64
└──       return %10

I thought that julia compiles functions the first time with the arguments that are passed. So along that line, I thought that when I call cylinder_sensor_response with a struct that has typeof(s.X_c) = Float64 , the type of x_c, ... would be known.

This change makes the functions type-stable:

struct sensor{Tag, T2} 
    location::Vec3{T2}
    normal::Vec3{T2}
    α::T2
    parameters
    id::Int
end

That’s because a is a global variable. It can literally change type at any time:

julia> a = 2.0
2.0

julia> f(x) = a + x
f (generic function with 1 method)

julia> a = 2//3
2//3

julia> f(1)
5//3

so when f is compiled Julia has to look up the type of a at runtime. It can be anything (Any).

Right, because a is no longer a global variable. When Julia compiles the function, it can tell that a never changes type within the lifetime of the function.

These are the very first two performance tips in the manual: don’t use untyped global variables in critical code, and instead put code into functions.

8 Likes

The type is known, but only at runtime, so when accessing s.X_c you will have runtime dispatch, not static dispatch.

The thing you are missing is that the compiler must generate code that handles any sensor object, not just one particular instance.

If I tell you, @rkube, that you will receive an object of type sensor{SomeType}. Based on that information, what is the type of s.X_c? You get no runtime value, only the name of the type itself, with a type parameter. Do you see how this question is unanswerable?

BTW, according to Julia convention (it’s a strong convention), your type should be called Sensor, capitalized, not sensor.

1 Like