Yes, I agree, that Zygote is not directly exposed to the user. However, it seems to be more complicated than that. I am referring to the code you referred me to at https://docs.sciml.ai/Overview/dev/showcase/missing_physics/
You include the following modules:
# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL
# Standard Libraries
using LinearAlgebra, Statistics, Random
# External Libraries
using ComponentArrays, Lux, Plots
Lower in the code, there is the line:
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
What happens when Optimization.AutoZygote()
is called? One gets an error (at least I get one now, and I understand it). AutoZygote
requires access to Zygote.jl
by Optimization.jl
. I say this because I forked Optimization.jl
and searched for all references to AutoZygote
, and found the following:
â Optimization.jl git:(master) findh '*.jl' Zygote
2:using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
109:optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
110:optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(), nothing)
./test/ADtests.jl
58: Optimization.AutoZygote())
./test/minibatch.jl
2:using FiniteDiff, ForwardDiff, ModelingToolkit, ReverseDiff, Tracker, Zygote
12: ForwardDiff, ModelingToolkit, ReverseDiff, Tracker, Zygote],
./docs/make.jl
1:using OptimizationOptimJL, OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote,
75: optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
86: optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), Optimization.AutoZygote())
99: optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), Optimization.AutoZygote(),
./lib/OptimizationOptimJL/test/runtests.jl
1:using OptimizationNLopt, Optimization, Zygote
10: optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), Optimization.AutoZygote())
15: optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
./lib/OptimizationNLopt/test/runtests.jl
1:using OptimizationMOI, Optimization, Ipopt, NLopt, Zygote, ModelingToolkit
37: optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), Optimization.AutoZygote())
43: optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
./lib/OptimizationMOI/test/runtests.jl
1:using OptimizationNonconvex, Optimization, Zygote, Pkg
9: optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
./lib/OptimizationNonconvex/test/runtests.jl
3:using Zygote
27: optprob = OptimizationFunction(sumfunc, Optimization.AutoZygote())
./lib/OptimizationOptimisers/test/runtests.jl
2:AutoZygote <: AbstractADType
8:OptimizationFunction(f,AutoZygote();kwargs...)
11:This uses the [Zygote.jl](https://github.com/FluxML/Zygote.jl) package.
14:forward-over-reverse mixing ForwardDiff.jl with Zygote.jl
23:Hessian is not defined via Zygote.
25:struct AutoZygote <: AbstractADType end
27:function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
28: num_cons != 0 && error("AutoZygote does not currently support constraints")
34: res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
42: Zygote.gradient(x -> _f(x, args...), θ)[1]
./src/function/zygote.jl
27: @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
./src/Optimization.jl
The last line,
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
is the key. While the user is not e exposed to Zygote.jl
, Optimization.jl
is exposed, in the __init()__
function. The Zygote.jl
cannot be included unless Zygote is in either my Project.toml
or the Project.toml
of Optimization.jl
. Looking at the Optimization module, we find that its Project.toml
does not include Zygote.jl
:
name = "Optimization"
uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
version = "3.10.0"
[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
[compat]
ArrayInterface = "6"
ArrayInterfaceCore = "0.1.1"
ConsoleProgressMonitor = "0.1"
DocStringExtensions = "0.8, 0.9"
LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
Requires = "1.0"
SciMLBase = "1.79.0"
TerminalLoggers = "0.1"
julia = "1.6"
[extras]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Therefore, my question is, why should the code you send me work, unless I specifically add Zygote.jl
to my own Project.toml
to address the omission in Optimization.jl
.
Iâd be interested to know if I made an error in reasoning. Thanks!