Accelerate autodiff in Zygote.jl

Hi there,
I’d like to know if there is any performance tip for using Zygote.jl.

I am calculating the Hessian matrix using autodiff, and now the autodiff part runs quite slow, and I’d like to accelerate it.
e.g.

  4.369689 seconds (7.95 M allocations: 1.199 GiB, 5.27% gc time)
Above: @time loss = myfunc(x, other_args) 

 10.895003 seconds (26.71 M allocations: 2.392 GiB, 5.95% gc time)
Above: @time ∂l∂x, _ = gradient(myfunc, x, other_args) 

480.086188 seconds (285.59 M allocations: 201.816 GiB, 4.54% gc time)
Above: @time ∂2l∂2x = Zygote.hessian(f, x) 

Is there any good way for performance optimization of autodiff? (like parallelization, or leveraging sparcity of matrix, or reducing memory IO? and how?)

Thx for any reply!

Zygote’s hessian is just bad. Try and see if Diffractor works on your problem (though it’s still very early for it)

3 Likes

Thanks, I’ll check it out!

Hi, I used ] add https://github.com/JuliaDiff/Diffractor.jl.git to get Diffractor, but I met some precompiling error:

julia> using Diffractor 
[ Info: Precompiling Diffractor [9f5e2b26-1114-432f-b630-d3fe2085c51c]
ERROR: LoadError: LoadError: LoadError: UndefVarError: @aggressive_constprop not defined
Stacktrace:
 [1] include(::Function, ::Module, ::String) at ./Base.jl:380
 [2] include at ./Base.jl:368 [inlined]
 [3] include(::String) at /root/.julia/packages/Diffractor/2Ott3/src/Diffractor.jl:1
 [4] top-level scope at /root/.julia/packages/Diffractor/2Ott3/src/Diffractor.jl:7
 [5] include(::Function, ::Module, ::String) at ./Base.jl:380
 [6] include(::Module, ::String) at ./Base.jl:368
 [7] top-level scope at none:2
 [8] eval at ./boot.jl:331 [inlined]
 [9] eval(::Expr) at ./client.jl:467
 [10] top-level scope at ./none:3
in expression starting at /root/.julia/packages/Diffractor/2Ott3/src/runtime.jl:3
in expression starting at /root/.julia/packages/Diffractor/2Ott3/src/runtime.jl:3
in expression starting at /root/.julia/packages/Diffractor/2Ott3/src/Diffractor.jl:7
ERROR: Failed to precompile Diffractor [9f5e2b26-1114-432f-b630-d3fe2085c51c] to /root/.julia/compiled/v1.5/Diffractor/vzwwW_E53n4.ji.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] compilecache(::Base.PkgId, ::String) at ./loading.jl:1305
 [3] _require(::Base.PkgId) at ./loading.jl:1030
 [4] require(::Base.PkgId) at ./loading.jl:928
 [5] require(::Module, ::Symbol) at ./loading.jl:923
 [6] eval at ./boot.jl:331 [inlined]
 [7] eval at ./Base.jl:39 [inlined]
 [8] repleval(::Module, ::Expr, ::String) at /root/.vscode-server-insiders/extensions/julialang.language-julia-1.3.30/scripts/packages/VSCodeServer/src/repl.jl:157
 [9] (::VSCodeServer.var"#69#71"{Module,Expr,REPL.LineEditREPL,REPL.LineEdit.Prompt})() at /root/.vscode-server-insiders/extensions/julialang.language-julia-1.3.30/scripts/packages/VSCodeServer/src/repl.jl:123
 [10] with_logstate(::Function, ::Any) at ./logging.jl:408
 [11] with_logger at ./logging.jl:514 [inlined]
 [12] (::VSCodeServer.var"#68#70"{Module,Expr,REPL.LineEditREPL,REPL.LineEdit.Prompt})() at /root/.vscode-server-insiders/extensions/julialang.language-julia-1.3.30/scripts/packages/VSCodeServer/src/repl.jl:124
 [13] #invokelatest#1 at ./essentials.jl:710 [inlined]
 [14] invokelatest(::Any) at ./essentials.jl:709
 [15] macro expansion at /root/.vscode-server-insiders/extensions/julialang.language-julia-1.3.30/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
 [16] (::VSCodeServer.var"#53#54")() at ./task.jl:356

It requires v1.7

Thx! Do you mean Julia v1.7 beta?

Yes

1 Like

Get it! :grinning: