Is it possible perform reverse mode differentiation (Flux.jl with Zygote.jl) of a forward mode differentiation result (e.g. ForwardDiff)?

Say I have some function f: \mathbb{R}^m \times \mathbb{R} \to \mathbb{R}^n, i.e. f(\theta, x) is some vector for a given set of parameters \theta and a scalar x. I would now like to calculate the gradient \nabla_\theta L(\theta, x) of a loss function on the form L(\theta, x) = g\left(\frac{\partial f(\theta, x)}{\partial x} \right) \in \mathbb{R} (where I need to be able to access the intermediate result of \frac{\partial f(\theta, x)}{\partial x} \in \mathbb{R}^n in order to calculate L(\theta, x)).

Is it possible to accomplish this within Flux? Here’s a minimal not working example (MNWE?) of what I’d like to do:

using Flux, ForwardDiff

f = Chain(x -> fill(x, 3), Dense(3, 3, softplus))
df(x) = ForwardDiff.derivative(f, x)

x = rand()
f(x) #Works
df(x) #Works
g = gradient(() -> sum(df(x)), params(f)) #Fails

which fails with the following stack trace:

ERROR: setindex! not defined for ForwardDiff.Partials{1,Float64}
Stacktrace:
 [1] error(::String, ::Type) at ./error.jl:42
 [2] error_if_canonical_setindex(::IndexLinear, ::ForwardDiff.Partials{1,Float64}, ::Int64) at ./abstractarray.jl:1082
 [3] setindex! at ./abstractarray.jl:1073 [inlined]
 [4] (::getfield(Zygote, Symbol("##976#978")){ForwardDiff.Partials{1,Float64},Tuple{Int64}})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:32
 [5] (::getfield(Zygote, Symbol("##2585#back#980")){getfield(Zygote, Symbol("##976#978")){ForwardDiff.Partials{1,Float64},Tuple{Int64}}})(::Float64) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [6] partials at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/dual.jl:96 [inlined]
 [7] (::getfield(Zygote, Symbol("##153#154")){typeof(∂(partials)),Tuple{Tuple{Nothing},Tuple{Nothing}}})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142
 [8] (::getfield(Zygote, Symbol("##283#back#155")){getfield(Zygote, Symbol("##153#154")){typeof(∂(partials)),Tuple{Tuple{Nothing},Tuple{Nothing}}}})(::Float64) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [9] extract_derivative at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/dual.jl:101 [inlined]
 [10] #72 at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/derivative.jl:81 [inlined]
 [11] (::typeof(∂(#72)))(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
 [12] (::getfield(Zygote, Symbol("##1104#1108")))(::typeof(∂(#72)), ::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:134
 [13] (::getfield(Base, Symbol("##3#4")){getfield(Zygote, Symbol("##1104#1108"))})(::Tuple{typeof(∂(#72)),Float64}) at ./generator.jl:36
 [14] iterate at ./generator.jl:47 [inlined]
 [15] collect at ./array.jl:606 [inlined]
 [16] map at ./abstractarray.jl:2155 [inlined]
 [17] (::getfield(Zygote, Symbol("##1103#1107")){Array{typeof(∂(#72)),1}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:134
 [18] (::getfield(Zygote, Symbol("##2842#back#1109")){getfield(Zygote, Symbol("##1103#1107")){Array{typeof(∂(#72)),1}}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [19] derivative at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/derivative.jl:81 [inlined]
 [20] df at ./REPL[3]:1 [inlined]
 [21] (::typeof(∂(df)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
 [22] #5 at ./REPL[7]:1 [inlined]
 [23] (::typeof(∂(#5)))(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
 [24] (::getfield(Zygote, Symbol("##38#39")){Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:101
 [25] gradient(::Function, ::Zygote.Params) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:47
 [26] top-level scope at REPL[7]:1

I can, of course, approximate the desired result by using a finite difference approximation of some kind, like

Δf(x, δ = 0.001) = (f(x + δ) - f(x)) ./ δ

Δf(x) #Works
g = gradient(() -> sum(Δf(x)), params(f)) #Works

but it would be great if this wasn’t necessary…

Ideas? Suggestions?


Some package information:

(v1.2) pkg> st Flux
    Status `~/.julia/environments/v1.2/Project.toml`
  [5ae59095] Colors v0.9.6
  [3a865a2d] CuArrays v1.4.2
  [587475ba] Flux v0.9.0 #master (https://github.com/FluxML/Flux.jl.git)
  [e5e0dc1b] Juno v0.7.2
  [872c559c] NNlib v0.6.0 #master (https://github.com/FluxML/NNlib.jl.git)
  [189a3867] Reexport v0.2.0
  [2913bbd2] StatsBase v0.32.0
  [e88e6eb3] Zygote v0.4.1

(v1.2) pkg> st ForwardDiff
    Status `~/.julia/environments/v1.2/Project.toml`
  [163ba53b] DiffResults v0.0.4
  [f6369f11] ForwardDiff v0.10.6
  [276daf66] SpecialFunctions v0.7.2
  [90137ffa] StaticArrays v0.11.1

Note that I’m on Flux#master.

I am also interested in this question; it seems like it comes up frequently but the answer also in flux! Bad puns notwithstanding, are there any updates in regards to this working with Zygote?

There is Zygote.hessian that might work for your usecase. Doing multiple AD passes is something that’s currently actively being worked on and will probably become more straightforward when ChainRules.jl becomes more widely adapted by AD packages.

DiffEqFlux.jl has a fix for this in terms of

plus the sciml_train training loop handles a related bug in Zygote by manually de-dualizing the gradient:

So if you use DiffEqFlux.sciml_train you’ll see this magically works, and feel free to take the fix. More discussion is here: https://github.com/FluxML/Zygote.jl/pull/510 .

FWIW, this could be fixed by a Zygote forward mode, and one is in the works: https://github.com/FluxML/Zygote.jl/pull/503

2 Likes