Re-using compiled code

I’m trying to reduce the latency of an application by using precompilation. I went through most of the steps described in the SnoopCompile manual. It seems that I’m running into the issue which is mentioned there, namely that the code that is causing most of the compilation time is not ‘owned’ by my application or package, but by others (mainly Zygote), and thus the precompilation statements emitted by SnoopCompile.write do not make a difference.

I guess an option is to execute some dummy workloads in the module initialization, but it seems that this needs to be redone in every new Julia session. Is there a way to cache and re-use the code generated there?

1 Like

No, just use SnoopPrecompile (EDIT: now replaced by PrecompileTools), install Julia 1.9 (currently in beta), and profit. *.jl files in your package are build scripts, and Julia just stores a snapshot of the result of building—nothing besides __init__ gets re-executed when you load the package.

7 Likes

As Tim mentioned, Julia 1.9 now caches all of this compiled code (thanks to herculean efforts he spearheaded together with a few other Julia devs).

However, keep in mind that moving to 1.9 “only” fixes compilation delays. So in using Library; Library.do_thing() the do_thing command now can be instantaneous even the first time it is called (no compilation necessary). However, the loading can still be slow (i.e. using Library can still take a while). While there is work on making this fast too by default, the current workaround for slow using Library is to compile a sysimage. Then both import time and first-function-call time can be instantaneous.

If you use VSCode this is a convenient guide to automating sysimages Compiling Sysimages · Julia in VS Code

3 Likes

Thanks. I briefly tried SnoopPrecompile before going through the SnoopCompile steps, but didn’t have much success. I will retry more thoroughly now.

I’m not sure I understand the concept of build scripts. Are you suggesting something else than what is written in the SnoopPrecompile docs? According to those, I still have to put some kind of high-level function call that should make most of the calls that are typically slow to compile, into the @precompile_all_calls block, right? Or are you now saying that we can put the entire package into that block?

