# Type Inference Inferring Any for function that returns Float64

Hello everyone, I am currently trying to code an Optical Ray Tracing package in Julia, and while trying to optimise the code for performance I ran into some issues that the function `trace` that performs the tracing does a lot of allocations for calling the function `intersection_distance` that has return type Float64.

Hereβs part of the trace function in question:

``````function trace(sys::OpticalSystem{N}) where {N}
traceable_rays::Vector{Rays{N}} = rays(sys)

system_surfaces::Vector{Surface{N}} = surfaces(sys)
for ray in traceable_rays
min_distance = Inf
min_index = 0
for (index, surface) in enumerate(system_surfaces)
distance = intersection_distance(surface, ray)
if distance < min_distance
min_distance = distance
min_index = index
end
end
end
end
``````

I checked the `intersection_distance` return types using the `Base.return_types` method and the output is as follows:

``````julia> Base.return_types(intersection_distance)
4-element Vector{Any}:
Float64
Float64
Float64
Float64
``````

The `intersection_distance` function takes two parameters. The first parameter is of type <:Surface{N} and the second parameter is of type ::Ray{N}. However, when running `@code_warntype` on the `trace` function define above I get that the return type of the intersection_distance function is type `Any`, which is the reason for doing so many allocations.
Hereβs the output of `@code_warntype`:

