Improve performance of computation graph evaluation

Hi,

NaiveNASlib.jl uses a run-of-the-mill representation of a computation graph composed of AbstractVertex structs where each vertex is associated with an operation (e.g a neural network layer) as well as its input vertices.

Tl;dr question is if there are any cool/efficient functional/julian ways to evaluate a computation graph with this structure?

Background and current thinking:

As actually evaluating the graph is not really what the package is about, I initially just left it at the following simple recursive function:

function output!(memo::Dict{AbstractVertex, Any}, v::AbstractVertex)
    # Calculate outputs which are not already calculated
    return get!(memo, v) do
        inpt = map(iv -> output!(memo, iv), inputs(v))
        out = v(inpt...)
    end
end

Now, when trying to differentiate though this with Zygote I first ran into that get! hits non-differentiable code. While this was easily workaroundable by either defining an adjoint for get! or rewriting the function to not use get! I ran into the harder issue that the performance was very poor. The first time a graph is evaluated generally takes upwards a minute, but in extreme cases even hours or complete stall.

I have tried Junos profiler and for “mild cases” I could not see anything special in the flame chart (comparing a simple linear graph with the equivalent Chain). In worse cases I got the message that the profiling buffer was full and the output I got looked just as the “mild cases”.

Current workaround is to do the recursion before execution and handle the actual execution with a loop:

Flux.Zygote.@adjoint function output!(memo::Dict{AbstractVertex, Any}, v::AbstractVertex)
    return Flux.Zygote.pullback(__context__, output_loop!, memo, v)
end


function output_loop!(memo, v)
    # flatten call is non-differentiable as it does array mutation. It also is not part of the computation we want to differentiate
    vs = nograd() do
        # flatten returns all input ancestors to v in topological order
        # We also provide all vertices for which we have the output already in memo
        # so we don't do unnecessary calculations (recusion will stop when it hits an existing vertex).
        flatten(v, collect(AbstractVertex, keys(memo)))[length(memo)+1:end]
    end

    for vn in vs
        inpt = map(iv -> memo[iv], inputs(vn))
        memo[vn] = vn(inpt...)
    end
    return memo[v]
end

This for some reason seems to get rid of the extremely long first-call times I see, although I don’t know why exactly.

Now this might be OOP damage, but whenever I need to represent something with a Dict, I always get a feeling that there is some aspect of the domain which I have not captured correctly.

I have peeked at a few other DL libraries with graphs (tensorflow, deeplearning4j, owl) and they all seem to have a nullable mutable output field in their vertices which to me is an even worse design smell than the dict. Despite this probably being more performant I would rather take the performance hit of the dict lookup than go this way.

I have also briefly looked at Memoize.jl, but it seems to also use the Dict approach and in addition it was not clear to me how to nicely deal with making sure all vertices having a vertex as input saw the same memoized function.

One thing I considered was that it is trivial to generate an expression which evaluates to an “unrolled” function which evaluates the graph, e.g something like this:

function(in1, in2,…)
   out1 = vertex1_op(in1)
   out2 = vertex2_op(out1)
   out3 = vertex3_op(in2, out1)
   out4 = vertex4_op(out3, out2)
   etc…

In other words, a write simple “compliler” for the computation graphs defined in the language of NaiveNASlib.

However, after encountering and reading up on world age problems I abandoned this approach as I don’t think it is possible to do this without using invokelatest (which apart from any performance issues is non-differentiable and I don’t think it is possible to write an adjoint for it). Is this the correct conclusion w.r.t metaprogramming in this case?

