Augmenting a function


#1

Say I wanted to write a classical (non-sampling) profiler. In Lisp/Python, I can replace any to-be-profiled function F with a function F’ that takes note of the time, calls F, then takes note of the time again. Something like this:

function profile(fname::Symbol)
   old_f = eval(fname)
   eval(quote
     function $fname(args...; kwargs...)
       time0 = time()
       res = $old_f(args...; kwargs...)
       dt = time() - time0 # store it somewhere
       return res
     end
   end
end

Obviously, multiple dispatch throws a wrench into that scheme. Is there any way of making this work (even if I have to use non-exported Base/Code functions)? Can I somehow ball up the method table for F and assign it to a different name?


#2

Do you want to profile all methods of a function? Or just one particular method?

Maybe a macro like https://github.com/simonster/Memoize.jl/blob/master/src/Memoize.jl would work for you? @memoize is taking an arbitrary function and storing the arguments and results into a lookup table, you want to take an arbitrary function and put the total time it has run in a lookup table.


#3

Right, macros are the obvious solutions, but it involves manually modifying the source code to insert the macro. In Python/Lisp, I can use the above technique on arbitrary already-compiled functions.


#4

You would just have to add it to the definition, not every call site. But I guess if you are trying to add it to some base functions, then that wouldn’t work for you.

There is the methods function which will give you a list of methods of a function for example:

julia> methods(sum)
# 16 methods for generic function "sum":
sum(::Base.#abs, x::AbstractSparseArray{Tv,Ti,1} where Ti where Tv) in Base.SparseArrays at sparse/sparsevector.jl:1386
sum(::Base.#abs2, x::AbstractSparseArray{Tv,Ti,1} where Ti where Tv) in Base.SparseArrays at sparse/sparsevector.jl:1386
sum(x::Tuple{Any,Vararg{Any,N} where N}) in Base at tuple.jl:323
sum(r::StepRangeLen) in Base at twiceprecision.jl:279
sum(r::Base.Use_StepRangeLen_Instead) in Base at deprecated.jl:1252
sum(r::Range{#s45} where #s45<:Real) in Base at range.jl:851
sum(f::Union{Function, Type}, a) in Base at reduce.jl:347
sum(B::BitArray) in Base at bitarray.jl:1825
sum(A::BitArray, region) in Base at bitarray.jl:1824
sum(x::AbstractSparseArray{Tv,Ti,1} where Ti where Tv) in Base.SparseArrays at sparse/sparsevector.jl:1361
sum(a::AbstractArray{Bool,N} where N) in Base at reduce.jl:360
sum(f::Function, A::AbstractArray, region) in Base at reducedim.jl:570
sum(arr::AbstractArray{BigInt,N} where N) in Base.GMP at gmp.jl:520
sum(arr::AbstractArray{BigFloat,N} where N) in Base.MPFR at mpfr.jl:623
sum(A::AbstractArray, region) in Base at reducedim.jl:572
sum(a) in Base at reduce.jl:359

julia> methods(map)
# 47 methods for generic function "map":
map(::Base.#zero, A::BitArray) in Base at bitarray.jl:1860
map(::Base.#one, A::BitArray) in Base at bitarray.jl:1861
map(::Base.#identity, A::BitArray) in Base at bitarray.jl:1862
map(::Base.#<=, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Base.#==, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Base.#<, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Base.#>, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Type{T}, r::StepRange) where T<:Real in Base at abstractarray.jl:851
map(::Type{T}, r::UnitRange) where T<:Real in Base at abstractarray.jl:852
map(::Type{T}, r::StepRangeLen) where T<:AbstractFloat in Base at abstractarray.jl:853
map(::Type{T}, r::LinSpace) where T<:AbstractFloat in Base at abstractarray.jl:855
map(::Union{Base.#!, Base.#~}, A::BitArray) in Base at bitarray.jl:1859
map(::Union{Base.#&, Base.#*, Base.#min}, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Union{Base.#max, Base.#|}, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Union{Base.#!=, Base.#xor}, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(::Union{Base.#>=, Base.#^}, A::BitArray, B::BitArray) in Base at bitarray.jl:1877
map(f::Function, bi::Base.LibGit2.GitBranchIter) in Base.LibGit2 at libgit2/reference.jl:336
map(f::Function, walker::Base.LibGit2.GitRevWalker; oid, range, by, rev, count) in Base.LibGit2 at libgit2/walker.jl:59
map(f, v::SimpleVector) in Base at essentials.jl:274
map(f, t::Tuple{}) in Base at tuple.jl:157
map(f, t::Tuple{}, s::Tuple{}) in Base at tuple.jl:176
map(f::Tf, A::SparseVector) where Tf in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:68
map(f::Tf, A::SparseMatrixCSC) where Tf in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:69
map(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf, N} in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:70
map(f::Tf, A::Union{SparseMatrixCSC, SparseVector}, Bs::Vararg{Union{SparseMatrixCSC, SparseVector},N}) where {Tf, N} in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:72
map(f::Tf, A::Union{Bidiagonal, Diagonal, SymTridiagonal, Tridiagonal}) where Tf in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:1121
map(f::Tf, A::Union{Bidiagonal, Diagonal, SparseMatrixCSC, SymTridiagonal, Tridiagonal}, Bs::Vararg{Union{Bidiagonal, Diagonal, SparseMatrixCSC, SymTridiagonal, Tridiagonal},N}) where {Tf, N} in Base.SparseArrays.HigherOrderFns at sparse/higherorderfns.jl:1122
map(f) in Base at abstractarray.jl:1930
map(f, ::Tuple{}...) in Base at tuple.jl:194
map(f, t::Tuple{Any}) in Base at tuple.jl:158
map(f, t::Tuple{Any}, s::Tuple{Any}) in Base at tuple.jl:177
map(f, t::Tuple{Any,Any}) in Base at tuple.jl:159
map(f, t::Tuple{Any,Any}, s::Tuple{Any,Any}) in Base at tuple.jl:178
map(f, t::Tuple{Any,Any,Any}) in Base at tuple.jl:160
map(f, t::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N) in Base at tuple.jl:168
map(f, t::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N, s::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N) in Base at tuple.jl:184
map(f, t1::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N, t2::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N, ts::Tuple{Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Any,Vararg{Any,N}} where N...) in Base at tuple.jl:200
map(f, t::Tuple) in Base at tuple.jl:161
map(f, t::Tuple, s::Tuple) in Base at tuple.jl:180
map(f, t1::Tuple, t2::Tuple, ts::Tuple...) in Base at tuple.jl:196
map(f, x::Number, ys::Number...) in Base at number.jl:123
map(f, rowvecs::RowVector...) in Base.LinAlg at linalg/rowvector.jl:151
map(f, A::Union{AbstractArray, AbstractSet, Associative}) in Base at abstractarray.jl:1865
map(f, s::AbstractString) in Base at strings/basic.jl:449
map(f, x::Nullable{T}) where T in Base at nullable.jl:281
map(f, A) in Base at abstractarray.jl:1888
map(f, iters...) in Base at abstractarray.jl:1931

#5

I think this works

function profile(f::Function)
    const time_store = Ref((0.0,0)) # accumulator and counter
    (args...) -> begin
        time0 = time()
        res = f(args...)
        dt = time() - time0
        # store it somewhere
        time_store.x = (time_store.x[1] + dt , time_store.x[2]+1)
        return res
    end
end

sum = profile(sum)

for i=1:10
    sum([1,2,3])
end

total_time = sum.time_store.x[1]
n = sum.time_store.x[2]
avg = total_time/n

#un-profile
sum = sum.f

#6

Ah this works by creating a new function that takes any number of arguments with any types. Will keyword arguments work with this approach?


#7

If I add the keyword arguments option, then slight modification is needed since there is another layer of indirection
that the compiler adds to check for keyword arguments.

Another issue is that you cannot replace names of functions that were defined in global scope, only names of imported
(import or using) functions.

I use Juno’s structure method which is a very convenient way to examine the contents of an object without the
print and display functions, just raw data and internal naming.

try using structure(g) in the code below from within Juno, before and after the un-profilling statement at the end.

Here is the code:

module temp
    function g(x;y = 2)
        x*y
      end
    export g
end
using temp

function profile(f::Function)
    const time_store = Ref((0.0,0)) # accumulator and counter
    (args...;kwargs...) -> begin
        time0 = time()
        res = f(args...;kwargs...)
        dt = time() - time0
        # store it somewhere
        time_store.x = (time_store.x[1] + dt , time_store.x[2]+1)
        return res
    end
end

g = profile(g)

for i=1:10
    g(1,y=i)
end

ts = getfield(g,1).contents.time_store.x
total_time = ts[1]
n = ts[2]
avg = total_time/n

#un-profile
g = getfield(g,1).contents.f

#8

I’m writing a tracing library (like the Common Lisp TRACE). It remembers the call-graph with values of calls to arbitrary functions in arbitrary packages.

To achieve this, I have to get the module’s file via an unexported Base function, parse it to find all of its include statements and function definitions, then create an alternate version of each function that stores its arguments and return values somehow. The new definitions are temporarily activated using eval, then reverted. That’s a lot more complicated (and brittle) than the Lisp equivalent. Is there a better way?


#9

Sounds like you need the AST representation of the function you are about to trace…
Then maybe replacing any Call statement (Sorry I don’t know the exact terminology) with your version
of the function that logs time.

you can have a dictionary of function for keys and profiled function for values , for caching the generation of profiling versions.


#10

I think all of the questions in this thread might be answered soon by Jarrett Revels’s in-progress Cassette.jl: https://github.com/jrevels/Cassette.jl/issues/1. It will allow you to intercept arbitrary method invocations with custom code and can implement tracers/debuggers/profilers/autodiff tools/etc.