``````MethodInstance for trace(::OpticalSystem{2})
from trace(sys::OpticalSystem{N}) where N @ Main REPL[7]:1
Static Parameters
N = 2
Arguments
#self#::Core.Const(trace)
sys::OpticalSystem{2}
Locals
@_3::Union{Nothing, Tuple{Any, Int64}}
system_surfaces::Vector{Surface{2}}
traceable_rays::Vector
@_6::Union{Nothing, Tuple{Tuple{Int64, Surface{2}}, Tuple{Int64, Int64}}}
ray::Any
min_index::Int64
min_distance::Any
@_10::Int64
surface::Surface{2}
index::Int64
distance::Any
@_14::Vector
@_15::Vector{Surface{2}}
Body::Nothing
1 ββ       Core.NewvarNode(:(@_3))
β          Core.NewvarNode(:(system_surfaces))
β          Core.NewvarNode(:(traceable_rays))
β    %4  = Main.rays(sys)::Vector{Ray{2}}
β    %5  = Main.Vector::Core.Const(Vector)
β    %6  = Core.apply_type(Main.Rays, \$(Expr(:static_parameter, 1)))::Any
β    %7  = Core.apply_type(%5, %6)::Type{Vector{_A}} where _A
β          (@_14 = %4)
β    %9  = (@_14::Vector{Ray{2}} isa %7)::Bool
ββββ       goto #3 if not %9
2 ββ       goto #4
3 ββ %12 = Base.convert(%7, @_14::Vector{Ray{2}})::Vector
ββββ       (@_14 = Core.typeassert(%12, %7))
4 ββ       (traceable_rays = @_14)
β    %15 = Main.surfaces(sys)::Vector{Surface{2}}
β    %16 = Main.Vector::Core.Const(Vector)
β    %17 = Core.apply_type(Main.Surface, \$(Expr(:static_parameter, 1)))::Core.Const(Surface{2})
β    %18 = Core.apply_type(%16, %17)::Core.Const(Vector{Surface{2}})
β          (@_15 = %15)
β    %20 = (@_15 isa %18)::Core.Const(true)
ββββ       goto #6 if not %20
5 ββ       goto #7
6 ββ       Core.Const(:(Base.convert(%18, @_15)))
ββββ       Core.Const(:(@_15 = Core.typeassert(%23, %18)))
7 ββ       (system_surfaces = @_15)
β    %26 = traceable_rays::Vector
β          (@_3 = Base.iterate(%26))
β    %28 = (@_3 === nothing)::Bool
β    %29 = Base.not_int(%28)::Bool
ββββ       goto #15 if not %29
8 ββ %31 = @_3::Tuple{Any, Int64}
β          (ray = Core.getfield(%31, 1))
β    %33 = Core.getfield(%31, 2)::Int64
β          (min_distance = Main.Inf)
β          (min_index = 0)
β    %36 = Main.enumerate(system_surfaces)::Base.Iterators.Enumerate{Vector{Surface{2}}}
β          (@_6 = Base.iterate(%36))
β    %38 = (@_6 === nothing)::Bool
β    %39 = Base.not_int(%38)::Bool
ββββ       goto #13 if not %39
9 ββ %41 = @_6::Tuple{Tuple{Int64, Surface{2}}, Tuple{Int64, Int64}}
β    %42 = Core.getfield(%41, 1)::Tuple{Int64, Surface{2}}
β    %43 = Base.indexed_iterate(%42, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
β          (index = Core.getfield(%43, 1))
β          (@_10 = Core.getfield(%43, 2))
β    %46 = Base.indexed_iterate(%42, 2, @_10::Core.Const(2))::Core.PartialStruct(Tuple{Surface{2}, Int64}, Any[Surface{2}, Core.Const(3)])
β          (surface = Core.getfield(%46, 1))
β    %48 = Core.getfield(%41, 2)::Tuple{Int64, Int64}
β          (distance = Main.intersection_distance(surface, ray))
β    %50 = (distance < min_distance)::Any
ββββ       goto #11 if not %50
10 β       (min_distance = distance)
ββββ       (min_index = index)
11 β       (@_6 = Base.iterate(%36, %48))
β    %55 = (@_6 === nothing)::Bool
β    %56 = Base.not_int(%55)::Bool
ββββ       goto #13 if not %56
12 β       goto #9
13 β       (@_3 = Base.iterate(%26, %33))
β    %60 = (@_3 === nothing)::Bool
β    %61 = Base.not_int(%60)::Bool
ββββ       goto #15 if not %61
14 β       goto #8
15 β       return nothing
``````

You can clearly see that the `distance` variable and `min_distance` variable have only been inferred to type `Any` although as far as I understand should be inferred to `Float64` since the `intersection_distance` function only returns type `Float64`. I would be very thankful for any help in regards to how I can help Julia infer the correct return type for the `intersection_distance` function.

Welcome! Can you post the definition of `intersection_distance` too? I think it would help to resolve your problem.

Sure, of course. So there are multiple definitions of intersection_distance for the different types of surfaces. All the types are subtypes of Surface{N}. Iβll add them all:

``````function intersection_distance(cylinder::Cylinder{2}, ray::Ray{2})
ray_direction = direction(ray)
distance_apart = center(cylinder) - origin(ray)
cylinder_axis = axis(cylinder)

na = cross(ray_direction, cylinder_axis)
ba = cross(distance_apart, cylinder_axis)

d1 = (ba - cylinder_radius) / na
d2 = (ba + cylinder_radius) / na

alpha::Float64 = 0.0
if d1 > 0.0 && d2 > 0.0
d = min(d1, d2)
elseif d1 > 0.0
d = d1
elseif d2 > 0.0
d = d2
else
return Inf
end

distance_from_center = dot(cylinder_axis, origin(ray, alpha) - center(cylinder))

0 <= distance_from_center <= height(cylinder) && return alpha

return Inf
end

function intersection_distance(spherical_cap::SphericalCap{N}, ray::Ray{N}) where {N}
distance_apart = origin(ray) - center(spherical_cap)
ray_direction = direction(ray)

a = dot(ray_direction, ray_direction)
b = 2 * dot(ray_direction, distance_apart)
c = dot(distance_apart, distance_apart) - spherical_cap_radius^2

alpha_1, alpha_2 = quadraticroots(a, b, c)

alpha_2 < 1e-10 && return Inf

alpha_2 = oncap(spherical_cap, origin(ray, alpha_2)) ? alpha_2 : Inf

alpha_1 < 1e-10 && return alpha_2

oncap(spherical_cap, origin(ray, alpha_1)) && return alpha_1

return alpha_2
end

function intersection_distance(plane::Plane{N}, ray::Ray{N})::Float64 where {N}
ray_direction_to_plane = dot(normal(plane), direction(ray))
distance_apart = distance(plane) - dot(normal(plane), origin(ray))
alpha =  distance_apart / ray_direction_to_plane
alpha < 1e-10 && return Inf
return alpha
end

function intersection_distance(sphere::OpticalSphere{N}, ray::Ray{N}) where {N}
distance_apart = origin(ray) - center(sphere)
ray_direction = direction(ray)

a = dot(ray_direction, ray_direction)
b = 2 * dot(ray_direction, distance_apart)
c = dot(distance_apart, distance_apart) - sphere_radius^2

alpha_1, alpha_2 = quadraticroots(a, b, c)::Tuple{Float64, Float64}

alpha_2 < 1e-10 && return Inf
alpha_1 < 1e-10 && return alpha_2
return alpha_1
end

temp = b^2 - 4 * a * c
temp < 1e-10 && return 0.0, 0.0
return (-b - sqrt(temp)) / 2a, (-b + sqrt(temp)) / 2a
end
``````

I appreciate any help.

The strangest thing I find is that despite:

the result is:

As a result, `ray` has become `Any`, which appears to lead to a chain of `Any`.

Since you appear to be working in global scope, perhaps you could try re-running it once in a clean context.

Edit:
Wait, are `Rays` and `Ray` interchangeable?

Please post a self-contained example so that someone can actually reproduce your results.

1 Like

Hey,
sorry for the late reply. Hereβs the self-contained example:

``````using LinearAlgebra
using StaticArrays

import Base: push!, append!

abstract type Surface{N} end
abstract type ParametricSurface{N} <: Surface{N} end
abstract type CompoundSurface{N} <: Surface{N} end

surfaces(surface::T) where {T <: Surface} = (surface, )

struct SphericalCap{N} <: ParametricSurface{N}
center::SVector{N, Float64}
direction::SVector{N, Float64}
theta::Float64
refractive_index::Float64
end

theta(spherical_cap::SphericalCap) = spherical_cap.theta
center(spherical_cap::SphericalCap) = spherical_cap.center
direction(spherical_cap::SphericalCap) = spherical_cap.direction
refractive_index(spherical_cap::SphericalCap) = spherical_cap.refractive_index

struct Plane{N} <: ParametricSurface{N}
point::SVector{N, Float64}
normal::SVector{N, Float64}
refractive_index::Float64
distance::Float64

function Plane(point::SVector{N, Float64}, normal::SVector{N, Float64}, refractive_index::Float64) where {N}
normal_unit = normalize(normal)
new{N}(point, normal_unit, refractive_index,  dot(normal, point))
end
end

point(plane::Plane) = plane.point
normal(plane::Plane) = plane.normal
distance(plane::Plane) = plane.distance
refractive_index(plane::Plane) = plane.refractive_index

struct OpticalSphere{N} <: ParametricSurface{N}
center::SVector{N, Float64}
refractive_index::Float64
end

center(sphere::OpticalSphere) = sphere.center
refractive_index(sphere::OpticalSphere) = sphere.refractive_index

struct Cylinder{N} <: ParametricSurface{N}
center::SVector{N, Float64}
height::Float64
refractive_index::Float64

function Cylinder(center::SVector{N, Float64}, height::Float64, radius::Float64, refractive_index::Float64) where {N}
@assert N==2 || N==3 "Spheres can only be two- or three-dimensional"
end

end

center(cylinder::Cylinder) = cylinder.center
height(cylinder::Cylinder) = cylinder.height
axis(::Cylinder{N}) where {N} = @SVector [1.0, 0.0]
refractive_index(cylinder::Cylinder) = cylinder.refractive_index

struct Lens{N} <: CompoundSurface{N}
surfaces::Tuple{SphericalCap{N}, Cylinder{N}, SphericalCap{N}}
end

surfaces(lens::Lens{N}) where {N} = lens.surfaces
front_cap(lens::Lens{N}) where {N} = lens.surfaces[1]
cylinder(lens::Lens{N}) where {N} = lens.surfaces[2]
back_cap(lens::Lens{N}) where {N} = lens.surfaces[3]

cap_offset = SVector(x_position,0.0)

front_cap_center = center + cap_offset
front_cap = SphericalCap(front_cap_center, SVector(-1.0,0.0), radius, theta, refractive_index)

back_cap_center = center - cap_offset
back_cap = SphericalCap(back_cap_center, SVector(1.0, 0.0), radius, theta, refractive_index)

distance = radius * (1 - cos(theta))
height = thickness-2*distance

cylinder= Cylinder(center - SVector(height/2,0), height, radius_lens, refractive_index)

return Lens((front_cap, cylinder, back_cap))
end

struct Ray{N}
origin::SVector{N, Float64}
direction::SVector{N, Float64}

function Ray(origin::SVector{N, Float64}, direction::SVector{N, Float64}) where {N}
new{N}(origin, normalize(direction))
end
end

origin(ray::Ray) = ray.origin
origin(ray::Ray, alpha::Float64) = origin(ray) + alpha * direction(ray)
direction(ray::Ray) = ray.direction

struct OpticalSystem{N}
objects::Vector{Surface{N}}
rays::Vector{Ray{N}}
end

OpticalSystem{N}() where {N} = OpticalSystem(Vector{Surface{N}}(), Vector{Ray{N}}())

objects(sys::OpticalSystem) = sys.objects
rays(sys::OpticalSystem) = sys.rays
traced_rays(sys::OpticalSystem) = sys.traced_rays
function surfaces(sys::OpticalSystem{N})::Vector{Surface{N}} where {N}
obj_surfaces = Vector{Surface{N}}()
for obj in objects(sys)
append!(obj_surfaces, surfaces(obj))
end
obj_surfaces
end

push!(sys::OpticalSystem, object::T) where {T<:Surface} = push!(sys.objects, object)
append!(sys::OpticalSystem{N}, rays::Vector{Ray{N}}) where {N} = append!(sys.rays, rays)

temp = b^2 - 4 * a * c
temp < 1e-10 && return 0.0, 0.0
return (-b - sqrt(temp)) / 2a, (-b + sqrt(temp)) / 2a
end

function intersection_distance(spherical_cap::SphericalCap{N}, ray::Ray{N}) where {N}
distance_apart = origin(ray) - center(spherical_cap)
ray_direction = direction(ray)

a = dot(ray_direction, ray_direction)
b = 2 * dot(ray_direction, distance_apart)
c = dot(distance_apart, distance_apart) - spherical_cap_radius^2

alpha_1, alpha_2 = quadraticroots(a, b, c)

alpha_2 < 1e-10 && return Inf

alpha_2 = oncap(spherical_cap, origin(ray, alpha_2)) ? alpha_2 : Inf

alpha_1 < 1e-10 && return alpha_2

oncap(spherical_cap, origin(ray, alpha_1)) && return alpha_1

return alpha_2
end

function oncap(spherical_cap::SphericalCap{N}, point::SVector{N, Float64}) where {N}
cosine_normal_direction = dot(normal(spherical_cap, point), direction(spherical_cap))
cosine_normal_direction <= 0 && return false
acos(cosine_normal_direction) > theta(spherical_cap) && return false

return true
end

normal(spherical_cap::SphericalCap{N}, point::SVector{N,Float64}) where {N} = normalize(point - center(spherical_cap))

function intersection_distance(plane::Plane{N}, ray::Ray{N})::Float64 where {N}
ray_direction_to_plane = dot(normal(plane), direction(ray))
distance_apart = distance(plane) - dot(normal(plane), origin(ray))
alpha =  distance_apart / ray_direction_to_plane
alpha < 1e-10 && return Inf
return alpha
end

function intersection_distance(sphere::OpticalSphere{N}, ray::Ray{N}) where {N}
distance_apart = origin(ray) - center(sphere)
ray_direction = direction(ray)

a = dot(ray_direction, ray_direction)
b = 2 * dot(ray_direction, distance_apart)
c = dot(distance_apart, distance_apart) - sphere_radius^2

alpha_1, alpha_2 = quadraticroots(a, b, c)

alpha_2 < 1e-10 && return Inf
alpha_1 < 1e-10 && return alpha_2
return alpha_1
end

function intersection_distance(cylinder::Cylinder{2}, ray::Ray{2})
ray_direction = direction(ray)
distance_apart = center(cylinder) - origin(ray)
cylinder_axis = axis(cylinder)

na = cross(ray_direction, cylinder_axis)
ba = cross(distance_apart, cylinder_axis)

d1 = (ba - cylinder_radius) / na
d2 = (ba + cylinder_radius) / na

alpha::Float64 = 0.0
if d1 > 0.0 && d2 > 0.0
d = min(d1, d2)
elseif d1 > 0.0
d = d1
elseif d2 > 0.0
d = d2
else
return Inf
end

distance_from_center = dot(cylinder_axis, origin(ray, alpha) - center(cylinder))

0 <= distance_from_center <= height(cylinder) && return alpha

return Inf
end

function trace(sys::OpticalSystem)
traceable_rays = rays(sys)
system_surfaces = surfaces(sys)

for ray in traceable_rays
min_distance = Inf
min_index = 0
for (index, surface) in enumerate(system_surfaces)
distance = intersection_distance(surface, ray)
if distance < min_distance
min_distance = distance
min_index = index
end
end
end
end

sys = OpticalSystem{2}()
lens = Lens(SVector(0.0, 0.0), 4.0, 2.0, 2.0, 1.33)
plane = Plane(SVector(5.0, 0.0), SVector(-1.0, 0.0), 1.33)
sphere = OpticalSphere(SVector(-3.0, 0.0), 1.0, 1.33)

push!(sys, sphere)
push!(sys, lens)
push!(sys, plane)

r = [Ray(SVector(-10.0, i), SVector(1.0, 0.0)) for i in  -2.0:0.01:2.0]
append!(sys, r);

@code_warntype trace(sys)
``````

Sorry, that is a mistake I made it should have been `Ray{N}`. The type `Rays` does not exists.

I donβt know where to find the proper documentation, but in any case, the default is to give up inference if the number of methods exceeds 3.

I think it is better to annotate manually the return type for this case.

Perhaps the order of the surfaces does not matter unless the distances are exactly the same, so grouping or filtering by the type is also another option. However, in this case, I do not think it is very effective in terms of performance.

Of course, there are many other measures, but that would be a discussion of total performance and should be in a separate topic.

The output type isnβt inferred because the input type isnβt inferred:

``````julia> eltype(surfaces(sys))
Surface{2}

julia> isconcretetype(ans)
false
``````

The vector is never created with concrete element type:

To be precise, it is only not fully inferred, but the inference itself is done reasonably.

For example, since `inf_params.max_union_splitting::Int = 4`, the return type of the following could be inferred as `Float64`.

``````function md_intersection_distance(surface::Union{Plane{N}, OpticalSphere{N}, Cylinder{2}, SphericalCap{N}}, ray::Ray{N}) where {N}
intersection_distance(surface, ray)
end
``````
1 Like

Hello,

Thank you very much for your answer. I had no idea this existed, but is very good to know. As soon as I commented out one `intersection_distance` methods such that there are only 3, it inferred the `Float64`, which is very nice to see. Thank you so much for your help this was very frustrating for me.

Thank you also for your advise on solving the issue as well.

NB: the exact cutoff value is an implementation detail, there are discussions to lower this default in the future. Thereβs also experimental APIs to override the default, under `Base.Experimental`.

2 Likes