I don’t know much about the capabilities of Zygote, but maybe I can provide some context from the old days of ReverseDiff. Long, long ago, ReverseDiff did something like your proposed expression-generation step. It worked, but it had exactly the problems you’ve predicted: the generated expression lived in the wrong world age, and working with it ended up being pretty gross (see https://github.com/JuliaDiff/ReverseDiff.jl/issues/70 ).

The fix was to instead switch to having each instruction in the tape act as a closure over its inputs and any pre-allocated outputs. We could then store that closure (which took no inputs and no outputs) as a FunctionWrapper{Nothing, Tuple{}}. The important thing is that a FunctionWrapper{Nothing, Tuple{} is a concrete type, so iterating over the tape and calling the wrapped closure for each operation involves no type instability. You can see the PR that changed this behavior here: https://github.com/JuliaDiff/ReverseDiff.jl/pull/71/files

Here’s an example of what I mean. Let’s create a pre-allocated input and output and a tape of operations. Each operation in the tape is a function that takes no inputs or outputs and simply operates over its captured (closed-over) variables:

julia> function make_tape()
         a = [0, 0, 0]
         b = similar(a)
         op1 = function()
           b .= a .+ 1
         end
         c = similar(b)
         op2 = function()
           c .= b .* 2.0
         end
         a, c, [op1, op2]
       end
make_tape (generic function with 1 method)

julia> in, out, tape = make_tape();

We can use this by modifying the input, calling each op in the tape, and then reading the output:

julia> in .= [1, 2, 3]
3-element Array{Int64,1}:
 1
 2
 3

julia> foreach(tape) do op
         op()
       end

julia> out
3-element Array{Int64,1}:
 4
 6
 8

However, as written, the foreach step is inefficient because each element of tape has a different type:

julia> @code_warntype foreach(op -> op(), tape)
Variables
  #self#::Core.Compiler.Const(foreach, false)
  f::Core.Compiler.Const(getfield(Main, Symbol("##23#24"))(), false)
  itr::Array{Function,1}
  @_4::Union{Nothing, Tuple{Function,Int64}}
  x::Function

We can fix this by wrapping each op in a no-argument FunctionWrapper:

julia> using FunctionWrappers: FunctionWrapper

julia> wrap(op) = FunctionWrapper{Nothing, Tuple{}}(op)
wrap (generic function with 1 method)

julia> wrapped_tape = wrap.(tape)

The wrapped_tape can be used to propagate data through the tape just like the original tape:

julia> in .= [1, 1, 1]
3-element Array{Int64,1}:
 1
 1
 1

julia> foreach(wrapped_tape) do op
         op()
       end

julia> out
3-element Array{Int64,1}:
 4
 4
 4

but now our foreach (i.e. the forward pass) is type-stable!

julia> @code_warntype foreach(op -> op(), wrapped_tape)
Variables
  #self#::Core.Compiler.Const(foreach, false)
  f::Core.Compiler.Const(getfield(Main, Symbol("##27#28"))(), false)
  itr::Array{FunctionWrapper{Nothing,Tuple{}},1}
  @_4::Union{Nothing, Tuple{FunctionWrapper{Nothing,Tuple{}},Int64}}
  x::FunctionWrapper{Nothing,Tuple{}}

I’m not sure this will work with Zygote, but I hope this is at least somewhat helpful.

Have you seen https://github.com/phipsgabler/DynamicComputationGraphs.jl ?

@rdeits: Thanks, this was quite helpful. FunctionWrappers or not, I guess I could try out the tape approach to compilation and see if it has any advantages. It’s not crystal clear right now how to do it, but I think it could be something like traversing the graph topologically and create closures or maybe even something similar to the instructions used in ReverseDiff and put them on a “tape” (i.e in an array or tuple) using some kinds of placeholders as input/output. For it to be compatible with Zygote this must be done it without relying on array mutation, but maybe mutable structs will work.

@Ratingulate: I wish I had seen this package when I started. It’s not super clear to me what the ambition is, but the capability of transforming a seemingly opaque function into a graph which can be dealt with programatically is something which I guess a non-naive NAS library would do. This use case is a bit of the opposite of what I’m after here, but assuming one can modify the graph created by DynamicComputationGraphs without world problems a non-naive NAS library built on this library would perhaps not need to bother about evaluating the computation graph.

Not sure I understand all the subtlety of your task, but here are some decisions from Yota which seems to be somewhat similar.

First of all, overhead of Base.invokelatest() might not be a big problem if you work with anything but very trivial operations. As an example, consider trivial in-place multiplication of 2 100x100 matrices X and Y:

julia> @btime mul!(Z, X, Y);
  27.355 μs (0 allocations: 0 bytes)

julia> @btime Base.invokelatest(mul!, Z, X, Y);
  27.349 μs (1 allocation: 32 bytes)

Base.invokelatest added so little overhead that benchmark couldn’t even catch the difference. It’s also usually much faster than allocating memory for buffers or interpreting graphs. You still can use artificial hacks to avoid dynamic dispatch altogether by compiling your function before defining another one that uses it, but I really recommend to benchmark invokelatest in your settings thoroughly before making more complex decisions.

Second, from your description I understand you want to compile a new function and then differentiate it using Zygote. In Yota I use the opposite approach:

  1. Yota.grad(f, args...) traces execution of function f, creating an unrolled tape, similar to what you described.
  2. Differentiation is done on this tape, gradient operations are recorded onto it. You end up with a single tape doing both - forward (evaluation) and reverse (differentiation) pass.
  3. The tape is compiled and cached. Base.invokelatest() is used to call the compiled function even if it’s defined in later world than the current one. Since compiled tape already contains instructions for calculating gradients, there’s no need to define adjoints for Base.invokelatest().

A few words about performance. Yota uses Cassette (or JuliaInterpreter) to trace the function call, which is pretty slow, so the first call may take from a few seconds to a few minutes. I don’t follow Zygote recently, but in previous versions it used similar approach and may be a subject to the same issue. The following executions of compiled code, however, are much faster (usually as fast as hand-written code).

Also I recommend to optimize tape/graph before compilation, e.g. eliminate common subexpressions, remove unused nodes (e.g. produced during tracing, but not needed in an unrolled graph), use pre-allocated buffers and use in-place operations where possible - in my experience difference between optimized and unoptimized graph can be as large as 100x times.

Hope this is something useful for your package.

1 Like

@dfdx Thanks alot for the insights!

I’m not sure I understand the subtleties either :slight_smile:

Its something along the lines of this: Old recursive way to execute the graph seems to have some problems with Zygote. Rewriting it from recursion to a loop seems to allevite those problems. Is there an even better way to rewrite it?

I’m not really desperate for a solution since the loop workaround seems to work ok, but if nothing else it can make for a useful coding exerecise.

The “compilation” approach was just one kind of extreme variant I considered, but not the “ultimate goal”.

I have indeed not been worried about overhead. Main reason for abandoning invokelatest is that it is non-differentiable and I can’t imagine how one would write an adjoint for it.

Maybe not superimportant, but one difference between the computation graph in NaiveNASlib and the graphs of the other libraries shown in this thread is that it does not seek to decompose a function into its smallest pieces. Each vertex can (and often is) represented by quite many computations (e.g. a neural network layer with weights, bias and activation function, or even several such layers) which is arbitrarly defined by the user as being an interesting building block of the search algorithm.

This is how I now justify not going back and refactor all the type mistakes I made when I started out :slight_smile:

That said, it would be super cool to make a non-naive NAS library which uses cassette or IRTools to mutate arbitrary functions, e.g. look for closed over variables and enable mucking around with their dimensions. It seems like a pretty hard task, but maybe it is similar to the AD task in the sense that it boils down to a few “primitives” and what I assume is a long tail of edge cases and performance optimizations.

What I see matches your description of the performance: After the first call all subsequent calls are indeed quite fast. The initial compilation time is still pretty important since the intended use case implies training many (many) models. Main headache however was the unreasonably long extreme cases which eventually makes the search program stall completely.

I have tried stripping the graph of all things not needed for the execution (e.g. metadata for mutation operations which wraps the actual computations) before fitting parameters, but this did not seem to affect performance in any measurable way.