I’ve been looking into memoisation in Julia for a bit and although there are quite a few packages, as far as I tell all of them either store the results in some variation of Dict{Any, Any} or require you to commit to a single type signature for the memoised function.
I think I’ve come up with a way to avoid that (i.e. have type stable, yet generic memoisation). The idea is to let the macro create a generated function which will have access to the argument types and then use Core.Compiler.return_type to get the return type. With this information all types are fully specified at compile time and the container (Dict or whatever) can be fully specific.
There are some obvious caveats, chiefly among them that the memoised function has all the constraints of a generated function. However, for the usual use cases of memoisation that should not be an issue, I think.
Is there anything else wrong with this approach that I might have missed?
Working implementation below.
using MacroUtils
macro cached(f)
fun = splitdef(f)
fname = fun[:name]
newfname = gensym(string(fname))
fun[:name] = newfname
args = fun[:args]
# needs to be generated here, otherwise quoting gets wonky
fcall = Meta.quot(quote
get!($(Expr(:$, :cache)), ($(args...),), $newfname($(args...)))
end)
quote
# generate actual function
$(esc(combinedef(fun)))
@generated function $(esc(fname))($(esc.(args)...))
arg_types = Tuple{$(esc.(args)...)}
ret_type = Core.Compiler.return_type($(esc(newfname)), arg_types)
cache = Dict{arg_types, ret_type}()
# TODO store somewhere appropriate
global cfuns
cfuns[($(esc(fname)), $(esc.(args)...))] = cache
$fcall
end
end
end
Good idea: even if what return_types changes, it will remain correct and it will be efficient for efficient functions. You probably want to use an IdDict instead of Dict for the cache. I think it would make sense to add this logic to the Memoize.jl implementation, which I think was originally written before IdDict had type parameters.
Interesting, I didn’t even know those existed. I’m not sure I would feel comfortable trusting the compiler to decide which version to use, though. In my particular use case the function call is happening hundreds or thousands of times in the inner loop of a simulation, so I know I want the fastest possible code.
I believe the non-generated path is mainly for when Julia is interpreting the code instead of compiling it - because then it’s faster to interpret the non-generated code than generate the generated code and then interpret that. So I don’t think you need to worry.
Currently the @generated branch is always used. In the future, which branch is used will mostly depend on whether the JIT compiler is enabled and available,
This is the body of the generated function using get! to check for the value and setting it to the calculated value if not present. However, I accidentally used the non-closure form, so the value is actually calculated every time which is pretty pointless if you want to do memoisation…
Unfortunately the obvious solution, using get!(fcall, dict, key) doesn’t work due to the restrictions applying to generated functions. So, I think the only way to do it is to check and set manually. This is sub-optimal as it does the lookup twice, but maybe in the grand scheme of things that’s still not so bad.
There’s a really clean way to implement this without any macros:
julia> struct Closure{F,A} <: Function
f::F
args::A
end
julia> (x::Closure)() = x.f(x.args...)
julia> @generated function _memo_invoke(::Type{R}, f, args::Tuple) where R
cache = IdDict{args,R}()
:(get!(Closure(f, args), $cache, args))
end
_memo_invoke (generic function with 1 method)
julia> function memo_invoke(f, args...)
R = Core.Compiler.return_type(f, typeof(args))
_memo_invoke(R, f, args)
end
memo_invoke (generic function with 1 method)
julia> f(x) = @show (x, x+1)
f (generic function with 1 method)
julia> memo_invoke(f, 3)
(x, x + 1) = (3, 4)
(3, 4)
julia> memo_invoke(f, 3)
(3, 4)
julia> @code_warntype memo_invoke(f, 3)
MethodInstance for memo_invoke(::typeof(f), ::Int64)
from memo_invoke(f, args...) in Main at REPL[4]:1
Arguments
#self#::Core.Const(memo_invoke)
f::Core.Const(f)
args::Tuple{Int64}
Body::Tuple{Int64, Int64}
1 ─ %1 = Core.Compiler::Core.Const(Core.Compiler)
│ %2 = Base.getproperty(%1, :return_type)::Core.Const(Core.Compiler.return_type)
│ %3 = Main.typeof(args)::Core.Const(Tuple{Int64})
│ %4 = (%2)(f, %3)::Core.Const(Tuple{Int64, Int64})
│ %5 = Main._memo_invoke(%4, f, args)::Tuple{Int64, Int64}
└── return %5
The Closure thing is just because you can’t use actual closures in generated functions. Comparing the two memo_invoke(f, 3) lines, you can see that f(3) was not called the second time. And the @code_warntype output shows the result is type-stable.
Of course you could still write a @memoize macro around this.
That’s a very elegant solution and the Closure trick is neat, very nice!
Edit:
The one issue I have with it is that I don’t see an easy way to swap out the cache construction per cached function, at least as long as I want to allow user-specified constructor expressions.
Well, thinking about it, I guess I could use a custom type as a functor by having its constructor wrap the provided expression. It feels a bit wrong, but it seems it should work.
It sounds like you might be yak shaving. If you need top performance, just memoize by hand. It’s pretty easy to do.
foo(i::Integer) = 2i
mem_foo = let
cache = Dict{Int, Int}()
function(i::Integer)
get!(cache, i) do
foo(i)
end
end
end
Also, sometimes in situations where you are using memoization, there’s actually a dynamic programming algorithm that you could use instead, so keep an eye out for that.
Or maybe you’re just going down this rabbit hole for fun and to provide a generic, type stable @memoize.
You are right of course, doing it manually, or simply pre-calculating the cache would be simple and would guarantee best performance. And that’s how I used to do it, especially back in my C++ days.
However, I’m now in a situation where this kind of pattern pops up all over the place (agent-based models of social systems, happy to elaborate), plus, over the years I have become more and more convinced that good simulation code should read as close to pseudo-code as possible (i.e. all “technical” stuff should be hidden away, happy to elaborate on this point as well ), plus Julia actually makes it reasonably straightforward to find an optimal (or close enough to) general solution. So, taking a couple of days to solve this generally seems like a good investment at this point.
It doesn’t hurt that it’s fun, though .
There’s a less magical way to do this that doesn’t involve @generated or Core.Compiler.return_type. All you have to do is provide an input to your function when you memoize it.
function memoize(f, args...)
y = f(args...)
cache = Dict(args => y)
function mem_f(args...)
get!(cache, args) do
f(args...)
end
end
y, mem_f
end
julia> f(x) = 2x
f (generic function with 1 method)
julia> y, mem_f = memoize(f, 1)
(2, mem_f)
julia> mem_f(2)
4
julia> mem_f.cache
Dict{Tuple{Int64}, Int64} with 2 entries:
(2,) => 4
(1,) => 2
There might be ways to make this more generic. I haven’t thought through what happens if you’re calling different methods of f with different input and output types.
I’m not sure if it makes sense to use IdDict. It might depend on your use case. Dict does conversion in setindex!, which is usually what you want. IdDict does not do conversion in setindex!:
julia> d = Dict(1 => 2);
julia> d[3.0] = 4.0; d
Dict{Int64, Int64} with 2 entries:
3 => 4
1 => 2
julia> id = IdDict(1 => 2);
julia> id[3.0] = 4.0; id
ERROR: ArgumentError: 3.0 is not a valid key for type Int64
Sure, this would work as well. But that would actually be quite inconvenient to use. It’s not unrealistic to assume that my simulation code has no idea which types it is going to be called with (or should have no idea). The point where the types are decided is a few layers up and at that point no details about memoised functions are and should be known.
That is in fact one of my main motivations behind this. I want to be able to use memoisation for generic code i.e. without committing to a specific input/output type but at the same time I would like the result to be close to “hand-memoised” code in terms of efficiency.
I even have some cases where it would make sense to use an Array. Ideally, I would like to be able to to provide a type or expression for the cache at time of memoisation. It’s not even difficult to do, it’s more a question of finding a nice syntax for it.
It’s not quite as straightforward, I think.
First, we either have to provide a way to refer to the key and value types when specifying the container type or we have to accept that the type parameters can only ever be specified in one way (e.g. Type{K, V} as for Dict). Then there might be situations where we want to either pre-initialise the container somehow or use an existing one, so it would be nice to be able to provide an arbitrary expression (some of the existing memoisation packages do that, notably Memoize.jl and Memoization.jl).
A somewhat related issue is the corner case of single argument functions. Internally the argument type (which ends up as the key type of the container) will always be a tuple. For 1-tuples, however, it might make sense to flatten them (at least in some cases), so that the plain values can be used as keys. This depends on the container type, however. For a Dict (or equivalent) it doesn’t really matter if the key type is Int or Tuple{Int}, but for an array only the former will work.
As pretty as this is, I run into issues as soon as I want to access the user-facing or the helper function in the generating part of the generated function. I need this because a) I want to store the per-method cache in a globally accessible Dict (so that the cache can be reset by the user) and b) I would prefer to determine the return type in the generated function (unless someone can convince me that Core.Compiler.return_type is a 0-cost call).
I could use the previously mentioned hack of abusing types as functions, but I think at that point I’d rather generate everything in a macro.
I have the bare bones of a working package now (here), but unfortunately I’m running into issues either with the compiler or the capabilities of generated functions (discussion here). Any help or suggestions would be greatly appreciated.