Type Instability when composing functor and function

I am training a neural net using Flux.jl. I have a custom function that normalizes inputs before evaluating the model:

function normalize(state)
    state_min = [-1.2f0, -0.07f0]
    state_max = [0.6f0, 0.07f0]
    normalized_state = (state .- state_min) ./ (state_max .- state_min)
    return normalized_state
end

mdl = Chain(Dense(2 => 32, swish),
    Dense(32 => 32, swish),
    Dense(32 => 1))

Define:

f(x) = mdl(normalize(x))
xtest = [-0.1f0, 0.2f0]

Why is mdl(normalize(xtest)) type-stable while f(xtest) is not, returning:

julia> @code_warntype f(x)
MethodInstance for f(::Vector{Float32})
  from f(x) @ Main Untitled-1:16
Arguments
  #self#::Core.Const(f)
  x::Vector{Float32}
Body::Any
1 ─ %1 = Main.normalize(x)::Vector{Float32}
β”‚   %2 = Main.mdl(%1)::Any
└──      return %2

Dug into this a bit more. This seems like a general pattern for functors, not specific to anything in Flux:

struct MyStruct{T}
    k::T
end

function (S::MyStruct)(x)
    return S.k * x
end

S = MyStruct(2)
f(x) = S(2*x)

Here are the @code_warntype results:

julia> @code_warntype S(2*1)
MethodInstance for (::MyStruct{Int64})(::Int64)
  from (S::MyStruct)(x) @ Main Untitled-1:61
Arguments
  S::MyStruct{Int64}
  x::Int64
Body::Int64
1 ─ %1 = Base.getproperty(S, :k)::Int64
β”‚   %2 = (%1 * x)::Int64
└──      return %2


julia> @code_warntype f(1)
MethodInstance for f(::Int64)
  from f(x) @ Main Untitled-1:66
Arguments
  #self#::Core.Const(f)
  x::Int64
Body::Any
1 ─ %1 = (2 * x)::Int64
β”‚   %2 = Main.S(%1)::Any
└──      return %2

This is a consequence of working on global scope where non-constant globals must be Any because the compiler can’t prove they won’t change out from under it.

julia> const S = MyStruct(2)
MyStruct{Int64}(2)

julia> f(x) = S(2*x)
f (generic function with 1 method)

julia> @code_warntype f(1)
MethodInstance for f(::Int64)
  from f(x) @ Main REPL[4]:1
Arguments
  #self#::Core.Const(f)
  x::Int64
Body::Int64
1 ─ %1 = (2 * x)::Int64
β”‚   %2 = Main.S(%1)::Int64
└──      return %2

Or in a closure

julia> function g(y)
       S = MyStruct(y)
       h(x) = S(2*x)
       end
g (generic function with 1 method)

julia> h=g(2)
h (generic function with 1 method)

ulia> @code_warntype h(1)
MethodInstance for (::var"#h#8"{MyStruct{Int64}})(::Int64)
  from (::var"#h#8")(x) @ Main REPL[7]:3
Arguments
  #self#::var"#h#8"{MyStruct{Int64}}
  x::Int64
Body::Int64
1 ─ %1 = Core.getfield(#self#, :S)::MyStruct{Int64}
β”‚   %2 = (2 * x)::Int64
β”‚   %3 = (%1)(%2)::Int64
└──      return %3

1 Like

Thanks! I confirmed that making mdl constant in my initial example following your suggestion resolve the issue there too.

Since I always end up declaring stuff in global scope when debugging/trying things out, I guess the lesson is to declare stuff constant when checking for type instability. I suppose BenchmarkTools already does this when one interpolates variables using $, so if one sees excessive allocations when benchmarking, this kind of type instability isn’t to blame.