Hi,
NaiveNASlib.jl uses a run-of-the-mill representation of a computation graph composed of AbstractVertex structs where each vertex is associated with an operation (e.g a neural network layer) as well as its input vertices.
Tl;dr question is if there are any cool/efficient functional/julian ways to evaluate a computation graph with this structure?
Background and current thinking:
As actually evaluating the graph is not really what the package is about, I initially just left it at the following simple recursive function:
function output!(memo::Dict{AbstractVertex, Any}, v::AbstractVertex)
# Calculate outputs which are not already calculated
return get!(memo, v) do
inpt = map(iv -> output!(memo, iv), inputs(v))
out = v(inpt...)
end
end
Now, when trying to differentiate though this with Zygote I first ran into that get!
hits non-differentiable code. While this was easily workaroundable by either defining an adjoint for get!
or rewriting the function to not use get!
I ran into the harder issue that the performance was very poor. The first time a graph is evaluated generally takes upwards a minute, but in extreme cases even hours or complete stall.
I have tried Junos profiler and for “mild cases” I could not see anything special in the flame chart (comparing a simple linear graph with the equivalent Chain). In worse cases I got the message that the profiling buffer was full and the output I got looked just as the “mild cases”.
Current workaround is to do the recursion before execution and handle the actual execution with a loop:
Flux.Zygote.@adjoint function output!(memo::Dict{AbstractVertex, Any}, v::AbstractVertex)
return Flux.Zygote.pullback(__context__, output_loop!, memo, v)
end
function output_loop!(memo, v)
# flatten call is non-differentiable as it does array mutation. It also is not part of the computation we want to differentiate
vs = nograd() do
# flatten returns all input ancestors to v in topological order
# We also provide all vertices for which we have the output already in memo
# so we don't do unnecessary calculations (recusion will stop when it hits an existing vertex).
flatten(v, collect(AbstractVertex, keys(memo)))[length(memo)+1:end]
end
for vn in vs
inpt = map(iv -> memo[iv], inputs(vn))
memo[vn] = vn(inpt...)
end
return memo[v]
end
This for some reason seems to get rid of the extremely long first-call times I see, although I don’t know why exactly.
Now this might be OOP damage, but whenever I need to represent something with a Dict, I always get a feeling that there is some aspect of the domain which I have not captured correctly.
I have peeked at a few other DL libraries with graphs (tensorflow, deeplearning4j, owl) and they all seem to have a nullable mutable output field in their vertices which to me is an even worse design smell than the dict. Despite this probably being more performant I would rather take the performance hit of the dict lookup than go this way.
I have also briefly looked at Memoize.jl, but it seems to also use the Dict approach and in addition it was not clear to me how to nicely deal with making sure all vertices having a vertex as input saw the same memoized function.
One thing I considered was that it is trivial to generate an expression which evaluates to an “unrolled” function which evaluates the graph, e.g something like this:
function(in1, in2,…)
out1 = vertex1_op(in1)
out2 = vertex2_op(out1)
out3 = vertex3_op(in2, out1)
out4 = vertex4_op(out3, out2)
etc…
In other words, a write simple “compliler” for the computation graphs defined in the language of NaiveNASlib.
However, after encountering and reading up on world age problems I abandoned this approach as I don’t think it is possible to do this without using invokelatest (which apart from any performance issues is non-differentiable and I don’t think it is possible to write an adjoint for it). Is this the correct conclusion w.r.t metaprogramming in this case?