A typical pattern in Jax to check when a function gets compiled (useful for debugging excessive recompliation) is,
@jax.jit
def f():
print("compiling")
return 0
f() # prints "compiling"
f() # doesnt print anything since the print is removed from a compiled version
Is there anything similarly as easy in Julia? (I’m aware of SnoopCompile which I think would get you this, just maybe not as easy)
In julia a function has a method compiled when called, but some things can invalidate the method and lead to recompilation. A nice tool for looking for this is GitHub - timholy/SnoopCompile.jl: Making packages work faster with more extensive precompilation
Yes, I am aware of SnoopCompile but I’m just curious if there’s any quick and dirty way to get the info by doing something simple like inserting a statement into the function kind of like Jax.