No, my point is that your package code isn’t executed when you load the package. It runs only during precompilation ([ Info: Precompiling XYZ [abcde...]). What gets cached is a snapshot of the changes your package makes to the running system. That’s what I mean by “your .jl files are just build scripts”: they describe the module that Julia should create, but what gets saved is the module itself, not the code that created it.

This was in response to your comment:

I guess an option is to execute some dummy workloads in the module initialization, but it seems that this needs to be redone in every new Julia session

to make sure you knew that the workload inside @precompile_all_calls only gets run during precompilation. But on 1.9 and higher it will deliver benefits every time you use the package.

Or are you now saying that we can put the entire package into that block?

Definitely not. All that block does is mark a section of code that tells Julia that everything it compiles in order to run it is also something that should end up in the cache file. So put methods, types, consts, etc, outside the block just as usual, and then put a small demo of the operations you want to make fast inside the @precompile_all_calls block.

Keep in mind that before Julia 1.9, Julia couldn’t cache native code, so be sure you’re on 1.9-beta or your results will be less than spectacular.

3 Likes

Thanks. I briefly tried SnoopPrecompile before going through the SnoopCompile steps, but didn’t have much success. I will retry more thoroughly now.

This issue in JuMP is a useful guide on our experience adding SnoopPrecompile:

Here are the relevant PRs:

3 Likes

Ok, thanks for all the pointers and explanations!
One issue that I think explains why I did not see any benefit the first time I tried SnoopPrecompile, was that the code that is seen by Zygote and is problematic (i.e. slow to compile) caused an exception that was caught. So the calls weren‘t executed. So now I‘m disabling catching the error.
The weird thing though is that the exception happens at all; it doesn‘t happen if the same command is executed in a normal script, but only ever happens during precompilation. It’s Zygote complaining about a mutation, but it has never complained at normal runtime. Does this ring any bells? I need to reduce it to a MWE.

EDIT: I cannot reproduce the MWE now. But the problem described above still exists. The code below does not qualify as a MWE, as without the Zygote.@ignore the compute_gradient function does not compile even outside precompilation.

Old post

So below is a MWE which exposes the same error I get during precompilation (but not during runtime). I guess it makes some sense given how reverse AD works, i.e. during precompilation it is not yet possible to determine that the code that contains the mutation will not contribute to the gradient, whereas once the code for AD has been compiled, this branch can be dropped. I still wonder why there is no error during normal compilation, i.e. when the code is compiled on the first call.

Inserting a Zygote.@ignore solves the issue, but it will be harder to solve it in my actual application since it would be needed in the function pairwise of Distances.jl.

module SnoopPrecompileMWE

using Distances
using Zygote: gradient, @ignore

using SnoopPrecompile

struct Loss{Tx, Ty}
    x::Tx
    y::Ty
end

function (l::Loss)(θ)
    dist = SqEuclidean()
    # The following line contains mutating code,
    # but does not need to be differentiated.
    # Without the @ignore, the precompilation directive will fail.
    D = @ignore pairwise(dist, l.x, l.y)
    return sum(θ.variance .* exp.(D ./ θ.lengthscale))
end

function compute_gradient(l, θ)
    return only(gradient(l, θ))
end

export Loss, compute_gradient

@precompile_setup begin
    l = Loss(randn(10), randn(10))
    θ = (variance = 1., lengthscale = 1.)

    @precompile_all_calls begin
        l(θ)
        compute_gradient(l, θ)
    end
end

end # module SnoopPrecompileMWE

This is a MWE that actually works. But I don’t have a fix without going inside the kernelmatrix method, where I would need to Zygote.ignore the pairwise function from Distances.jl, which is not a safe thing to do in general:

module SnoopPrecompileMWE

using KernelFunctions
using Zygote: gradient, @ignore

using SnoopPrecompile

struct Loss{Tx, Ty}
    x::Tx
    y::Ty
end

function (l::Loss)(θ)
    k = θ.variance * SEKernel() ∘ ScaleTransform(θ.lengthscale)
    return sum(kernelmatrix(k, l.x, l.y))
end

# This function compiles when the precompilation block below is commented out,
# but it generates an error 
# "ERROR: LoadError: Mutating arrays is not supported"
# when called within the precompilation block.
function compute_gradient(l, θ)
    return only(gradient(l, θ))
end

export Loss, compute_gradient

@precompile_setup begin
    l = Loss(randn(10), randn(10))
    θ = (variance = 1., lengthscale = 1.)

    @precompile_all_calls begin
        l(θ)
        compute_gradient(l, θ)
    end
end

end # module SnoopPrecompileMWE

It’s almost certainly an issue with Zygote itself, and a Zygote developer would be better suited to answer.

The only precompilation-related fact that seems like it might be relevant: note that __init__ does not run during precompilation, but does when you load the module. So something that runs with the __init__ may be causing the difference. Note that Zygote has a couple of @init blocks, which (if memory serves) are a way of generating __init__.

1 Like

I think that the problem is in pairwise function, which preallocates array for results and then fill it. It therefore performs a mutation, which is something Zygote cannot differentiate over.
You can either solved this by writing a rule, or by rewriting the pairwise such that it will not mutate existing code. For examle compute squared Euclidean distance would be

@. sum(x .^2 , dims = 1) - 2 * x' * y + sum(y .^2, dims = 1)' 
1 Like

Yes, it has been a long-standing issue with KernelFunctions.jl to get those distance computations implemented in a way that is performant, AD-friendly, and GPU-friendly. So for the last hours I’ve been trying to rewrite them using broadcast.

A generic implementation which (aside from AD) covers all cases of interest is:

function KernelFunctions.kernelmatrix(
    κ::KernelFunctions.SimpleKernel, x::AbstractVector, y::AbstractVector
)
    KernelFunctions.validate_inputs(x, y)
    dist = metric(κ)
    return kappa.(Ref(κ), dist.(x, y'))
end

And with this monkey patch put into the module SnoopCompileMWE above, the precompilation statement actually gives results as one would expect for x, y of type Vector{Float64}:

using SnoopPrecompileMWE 

l = Loss(randn(10), randn(10))
θ = (variance = 1., lengthscale = 1.)

@time l(θ)
# w/o precompile statement: 0.182458 seconds (571.37 k allocations: 38.590 MiB, 5.52% gc time)
# w/  precompile statement: 0.000028 seconds (7 allocations: 2.172 KiB)

@time compute_gradient(l, θ)
# w/o precompile statement: 18.626561 seconds (36.44 M allocations: 2.297 GiB, 4.92% gc time)
# w/  precompile statement: 0.000907 seconds (1.38 k allocations: 95.891 KiB)

which is great!
However, it can’t be differentiated if the input is RowVecs or ColVecs (which are conceptually similar to eachrow or eachcol, but custom structs exported by KernelFunctions.jl):

sing SnoopPrecompileMWE
using KernelFunctions: RowVecs

l = Loss(RowVecs(randn(10, 5)), RowVecs(randn(10, 5)))
θ = (variance = 1., lengthscale = 1.)

@time l(θ)
@time compute_gradient(l, θ)

yields a quite nasty error (some parts skipped to make it fit the character limit, but you get the idea):

Summary
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Core.Compiler.VarState, (56,)}[Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(Distances._evaluate), Distances.SqEuclidean, Base.SubArray{Float64, 1, Array{Float64, 2}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, LinearAlgebra.Adjoint{Float64, Base.SubArray{Float64, 1, Array{Float64, 2}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}, Nothing}, Any}, undef=false), Core.Compiler.VarState(typ=Float64, undef=false), Core.Compiler.VarState(typ=Core.Const(val=nothing), undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=true), Core.Compiler.VarState(typ=Any, undef=false), Core.Compiler.VarState(typ=Any, undef=false), Core.Compiler.VarState(typ=Any, undef=false), Core.Compiler.VarState(typ=Any, undef=false), Core.Compiler.VarState(typ=Any, undef=false), Core.Compiler.VarState(typ=Any, undef=false)], i=(57,))
ijl_bounds_error_ints at /cache/build/default-amdci4-7/julialang/julia-release-1-dot-9/src/rtutils.c:194
setindex! at ./array.jl:969 [inlined]
stoverwrite1! at ./compiler/typelattice.jl:599
typeinf_local at ./compiler/abstractinterpretation.jl:2742
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_apply at ./compiler/abstractinterpretation.jl:1545
abstract_call_known at ./compiler/abstractinterpretation.jl:1830
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2536
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
jfptr_abstract_call_13817.clone_1 at /home/simone_a/.julia/juliaup/julia-1.9.0-beta3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-7/julialang/julia-release-1-dot-9/src/gf.c:2681 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-7/julialang/julia-release-1-dot-9/src/gf.c:2863
return_type_tfunc at ./compiler/tfuncs.jl:2317
abstract_call_known at ./compiler/abstractinterpretation.jl:1867
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2536
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1916
abstract_call at ./compiler/abstractinterpretation.jl:1987
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2536
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2536
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:947
abstract_call_method at ./compiler/abstractinterpretation.jl:609
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call at ./compiler/abstractinterpretation.jl:1984
abstract_call at ./compiler/abstractinterpretation.jl:1966
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2133
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2347
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2560
typeinf_local at ./compiler/abstractinterpretation.jl:2735
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2841
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_ext at ./compiler/typeinfer.jl:1072
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1105
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1101
jfptr_typeinf_ext_toplevel_16609.clone_1 at /home/simone_a/.julia/juliaup/julia-1.9.0-beta3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
...
ERROR: BoundsError: attempt to access 56-element Vector{Core.Compiler.VarState} at index [57]
Stacktrace:
  [1] setindex!
    @ ./array.jl:969 [inlined]
  [2] stoverwrite1!(state::Vector{Core.Compiler.VarState}, change::Core.Compiler.StateUpdate)
    @ Core.Compiler ./compiler/typelattice.jl:599
  [3] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2742
  [4] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2841
  [5] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:244
  [6] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:215
  [7] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:947
  [8] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:609
  [9] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:153
 [10] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1984
 [11] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1966
 [12] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2133
 [13] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2347
 [14] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2560
 [15] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2735
