Hi there,
I simulated some data which I stored as an Array{Observation}
where Observation
is a composite type that stores the elements of each observation. I then I wrote a function to compute the log-likelihood of the data as:
struct Observation{T}
y::T
x1::T
x2::T
end
# this function will generate the logit data for Q4
function simulate_logit_data(; n_obs = 300000)
# define the DGP
X1 = 1.0:0.01:10.0
X2 = 18:0.01:65.0
α = -1.0
θ_1 = 0.5
θ_2 = 0.02
# pre-allocate the container for the data
data = Array{Observation}(undef, n_obs)
for i in eachindex(data)
x1 = rand(X1)
x2 = rand(X2)
u = rand()
y = (α + θ_1*x2 + θ_2*x2 + log(u/(1-u)) > 0) ? 1.0 : 0.0
data[i] = Observation(y, x1, x2)
end
return data
end
function Xθ(θ, obs; f = fieldnames(Observation))
result::Float64 = θ[1]
for i in 2:length(f)
result += θ[i]*getfield(obs, f[i])
end
return result
end
# this function will compute the log-likelihood given a parameter space point
function LL(θ, data)
ll::Float64 = 0.0
for i in eachindex(data)
ll += -log(1+exp(Xθ(θ, data[i]))) + getfield(data[i], :y)*Xθ(θ, data[i])
end
return ll
end
function main()
data = simulate_logit_data()
@btime LL($(rand(3)), $data)
end
main()
But the performance is not as expected:
- I am getting a lot of allocations. I don’t seem to understand why this is happening.
julia> @btime LL($rand(3), $data)
37.647 ms (2400002 allocations: 36.62 MiB)
-21070.048454177613
- I tried to check for type stability using
@code_warntype
but I can not understand where the issues come. Here I attach a screenshot with partial output (due to space constraint) but which contains the two issues:
MethodInstance for LL(::Vector{Float64}, ::Vector{Observation})
from LL(θ, data) @ Main REPL[12]:1
Arguments
#self#::Core.Const(Main.LL)
θ::Vector{Float64}
data::Vector{Observation}
Locals
@_4::Union{Nothing, Tuple{Int64, Int64}}
ll::Float64
i::Int64
@_7::Float64
@_8::Any
Body::Float64
1 ── Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(ll))
│ (@_7 = 0.0)
│ %4 = @_7::Core.Const(0.0)
│ %5 = (%4 isa Main.Float64)::Core.Const(true)
└─── goto #3 if not %5
2 ── goto #4
3 ── Core.Const(:(@_7))
│ Core.Const(:(Base.convert(Main.Float64, %8)))
│ Core.Const(:(Main.Float64))
└─── Core.Const(:(@_7 = Core.typeassert(%9, %10)))
4 ┄─ %12 = @_7::Core.Const(0.0)
│ (ll = %12)
│ %14 = Main.eachindex(data)::Base.OneTo{Int64}
│ (@_4 = Base.iterate(%14))
│ %16 = @_4::Union{Nothing, Tuple{Int64, Int64}}
│ %17 = (%16 === nothing)::Bool
│ %18 = Base.not_int(%17)::Bool
└─── goto #10 if not %18
5 ┄─ %20 = @_4::Tuple{Int64, Int64}
│ (i = Core.getfield(%20, 1))
│ %22 = Core.getfield(%20, 2)::Int64
│ %23 = Main.:+::Core.Const(+)
│ %24 = ll::Float64
│ %25 = Main.:+::Core.Const(+)
│ %26 = Main.:-::Core.Const(-)
│ %27 = Main.log::Core.Const(log)
│ %28 = Main.:+::Core.Const(+)
│ %29 = Main.exp::Core.Const(exp)
│ %30 = Main.Xθ::Core.Const(Main.Xθ)
│ %31 = i::Int64
│ %32 = Base.getindex(data, %31)::Observation
│ %33 = (%30)(θ, %32)::Float64
│ %34 = (%29)(%33)::Float64
│ %35 = (%28)(1, %34)::Float64
│ %36 = (%27)(%35)::Float64
│ %37 = (%26)(%36)::Float64
│ %38 = Main.:*::Core.Const(*)
│ %39 = Main.getfield::Core.Const(getfield)
│ %40 = i::Int64
│ %41 = Base.getindex(data, %40)::Observation
│ %42 = (%39)(%41, :y)::Any
│ %43 = Main.Xθ::Core.Const(Main.Xθ)
│ %44 = i::Int64
│ %45 = Base.getindex(data, %44)::Observation
│ %46 = (%43)(θ, %45)::Float64
│ %47 = (%38)(%42, %46)::Any
│ %48 = (%25)(%37, %47)::Any
│ %49 = (%23)(%24, %48)::Any
│ (@_8 = %49)
│ %51 = @_8::Any
│ %52 = (%51 isa Main.Float64)::Bool
└─── goto #7 if not %52
6 ── goto #8
7 ── %55 = @_8::Any
│ %56 = Base.convert(Main.Float64, %55)::Any
│ %57 = Main.Float64::Core.Const(Float64)
└─── (@_8 = Core.typeassert(%56, %57))
8 ┄─ %59 = @_8::Float64
│ (ll = %59)
│ (@_4 = Base.iterate(%14, %22))
│ %62 = @_4::Union{Nothing, Tuple{Int64, Int64}}
│ %63 = (%62 === nothing)::Bool
│ %64 = Base.not_int(%63)::Bool
└─── goto #10 if not %64
9 ── goto #5
10 ┄ %67 = ll::Float64
└─── return %67
thanks a lot in advance.