Type Inference Inferring Any for function that returns Float64

Hello everyone, I am currently trying to code an Optical Ray Tracing package in Julia, and while trying to optimise the code for performance I ran into some issues that the function trace that performs the tracing does a lot of allocations for calling the function intersection_distance that has return type Float64.

Here’s part of the trace function in question:

function trace(sys::OpticalSystem{N}) where {N}
    traceable_rays::Vector{Rays{N}} = rays(sys)

    system_surfaces::Vector{Surface{N}} = surfaces(sys)
    for ray in traceable_rays
        min_distance = Inf
        min_index = 0
        for (index, surface) in enumerate(system_surfaces)
            distance = intersection_distance(surface, ray)
            if distance < min_distance
                min_distance = distance
                min_index = index
            end
        end
    end
end

I checked the intersection_distance return types using the Base.return_types method and the output is as follows:

julia> Base.return_types(intersection_distance)
4-element Vector{Any}:
 Float64
 Float64
 Float64
 Float64

The intersection_distance function takes two parameters. The first parameter is of type <:Surface{N} and the second parameter is of type ::Ray{N}. However, when running @code_warntype on the trace function define above I get that the return type of the intersection_distance function is type Any, which is the reason for doing so many allocations.
Here’s the output of @code_warntype:

MethodInstance for trace(::OpticalSystem{2})
  from trace(sys::OpticalSystem{N}) where N @ Main REPL[7]:1
Static Parameters
  N = 2
Arguments
  #self#::Core.Const(trace)
  sys::OpticalSystem{2}
Locals
  @_3::Union{Nothing, Tuple{Any, Int64}}
  system_surfaces::Vector{Surface{2}}
  traceable_rays::Vector
  @_6::Union{Nothing, Tuple{Tuple{Int64, Surface{2}}, Tuple{Int64, Int64}}}
  ray::Any
  min_index::Int64
  min_distance::Any
  @_10::Int64
  surface::Surface{2}
  index::Int64
  distance::Any
  @_14::Vector
  @_15::Vector{Surface{2}}
