Dear All,
I am trying to understand the role of generated functions in Zygote / Cassette, as want to properly understand “overdubbing design pattern”. I decided to test the idea by writing a profiler, which will log the start and end of each function called. This is kind of an alternative to the sampling profiler available in Julia. Note that the implementation below is not a package (and likely never going to be), it is something I decided to implement to educate myself. The full code is
"Full code
@generated function overdub(ctx::Context, f::F, args...) where {F}
ci = retrieve_code_info((F, args...))
slot_vars = Dict(enumerate(ci.slotnames))
# ssa_vars = Dict(i => gensym(:left) for i in 1:length(ci.code))
ssa_vars = Dict(i => Symbol(:L, i) for i in 1:length(ci.code))
used = assigned_vars(ci.code) |> distinct
exprs = []
for i in 1:length(args)
push!(exprs, Expr(:(=), ci.slotnames[i+1], :(args[$(i)])))
end
for (i, ex) in enumerate(ci.code)
ex = rename_args(ex, slot_vars, ssa_vars)
if ex isa Core.ReturnNode
push!(exprs, Expr(:return, ex.val))
continue
end
if timable(ex)
fname = exportname(ex)
fname = :(Symbol($(fname)))
push!(exprs, Expr(:call, :push!, :to, :(:start), fname))
ex = overdubbable(ex) ? Expr(:call, :overdub, :ctx, ex.args...) : ex
ex = i ∈ used ? Expr(:(=) , ssa_vars[i], ex) : ex
push!(exprs, ex)
push!(exprs, Expr(:call, :push!, :to, :(:stop), fname))
else
ex = i ∈ used ? Expr(:(=) , ssa_vars[i], ex) : ex
push!(exprs, ex)
end
end
r = Expr(:block, exprs...)
@show r
# println(" ")
r
end
For debugging purposes I have not hygienized inserted symbols and also I print the generated code, which is against the generated functions not to have side effects, but it is for debugging and final code will not have it. The above code does not recursively descend into the function calls, and it produces the result I expect, as
julia> function foo(x, y)
z = x * y
z + sin(y)
end
foo (generic function with 1 method)
julia> reset!(to)
0
julia> overdub(Context(), foo, 1.0, 1.0)
r = quote
x = args[1]
y = args[2]
z = x * y
L2 = z
push!(to, :start, Symbol(Main.sin))
L3 = Main.sin(y)
push!(to, :stop, Symbol(Main.sin))
push!(to, :start, Symbol(Main.:+))
L4 = L2 + L3
push!(to, :stop, Symbol(Main.:+))
return L4
end
1.8414709848078965
julia> to
0.0 start sin
1.9073486328125e-6 stop sin
3.814697265625e-6 start +
3.814697265625e-6 stop +
If I know enable the recursion by redefining overdubbable
as
function overdubbable(ex::Expr)
ex.head != :call && return(false)
length(ex.args) < 2 && return(false)
(ex.args[1] isa Core.IntrinsicFunction) && return(false)
return(true)
end
the overdubbing goes nuts and I am apparently overdubbing a function that I have already overdubbed.
A recursive overdubbing that went nuts
julia> overdub(Context(), foo, 1.0, 1.0)
r = quote
x = args[1]
y = args[2]
z = x * y
L2 = z
push!(to, :start, Symbol(Main.sin))
L3 = overdub(ctx, Main.sin, y)
push!(to, :stop, Symbol(Main.sin))
push!(to, :start, Symbol(Main.:+))
L4 = overdub(ctx, Main.:+, L2, L3)
push!(to, :stop, Symbol(Main.:+))
return L4
end
r = quote
x = args[1]
Core.NewvarNode(:(_3))
Core.NewvarNode(:(_4))
Core.NewvarNode(:(_5))
absx = Base.Math.abs(x)
L5 = absx
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L6 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.pi)
push!(to, :stop, Symbol((Expr(:static_parameter, 1))))
push!(to, :start, Symbol(Base.Math.:/))
L7 = overdub(ctx, Base.Math.:/, L6, 4)
push!(to, :stop, Symbol(Base.Math.:/))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L5, L7)
push!(to, :stop, Symbol(Base.Math.:<))
goto %18 if not %8
L10 = absx
push!(to, :start, Symbol(Base.Math.eps))
L11 = overdub(ctx, Base.Math.eps, (Expr(:static_parameter, 1)))
push!(to, :stop, Symbol(Base.Math.eps))
push!(to, :start, Symbol(Base.Math.sqrt))
L12 = overdub(ctx, Base.Math.sqrt, L11)
push!(to, :stop, Symbol(Base.Math.sqrt))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L10, L12)
push!(to, :stop, Symbol(Base.Math.:<))
goto %16 if not %13
return x
push!(to, :start, Symbol(Base.Math.sin_kernel))
L16 = overdub(ctx, Base.Math.sin_kernel, x)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L16
push!(to, :start, Symbol(Base.Math.isnan))
overdub(ctx, Base.Math.isnan, x)
push!(to, :stop, Symbol(Base.Math.isnan))
goto %22 if not %18
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L20 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.NaN)
push!(to, :stop, Symbol($(Expr(:static_parameter, 1))))
return L20
push!(to, :start, Symbol(Base.Math.isinf))
overdub(ctx, Base.Math.isinf, x)
push!(to, :stop, Symbol(Base.Math.isinf))
goto %25 if not %22
push!(to, :start, Symbol(Base.Math.sin_domain_error))
overdub(ctx, Base.Math.sin_domain_error, x)
push!(to, :stop, Symbol(Base.Math.sin_domain_error))
push!(to, :start, Symbol(Base.Math.rem_pio2_kernel))
L25 = overdub(ctx, Base.Math.rem_pio2_kernel, x)
push!(to, :stop, Symbol(Base.Math.rem_pio2_kernel))
push!(to, :start, Symbol(Base.indexed_iterate))
L26 = overdub(ctx, Base.indexed_iterate, L25, 1)
push!(to, :stop, Symbol(Base.indexed_iterate))
n = Core.getfield(L26, 1)
var"" = Core.getfield(L26, 2)
push!(to, :start, Symbol(Base.indexed_iterate))
L29 = overdub(ctx, Base.indexed_iterate, L25, 2, var"")
push!(to, :stop, Symbol(Base.indexed_iterate))
y = Core.getfield(L29, 1)
n = n & 3
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 0)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %36 if not %32
push!(to, :start, Symbol(Base.Math.sin_kernel))
L34 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L34
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 1)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %40 if not %36
push!(to, :start, Symbol(Base.Math.cos_kernel))
L38 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
return L38
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 2)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %45 if not %40
push!(to, :start, Symbol(Base.Math.sin_kernel))
L42 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L43 = overdub(ctx, Base.Math.:-, L42)
push!(to, :stop, Symbol(Base.Math.:-))
return L43
push!(to, :start, Symbol(Base.Math.cos_kernel))
L45 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L46 = overdub(ctx, Base.Math.:-, L45)
push!(to, :stop, Symbol(Base.Math.:-))
return L46
end
r = quote
x = args[1]
Core.NewvarNode(:(_3))
Core.NewvarNode(:(_4))
Core.NewvarNode(:(_5))
absx = Base.Math.abs(x)
L5 = absx
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L6 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.pi)
push!(to, :stop, Symbol((Expr(:static_parameter, 1))))
push!(to, :start, Symbol(Base.Math.:/))
L7 = overdub(ctx, Base.Math.:/, L6, 4)
push!(to, :stop, Symbol(Base.Math.:/))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L5, L7)
push!(to, :stop, Symbol(Base.Math.:<))
goto %18 if not %8
L10 = absx
push!(to, :start, Symbol(Base.Math.eps))
L11 = overdub(ctx, Base.Math.eps, (Expr(:static_parameter, 1)))
push!(to, :stop, Symbol(Base.Math.eps))
push!(to, :start, Symbol(Base.Math.sqrt))
L12 = overdub(ctx, Base.Math.sqrt, L11)
push!(to, :stop, Symbol(Base.Math.sqrt))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L10, L12)
push!(to, :stop, Symbol(Base.Math.:<))
goto %16 if not %13
return x
push!(to, :start, Symbol(Base.Math.sin_kernel))
L16 = overdub(ctx, Base.Math.sin_kernel, x)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L16
push!(to, :start, Symbol(Base.Math.isnan))
overdub(ctx, Base.Math.isnan, x)
push!(to, :stop, Symbol(Base.Math.isnan))
goto %22 if not %18
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L20 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.NaN)
push!(to, :stop, Symbol($(Expr(:static_parameter, 1))))
return L20
push!(to, :start, Symbol(Base.Math.isinf))
overdub(ctx, Base.Math.isinf, x)
push!(to, :stop, Symbol(Base.Math.isinf))
goto %25 if not %22
push!(to, :start, Symbol(Base.Math.sin_domain_error))
overdub(ctx, Base.Math.sin_domain_error, x)
push!(to, :stop, Symbol(Base.Math.sin_domain_error))
push!(to, :start, Symbol(Base.Math.rem_pio2_kernel))
L25 = overdub(ctx, Base.Math.rem_pio2_kernel, x)
push!(to, :stop, Symbol(Base.Math.rem_pio2_kernel))
push!(to, :start, Symbol(Base.indexed_iterate))
L26 = overdub(ctx, Base.indexed_iterate, L25, 1)
push!(to, :stop, Symbol(Base.indexed_iterate))
n = Core.getfield(L26, 1)
var"" = Core.getfield(L26, 2)
push!(to, :start, Symbol(Base.indexed_iterate))
L29 = overdub(ctx, Base.indexed_iterate, L25, 2, var"")
push!(to, :stop, Symbol(Base.indexed_iterate))
y = Core.getfield(L29, 1)
n = n & 3
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 0)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %36 if not %32
push!(to, :start, Symbol(Base.Math.sin_kernel))
L34 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L34
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 1)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %40 if not %36
push!(to, :start, Symbol(Base.Math.cos_kernel))
L38 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
return L38
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 2)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %45 if not %40
push!(to, :start, Symbol(Base.Math.sin_kernel))
L42 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L43 = overdub(ctx, Base.Math.:-, L42)
push!(to, :stop, Symbol(Base.Math.:-))
return L43
push!(to, :start, Symbol(Base.Math.cos_kernel))
L45 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L46 = overdub(ctx, Base.Math.:-, L45)
push!(to, :stop, Symbol(Base.Math.:-))
return L46
end
r = quote
x = args[1]
Core.NewvarNode(:(_3))
Core.NewvarNode(:(_4))
Core.NewvarNode(:(_5))
absx = Base.Math.abs(x)
L5 = absx
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L6 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.pi)
push!(to, :stop, Symbol((Expr(:static_parameter, 1))))
push!(to, :start, Symbol(Base.Math.:/))
L7 = overdub(ctx, Base.Math.:/, L6, 4)
push!(to, :stop, Symbol(Base.Math.:/))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L5, L7)
push!(to, :stop, Symbol(Base.Math.:<))
goto %18 if not %8
L10 = absx
push!(to, :start, Symbol(Base.Math.eps))
L11 = overdub(ctx, Base.Math.eps, (Expr(:static_parameter, 1)))
push!(to, :stop, Symbol(Base.Math.eps))
push!(to, :start, Symbol(Base.Math.sqrt))
L12 = overdub(ctx, Base.Math.sqrt, L11)
push!(to, :stop, Symbol(Base.Math.sqrt))
push!(to, :start, Symbol(Base.Math.:<))
overdub(ctx, Base.Math.:<, L10, L12)
push!(to, :stop, Symbol(Base.Math.:<))
goto %16 if not %13
return x
push!(to, :start, Symbol(Base.Math.sin_kernel))
L16 = overdub(ctx, Base.Math.sin_kernel, x)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L16
push!(to, :start, Symbol(Base.Math.isnan))
overdub(ctx, Base.Math.isnan, x)
push!(to, :stop, Symbol(Base.Math.isnan))
goto %22 if not %18
push!(to, :start, Symbol((Expr(:static_parameter, 1))))
L20 = overdub(ctx, (Expr(:static_parameter, 1)), Base.Math.NaN)
push!(to, :stop, Symbol($(Expr(:static_parameter, 1))))
return L20
push!(to, :start, Symbol(Base.Math.isinf))
overdub(ctx, Base.Math.isinf, x)
push!(to, :stop, Symbol(Base.Math.isinf))
goto %25 if not %22
push!(to, :start, Symbol(Base.Math.sin_domain_error))
overdub(ctx, Base.Math.sin_domain_error, x)
push!(to, :stop, Symbol(Base.Math.sin_domain_error))
push!(to, :start, Symbol(Base.Math.rem_pio2_kernel))
L25 = overdub(ctx, Base.Math.rem_pio2_kernel, x)
push!(to, :stop, Symbol(Base.Math.rem_pio2_kernel))
push!(to, :start, Symbol(Base.indexed_iterate))
L26 = overdub(ctx, Base.indexed_iterate, L25, 1)
push!(to, :stop, Symbol(Base.indexed_iterate))
n = Core.getfield(L26, 1)
var"" = Core.getfield(L26, 2)
push!(to, :start, Symbol(Base.indexed_iterate))
L29 = overdub(ctx, Base.indexed_iterate, L25, 2, var"")
push!(to, :stop, Symbol(Base.indexed_iterate))
y = Core.getfield(L29, 1)
n = n & 3
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 0)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %36 if not %32
push!(to, :start, Symbol(Base.Math.sin_kernel))
L34 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
return L34
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 1)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %40 if not %36
push!(to, :start, Symbol(Base.Math.cos_kernel))
L38 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
return L38
push!(to, :start, Symbol(Base.Math.:(==)))
overdub(ctx, Base.Math.:(==), n, 2)
push!(to, :stop, Symbol(Base.Math.:(==)))
goto %45 if not %40
push!(to, :start, Symbol(Base.Math.sin_kernel))
L42 = overdub(ctx, Base.Math.sin_kernel, y)
push!(to, :stop, Symbol(Base.Math.sin_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L43 = overdub(ctx, Base.Math.:-, L42)
push!(to, :stop, Symbol(Base.Math.:-))
return L43
push!(to, :start, Symbol(Base.Math.cos_kernel))
L45 = overdub(ctx, Base.Math.cos_kernel, y)
push!(to, :stop, Symbol(Base.Math.cos_kernel))
push!(to, :start, Symbol(Base.Math.:-))
L46 = overdub(ctx, Base.Math.:-, L45)
push!(to, :stop, Symbol(Base.Math.:-))
return L46
end
Can someone help me by explaining what I doing wrong or which part went nuts? I would like to be able to recursively overdubb functions that are called within overdubbed functions (for example here the function sin
. Also, I am also totally surprised where the nodes Core.NewvarNode(:(_3))
came from.
Note that I know about IRTools and Cassette, but as I have said, I would like to understand the mechanics of those packages.
Thank you very much in advance.
Tomas Pevny