--- the last 12 lines are repeated 1 more time ---
 [28] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2841
 [29] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:244
 [30] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:215
 [31] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:947
 [32] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:609
 [33] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:153
 [34] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1916
 [35] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1987
 [36] abstract_apply(interp::Core.Compiler.NativeInterpreter, argtypes::Vector{Any}, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1545
 [37] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1830
 [38] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1987
 [39] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1966
 [40] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2133
 [41] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2347
 [42] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2560
 [43] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2735
 [44] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2841
 [45] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:244
 [46] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:215
 [47] typeinf_edge(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:947
 [48] abstract_call_method(interp::Core.Compiler.NativeInterpreter, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:609
 [49] abstract_call_gf_by_type(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:153
 [50] abstract_call_known(interp::Core.Compiler.NativeInterpreter, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1916
--- the last 13 lines are repeated 1 more time ---
 [64] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1987
 [65] abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:1966
 [66] abstract_eval_statement_expr(interp::Core.Compiler.NativeInterpreter, e::Expr, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState, mi::Nothing)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2133
 [67] abstract_eval_statement(interp::Core.Compiler.NativeInterpreter, e::Any, vtypes::Vector{Core.Compiler.VarState}, sv::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2347
 [68] abstract_eval_basic_statement(interp::Core.Compiler.NativeInterpreter, stmt::Any, pc_vartable::Vector{Core.Compiler.VarState}, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2536
 [69] typeinf_local(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2735
 [70] typeinf_nocycle(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/abstractinterpretation.jl:2841
 [71] _typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:244
 [72] typeinf(interp::Core.Compiler.NativeInterpreter, frame::Core.Compiler.InferenceState)
    @ Core.Compiler ./compiler/typeinfer.jl:215
 [73] typeinf
    @ ./compiler/typeinfer.jl:12 [inlined]
 [74] typeinf_type(interp::Core.Compiler.NativeInterpreter, method::Method, atype::Any, sparams::Core.SimpleVector)
    @ Core.Compiler ./compiler/typeinfer.jl:1094
 [75] return_type(interp::Core.Compiler.NativeInterpreter, t::DataType)
    @ Core.Compiler ./compiler/typeinfer.jl:1155
 [76] return_type(f::Any, t::DataType)
    @ Core.Compiler ./compiler/typeinfer.jl:1127
 [77] collect(itr::Base.Generator)
    @ Base ./array.jl:757
 [78] map(::Function, ::Matrix{Tuple{Float64, typeof(∂(λ))}}, ::Matrix{Float64})
    @ Base ./abstractarray.jl:3377
 [79] (::Zygote.var"#∇broadcasted#943")(ȳ::Any)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:205
 [80] (::Zygote.var"#3903#back#947")(Δ::Any)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [81] (::Zygote.var"#208#209")(Δ::Any)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:206
 [82] (::Zygote.var"#2084#back#210")(Δ::Any)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [83] Pullback
    @ ./broadcast.jl:1317 [inlined]
 [84] (::typeof(∂(broadcasted)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [85] Pullback
    @ ~/Documents/projects/n/SnoopPrecompileMWE/src/SnoopPrecompileMWE.jl:14 [inlined]
 [86] (::typeof(∂(kernelmatrix)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [87] Pullback
    @ ~/.julia/dev/KernelFunctions/src/kernels/scaledkernel.jl:29 [inlined]
 [88] (::typeof(∂(kernelmatrix)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [89] Pullback
    @ ~/.julia/dev/KernelFunctions/src/kernels/transformedkernel.jl:117 [inlined]
 [90] (::typeof(∂(kernelmatrix)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [91] Pullback
    @ ~/Documents/projects/n/SnoopPrecompileMWE/src/SnoopPrecompileMWE.jl:47 [inlined]
 [92] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [93] (::Zygote.var"#60#61")(Δ::Any)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:45
 [94] gradient(f::Loss{RowVecs{Float64, Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}, RowVecs{Float64, Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}}, args::NamedTuple{(:variance, :lengthscale), Tuple{Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:97
...

(the error with RowVecs replaced by eachrow is very similar).

So the __init__ you are talking about would be in Zygote, not in my module (because I don’t have an explicit __init__ function, so unless @precompile_all_calls is generating one, I don’t have one at all)?

1 Like

ah. someone needs to make a weakdeps pr.

Thanks! This looks quite involved.
It seems to me that there are already so many custom tricks in place to get Distances.jl to play nicely with Zygote that what I’m trying to do messes things up.

I just tried completely circumventing Distances.jl and also the kernel structs from KernelFunctions.jl, and everything works nicely without any custom rrules with broadcast:

module SnoopPrecompileMWE

using LinearAlgebra
using Zygote: gradient

struct Loss{Tx, Ty}
    x::Tx
    y::Ty
end

sqdist(x::Real, y::Real) = (x-y)^2
sqdist(x, y) = dot(x, x) + dot(y, y) - 2dot(x, y)
sekernel(v, l, x, y) = v * exp(-sqdist(x, y) / (2l^2))

function (l::Loss)(θ)
    # k = θ.variance * SEKernel() ∘ ScaleTransform(1/θ.lengthscale)
    # return sum(kernelmatrix(k, l.x, l.y))
    return sum(sekernel.(θ.variance, θ.lengthscale, l.x, permutedims(l.y)))
end

function compute_gradient(l, θ)
    return only(gradient(l, θ))
end

export Loss, compute_gradient

using SnoopPrecompile
using KernelFunctions: RowVecs, ColVecs

@precompile_setup begin
    losses = [
        Loss(randn(10), randn(10)),
        Loss(RowVecs(randn(10, 5)), RowVecs(randn(10, 5))),
        Loss(eachrow(randn(10, 5)), eachrow(randn(10, 5))),
        Loss(ColVecs(randn(10, 5)), ColVecs(randn(10, 5))),
        Loss(eachcol(randn(10, 5)), eachcol(randn(10, 5))),
        Loss([randn(5) for _ in  1:10], [randn(5) for _ in  1:10])
    ]
    θ = (variance = 1., lengthscale = 1.)

    @precompile_all_calls begin
        for loss in losses
            loss(θ)
            compute_gradient(loss, θ)
        end
    end
end

end # module SnoopPrecompileMWE

This yields:

@time using SnoopPrecompileMWE 

## w/o precompile statements
# 2.568953 seconds (5.10 M allocations: 359.509 MiB, 5.17% gc time, 0.39% compilation time)

## w/ precompile statements, first time
# 17.620369 seconds (5.48 M allocations: 385.582 MiB, 0.77% gc time, 0.27% compilation time)

## w/ precompile statements, second time
# 2.794685 seconds (5.46 M allocations: 383.742 MiB, 5.05% gc time, 0.41% compilation time)

using KernelFunctions: RowVecs, ColVecs
losses = [
    Loss(randn(10), randn(10)),
    Loss(RowVecs(randn(10, 5)), RowVecs(randn(10, 5))),
    Loss(eachrow(randn(10, 5)), eachrow(randn(10, 5))),
    Loss(ColVecs(randn(10, 5)), ColVecs(randn(10, 5))),
    Loss(eachcol(randn(10, 5)), eachcol(randn(10, 5))),
    Loss([randn(5) for _ in  1:10], [randn(5) for _ in  1:10])
];
θ = (variance = 1., lengthscale = 1.);

for loss in losses
    @time loss(θ)
end

## w/o precompile statements
# 0.081053 seconds (289.35 k allocations: 19.437 MiB)
# 0.159219 seconds (570.63 k allocations: 38.681 MiB, 6.74% gc time)
# 0.130257 seconds (511.75 k allocations: 35.031 MiB, 8.63% gc time)
# 0.150489 seconds (525.82 k allocations: 35.878 MiB, 7.86% gc time)
# 0.115821 seconds (499.20 k allocations: 34.447 MiB)
# 0.066783 seconds (215.01 k allocations: 14.603 MiB, 11.55% gc time)

## w/ precompile statements
# 0.000013 seconds (4 allocations: 1008 bytes)
# 0.000039 seconds (2 allocations: 912 bytes)
# 0.000007 seconds (2 allocations: 912 bytes)
# 0.000005 seconds (2 allocations: 272 bytes)
# 0.000004 seconds (2 allocations: 272 bytes)
# 0.000005 seconds (4 allocations: 1008 bytes)

for loss in losses
    @time compute_gradient(loss, θ)
end

## w/o precompile statements
# 12.661393 seconds (24.57 M allocations: 1.553 GiB, 5.08% gc time)
# 1.160352 seconds (3.17 M allocations: 203.534 MiB, 9.06% gc time)
# 0.255233 seconds (920.16 k allocations: 58.944 MiB)
# 0.490022 seconds (1.40 M allocations: 90.359 MiB, 7.70% gc time)
# 0.253476 seconds (912.18 k allocations: 58.553 MiB)
# 0.401538 seconds (1.11 M allocations: 72.077 MiB, 11.36% gc time)

## w/ precompile statements
# 0.000291 seconds (198 allocations: 19.953 KiB)
# 0.009303 seconds (4.42 k allocations: 368.150 KiB)
# 0.000117 seconds (1.33 k allocations: 166.328 KiB)
# 0.000047 seconds (426 allocations: 59.594 KiB)
# 0.000044 seconds (426 allocations: 59.906 KiB)
# 0.000079 seconds (1.32 k allocations: 146.625 KiB)
1 Like

It shouldn’t be too hard either. We currently have an almost complete implementation in Changing Distances adjoints to ChainRules syntax by theogf · Pull Request #923 · FluxML/Zygote.jl · GitHub, so the biggest change would be removing ZygoteRuleConfig from each method. Not technically difficult, but may run into jurisdictional issues.

Maybe I should try whether the original attempt at precompilation works on that branch of Zygote.

It won’t, since the code there is still gated behind Requires. That code needs to be moved into Distances.jl as a package extension, hence “jurisdictional issues” about who would move it, who would approve it, whether anyone would approve it (I hope so, but we’ve been here before), etc.

I‘m happy to make a PR, but I‘m not familiar enough to know what to do. Why would there be any opposition?