Body::Nothing
1 ──       Core.NewvarNode(:(@_3))
β”‚          Core.NewvarNode(:(system_surfaces))
β”‚          Core.NewvarNode(:(traceable_rays))
β”‚    %4  = Main.rays(sys)::Vector{Ray{2}}
β”‚    %5  = Main.Vector::Core.Const(Vector)
β”‚    %6  = Core.apply_type(Main.Rays, $(Expr(:static_parameter, 1)))::Any
β”‚    %7  = Core.apply_type(%5, %6)::Type{Vector{_A}} where _A
β”‚          (@_14 = %4)
β”‚    %9  = (@_14::Vector{Ray{2}} isa %7)::Bool
└───       goto #3 if not %9
2 ──       goto #4
3 ── %12 = Base.convert(%7, @_14::Vector{Ray{2}})::Vector
└───       (@_14 = Core.typeassert(%12, %7))
4 ┄─       (traceable_rays = @_14)
β”‚    %15 = Main.surfaces(sys)::Vector{Surface{2}}
β”‚    %16 = Main.Vector::Core.Const(Vector)
β”‚    %17 = Core.apply_type(Main.Surface, $(Expr(:static_parameter, 1)))::Core.Const(Surface{2})
β”‚    %18 = Core.apply_type(%16, %17)::Core.Const(Vector{Surface{2}})
β”‚          (@_15 = %15)
β”‚    %20 = (@_15 isa %18)::Core.Const(true)
└───       goto #6 if not %20
5 ──       goto #7
6 ──       Core.Const(:(Base.convert(%18, @_15)))
└───       Core.Const(:(@_15 = Core.typeassert(%23, %18)))
7 ┄─       (system_surfaces = @_15)
β”‚    %26 = traceable_rays::Vector
β”‚          (@_3 = Base.iterate(%26))
β”‚    %28 = (@_3 === nothing)::Bool
β”‚    %29 = Base.not_int(%28)::Bool
└───       goto #15 if not %29
8 ┄─ %31 = @_3::Tuple{Any, Int64}
β”‚          (ray = Core.getfield(%31, 1))
β”‚    %33 = Core.getfield(%31, 2)::Int64
β”‚          (min_distance = Main.Inf)
β”‚          (min_index = 0)
β”‚    %36 = Main.enumerate(system_surfaces)::Base.Iterators.Enumerate{Vector{Surface{2}}}
β”‚          (@_6 = Base.iterate(%36))
β”‚    %38 = (@_6 === nothing)::Bool
β”‚    %39 = Base.not_int(%38)::Bool
└───       goto #13 if not %39
9 ┄─ %41 = @_6::Tuple{Tuple{Int64, Surface{2}}, Tuple{Int64, Int64}}
β”‚    %42 = Core.getfield(%41, 1)::Tuple{Int64, Surface{2}}
β”‚    %43 = Base.indexed_iterate(%42, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
β”‚          (index = Core.getfield(%43, 1))
β”‚          (@_10 = Core.getfield(%43, 2))
β”‚    %46 = Base.indexed_iterate(%42, 2, @_10::Core.Const(2))::Core.PartialStruct(Tuple{Surface{2}, Int64}, Any[Surface{2}, Core.Const(3)])
β”‚          (surface = Core.getfield(%46, 1))
β”‚    %48 = Core.getfield(%41, 2)::Tuple{Int64, Int64}
β”‚          (distance = Main.intersection_distance(surface, ray))
β”‚    %50 = (distance < min_distance)::Any
└───       goto #11 if not %50
10 ─       (min_distance = distance)
└───       (min_index = index)
11 β”„       (@_6 = Base.iterate(%36, %48))
β”‚    %55 = (@_6 === nothing)::Bool
β”‚    %56 = Base.not_int(%55)::Bool
└───       goto #13 if not %56
12 ─       goto #9
13 β”„       (@_3 = Base.iterate(%26, %33))
β”‚    %60 = (@_3 === nothing)::Bool
β”‚    %61 = Base.not_int(%60)::Bool
└───       goto #15 if not %61
14 ─       goto #8
15 β”„       return nothing

You can clearly see that the distance variable and min_distance variable have only been inferred to type Any although as far as I understand should be inferred to Float64 since the intersection_distance function only returns type Float64. I would be very thankful for any help in regards to how I can help Julia infer the correct return type for the intersection_distance function.

Welcome! Can you post the definition of intersection_distance too? I think it would help to resolve your problem.

Sure, of course. So there are multiple definitions of intersection_distance for the different types of surfaces. All the types are subtypes of Surface{N}. I’ll add them all:

function intersection_distance(cylinder::Cylinder{2}, ray::Ray{2})
    ray_direction = direction(ray)
    distance_apart = center(cylinder) - origin(ray)
    cylinder_axis = axis(cylinder)
    cylinder_radius=radius(cylinder)

    na = cross(ray_direction, cylinder_axis)
    ba = cross(distance_apart, cylinder_axis)

    d1 = (ba - cylinder_radius) / na
    d2 = (ba + cylinder_radius) / na

    alpha::Float64 = 0.0
    if d1 > 0.0 && d2 > 0.0
        d = min(d1, d2)
    elseif d1 > 0.0
        d = d1
    elseif d2 > 0.0
        d = d2
    else
        return Inf
    end

    distance_from_center = dot(cylinder_axis, origin(ray, alpha) - center(cylinder))

    0 <= distance_from_center <= height(cylinder) && return alpha

    return Inf
end

function intersection_distance(spherical_cap::SphericalCap{N}, ray::Ray{N}) where {N}
    distance_apart = origin(ray) - center(spherical_cap)
    ray_direction = direction(ray)
    spherical_cap_radius = radius(spherical_cap)

    a = dot(ray_direction, ray_direction)
    b = 2 * dot(ray_direction, distance_apart)
    c = dot(distance_apart, distance_apart) - spherical_cap_radius^2

    alpha_1, alpha_2 = quadraticroots(a, b, c)

    alpha_2 < 1e-10 && return Inf

    alpha_2 = oncap(spherical_cap, origin(ray, alpha_2)) ? alpha_2 : Inf

    alpha_1 < 1e-10 && return alpha_2

    oncap(spherical_cap, origin(ray, alpha_1)) && return alpha_1

    return alpha_2
end

function intersection_distance(plane::Plane{N}, ray::Ray{N})::Float64 where {N}
    ray_direction_to_plane = dot(normal(plane), direction(ray))
    distance_apart = distance(plane) - dot(normal(plane), origin(ray))
    alpha =  distance_apart / ray_direction_to_plane
    alpha < 1e-10 && return Inf
    return alpha
end

function intersection_distance(sphere::OpticalSphere{N}, ray::Ray{N}) where {N}
    distance_apart = origin(ray) - center(sphere)
    ray_direction = direction(ray)
    sphere_radius = radius(sphere)

    a = dot(ray_direction, ray_direction)
    b = 2 * dot(ray_direction, distance_apart)
    c = dot(distance_apart, distance_apart) - sphere_radius^2

    alpha_1, alpha_2 = quadraticroots(a, b, c)::Tuple{Float64, Float64}

    alpha_2 < 1e-10 && return Inf
    alpha_1 < 1e-10 && return alpha_2
    return alpha_1
end

function quadraticroots(a::Float64, b::Float64, c::Float64)::Tuple{Float64, Float64}
    temp = b^2 - 4 * a * c
    temp < 1e-10 && return 0.0, 0.0
    return (-b - sqrt(temp)) / 2a, (-b + sqrt(temp)) / 2a
end

I appreciate any help.

The strangest thing I find is that despite:

the result is:

As a result, ray has become Any, which appears to lead to a chain of Any.

Since you appear to be working in global scope, perhaps you could try re-running it once in a clean context.

Edit:
Wait, are Rays and Ray interchangeable?

Please post a self-contained example so that someone can actually reproduce your results.

1 Like

Hey,
sorry for the late reply. Here’s the self-contained example:

using LinearAlgebra
using StaticArrays

import Base: push!, append!

abstract type Surface{N} end
abstract type ParametricSurface{N} <: Surface{N} end
abstract type CompoundSurface{N} <: Surface{N} end

surfaces(surface::T) where {T <: Surface} = (surface, )

struct SphericalCap{N} <: ParametricSurface{N}
    center::SVector{N, Float64}
    direction::SVector{N, Float64}
    radius::Float64
    theta::Float64
    refractive_index::Float64
end

radius(spherical_cap::SphericalCap) = spherical_cap.radius
theta(spherical_cap::SphericalCap) = spherical_cap.theta
center(spherical_cap::SphericalCap) = spherical_cap.center
direction(spherical_cap::SphericalCap) = spherical_cap.direction
refractive_index(spherical_cap::SphericalCap) = spherical_cap.refractive_index

struct Plane{N} <: ParametricSurface{N}
    point::SVector{N, Float64}
    normal::SVector{N, Float64}
    refractive_index::Float64
    distance::Float64

    function Plane(point::SVector{N, Float64}, normal::SVector{N, Float64}, refractive_index::Float64) where {N}
        normal_unit = normalize(normal)
        new{N}(point, normal_unit, refractive_index,  dot(normal, point))
    end
end

point(plane::Plane) = plane.point
normal(plane::Plane) = plane.normal
distance(plane::Plane) = plane.distance
refractive_index(plane::Plane) = plane.refractive_index

struct OpticalSphere{N} <: ParametricSurface{N}
    center::SVector{N, Float64}
    radius::Float64
    refractive_index::Float64
end

radius(sphere::OpticalSphere) = sphere.radius
center(sphere::OpticalSphere) = sphere.center
refractive_index(sphere::OpticalSphere) = sphere.refractive_index

struct Cylinder{N} <: ParametricSurface{N}
    center::SVector{N, Float64}
    height::Float64
    radius::Float64
    refractive_index::Float64

    function Cylinder(center::SVector{N, Float64}, height::Float64, radius::Float64, refractive_index::Float64) where {N}
        @assert N==2 || N==3 "Spheres can only be two- or three-dimensional"
        new{N}(center, height, radius, refractive_index)
    end

end

center(cylinder::Cylinder) = cylinder.center
radius(cylinder::Cylinder) = cylinder.radius
height(cylinder::Cylinder) = cylinder.height
axis(::Cylinder{N}) where {N} = @SVector [1.0, 0.0]
refractive_index(cylinder::Cylinder) = cylinder.refractive_index

struct Lens{N} <: CompoundSurface{N}
    surfaces::Tuple{SphericalCap{N}, Cylinder{N}, SphericalCap{N}}
end

surfaces(lens::Lens{N}) where {N} = lens.surfaces
front_cap(lens::Lens{N}) where {N} = lens.surfaces[1]
cylinder(lens::Lens{N}) where {N} = lens.surfaces[2]
back_cap(lens::Lens{N}) where {N} = lens.surfaces[3]

function Lens(center::SVector{N,T}, radius::T, radius_lens::T, thickness::T, refractive_index::T) where {N, T<:Real}
    x_position= radius - thickness/2
    theta = asin(radius_lens / radius)
    cap_offset = SVector(x_position,0.0)
    
    front_cap_center = center + cap_offset
    front_cap = SphericalCap(front_cap_center, SVector(-1.0,0.0), radius, theta, refractive_index)
    
    back_cap_center = center - cap_offset
    back_cap = SphericalCap(back_cap_center, SVector(1.0, 0.0), radius, theta, refractive_index)
    
    distance = radius * (1 - cos(theta))
    height = thickness-2*distance

    cylinder= Cylinder(center - SVector(height/2,0), height, radius_lens, refractive_index)

    return Lens((front_cap, cylinder, back_cap))
end

struct Ray{N}
    origin::SVector{N, Float64}
    direction::SVector{N, Float64}

    function Ray(origin::SVector{N, Float64}, direction::SVector{N, Float64}) where {N}
        new{N}(origin, normalize(direction))
    end
end

origin(ray::Ray) = ray.origin
origin(ray::Ray, alpha::Float64) = origin(ray) + alpha * direction(ray)
direction(ray::Ray) = ray.direction

struct OpticalSystem{N}
    objects::Vector{Surface{N}}
    rays::Vector{Ray{N}}
end

OpticalSystem{N}() where {N} = OpticalSystem(Vector{Surface{N}}(), Vector{Ray{N}}())

objects(sys::OpticalSystem) = sys.objects
rays(sys::OpticalSystem) = sys.rays
traced_rays(sys::OpticalSystem) = sys.traced_rays
function surfaces(sys::OpticalSystem{N})::Vector{Surface{N}} where {N}
    obj_surfaces = Vector{Surface{N}}()
    for obj in objects(sys)
        append!(obj_surfaces, surfaces(obj))
    end
    obj_surfaces
end

push!(sys::OpticalSystem, object::T) where {T<:Surface} = push!(sys.objects, object)
append!(sys::OpticalSystem{N}, rays::Vector{Ray{N}}) where {N} = append!(sys.rays, rays)

function quadraticroots(a::Float64, b::Float64, c::Float64)::Tuple{Float64, Float64}
    temp = b^2 - 4 * a * c
    temp < 1e-10 && return 0.0, 0.0
    return (-b - sqrt(temp)) / 2a, (-b + sqrt(temp)) / 2a
end

function intersection_distance(spherical_cap::SphericalCap{N}, ray::Ray{N}) where {N}
    distance_apart = origin(ray) - center(spherical_cap)
    ray_direction = direction(ray)
    spherical_cap_radius = radius(spherical_cap)

    a = dot(ray_direction, ray_direction)
    b = 2 * dot(ray_direction, distance_apart)
    c = dot(distance_apart, distance_apart) - spherical_cap_radius^2

    alpha_1, alpha_2 = quadraticroots(a, b, c)

    alpha_2 < 1e-10 && return Inf

    alpha_2 = oncap(spherical_cap, origin(ray, alpha_2)) ? alpha_2 : Inf

    alpha_1 < 1e-10 && return alpha_2

    oncap(spherical_cap, origin(ray, alpha_1)) && return alpha_1

    return alpha_2
end

function oncap(spherical_cap::SphericalCap{N}, point::SVector{N, Float64}) where {N}
    cosine_normal_direction = dot(normal(spherical_cap, point), direction(spherical_cap))
    cosine_normal_direction <= 0 && return false
    acos(cosine_normal_direction) > theta(spherical_cap) && return false

    return true
end

normal(spherical_cap::SphericalCap{N}, point::SVector{N,Float64}) where {N} = normalize(point - center(spherical_cap))


function intersection_distance(plane::Plane{N}, ray::Ray{N})::Float64 where {N}
    ray_direction_to_plane = dot(normal(plane), direction(ray))
    distance_apart = distance(plane) - dot(normal(plane), origin(ray))
    alpha =  distance_apart / ray_direction_to_plane
    alpha < 1e-10 && return Inf
    return alpha
end

function intersection_distance(sphere::OpticalSphere{N}, ray::Ray{N}) where {N}
    distance_apart = origin(ray) - center(sphere)
    ray_direction = direction(ray)
    sphere_radius = radius(sphere)

    a = dot(ray_direction, ray_direction)
    b = 2 * dot(ray_direction, distance_apart)
    c = dot(distance_apart, distance_apart) - sphere_radius^2

    alpha_1, alpha_2 = quadraticroots(a, b, c)

    alpha_2 < 1e-10 && return Inf
    alpha_1 < 1e-10 && return alpha_2
    return alpha_1
end

function intersection_distance(cylinder::Cylinder{2}, ray::Ray{2})
    ray_direction = direction(ray)
    distance_apart = center(cylinder) - origin(ray)
    cylinder_axis = axis(cylinder)
    cylinder_radius=radius(cylinder)

    na = cross(ray_direction, cylinder_axis)
    ba = cross(distance_apart, cylinder_axis)

    d1 = (ba - cylinder_radius) / na
    d2 = (ba + cylinder_radius) / na

    alpha::Float64 = 0.0
    if d1 > 0.0 && d2 > 0.0
        d = min(d1, d2)
    elseif d1 > 0.0
        d = d1
    elseif d2 > 0.0
        d = d2
    else
        return Inf
    end

    distance_from_center = dot(cylinder_axis, origin(ray, alpha) - center(cylinder))

    0 <= distance_from_center <= height(cylinder) && return alpha

    return Inf
end

function trace(sys::OpticalSystem)
    traceable_rays = rays(sys)
    system_surfaces = surfaces(sys)

    for ray in traceable_rays
        min_distance = Inf
        min_index = 0
        for (index, surface) in enumerate(system_surfaces)
            distance = intersection_distance(surface, ray)
            if distance < min_distance
                min_distance = distance
                min_index = index
            end
        end
    end
end

sys = OpticalSystem{2}()
lens = Lens(SVector(0.0, 0.0), 4.0, 2.0, 2.0, 1.33)
plane = Plane(SVector(5.0, 0.0), SVector(-1.0, 0.0), 1.33)
sphere = OpticalSphere(SVector(-3.0, 0.0), 1.0, 1.33)

push!(sys, sphere)
push!(sys, lens)
push!(sys, plane)

r = [Ray(SVector(-10.0, i), SVector(1.0, 0.0)) for i in  -2.0:0.01:2.0]
append!(sys, r);

@code_warntype trace(sys)

Sorry, that is a mistake I made it should have been Ray{N}. The type Rays does not exists.

I don’t know where to find the proper documentation, but in any case, the default is to give up inference if the number of methods exceeds 3.

I think it is better to annotate manually the return type for this case.

Perhaps the order of the surfaces does not matter unless the distances are exactly the same, so grouping or filtering by the type is also another option. However, in this case, I do not think it is very effective in terms of performance.

Of course, there are many other measures, but that would be a discussion of total performance and should be in a separate topic.

The output type isn’t inferred because the input type isn’t inferred:

julia> eltype(surfaces(sys))
Surface{2}

julia> isconcretetype(ans)
false

The vector is never created with concrete element type:

To be precise, it is only not fully inferred, but the inference itself is done reasonably.

For example, since inf_params.max_union_splitting::Int = 4, the return type of the following could be inferred as Float64.

function md_intersection_distance(surface::Union{Plane{N}, OpticalSphere{N}, Cylinder{2}, SphericalCap{N}}, ray::Ray{N}) where {N}
    intersection_distance(surface, ray)
end
1 Like

Hello,

Thank you very much for your answer. I had no idea this existed, but is very good to know. As soon as I commented out one intersection_distance methods such that there are only 3, it inferred the Float64, which is very nice to see. Thank you so much for your help this was very frustrating for me.

Thank you also for your advise on solving the issue as well.

NB: the exact cutoff value is an implementation detail, there are discussions to lower this default in the future. There’s also experimental APIs to override the default, under Base.Experimental.

2 Likes