Say I have some function f: \mathbb{R}^m \times \mathbb{R} \to \mathbb{R}^n, i.e. f(\theta, x) is some vector for a given set of parameters \theta and a scalar x. I would now like to calculate the gradient \nabla_\theta L(\theta, x) of a loss function on the form L(\theta, x) = g\left(\frac{\partial f(\theta, x)}{\partial x} \right) \in \mathbb{R} (where I need to be able to access the intermediate result of \frac{\partial f(\theta, x)}{\partial x} \in \mathbb{R}^n in order to calculate L(\theta, x)).
Is it possible to accomplish this within Flux? Here’s a minimal not working example (MNWE?) of what I’d like to do:
using Flux, ForwardDiff
f = Chain(x -> fill(x, 3), Dense(3, 3, softplus))
df(x) = ForwardDiff.derivative(f, x)
x = rand()
f(x) #Works
df(x) #Works
g = gradient(() -> sum(df(x)), params(f)) #Fails
which fails with the following stack trace:
ERROR: setindex! not defined for ForwardDiff.Partials{1,Float64}
Stacktrace:
[1] error(::String, ::Type) at ./error.jl:42
[2] error_if_canonical_setindex(::IndexLinear, ::ForwardDiff.Partials{1,Float64}, ::Int64) at ./abstractarray.jl:1082
[3] setindex! at ./abstractarray.jl:1073 [inlined]
[4] (::getfield(Zygote, Symbol("##976#978")){ForwardDiff.Partials{1,Float64},Tuple{Int64}})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:32
[5] (::getfield(Zygote, Symbol("##2585#back#980")){getfield(Zygote, Symbol("##976#978")){ForwardDiff.Partials{1,Float64},Tuple{Int64}}})(::Float64) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[6] partials at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/dual.jl:96 [inlined]
[7] (::getfield(Zygote, Symbol("##153#154")){typeof(∂(partials)),Tuple{Tuple{Nothing},Tuple{Nothing}}})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142
[8] (::getfield(Zygote, Symbol("##283#back#155")){getfield(Zygote, Symbol("##153#154")){typeof(∂(partials)),Tuple{Tuple{Nothing},Tuple{Nothing}}}})(::Float64) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[9] extract_derivative at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/dual.jl:101 [inlined]
[10] #72 at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/derivative.jl:81 [inlined]
[11] (::typeof(∂(#72)))(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[12] (::getfield(Zygote, Symbol("##1104#1108")))(::typeof(∂(#72)), ::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:134
[13] (::getfield(Base, Symbol("##3#4")){getfield(Zygote, Symbol("##1104#1108"))})(::Tuple{typeof(∂(#72)),Float64}) at ./generator.jl:36
[14] iterate at ./generator.jl:47 [inlined]
[15] collect at ./array.jl:606 [inlined]
[16] map at ./abstractarray.jl:2155 [inlined]
[17] (::getfield(Zygote, Symbol("##1103#1107")){Array{typeof(∂(#72)),1}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/Zygote/8dVxG/src/lib/array.jl:134
[18] (::getfield(Zygote, Symbol("##2842#back#1109")){getfield(Zygote, Symbol("##1103#1107")){Array{typeof(∂(#72)),1}}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[19] derivative at /home/troels/.julia/packages/ForwardDiff/yPcDQ/src/derivative.jl:81 [inlined]
[20] df at ./REPL[3]:1 [inlined]
[21] (::typeof(∂(df)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[22] #5 at ./REPL[7]:1 [inlined]
[23] (::typeof(∂(#5)))(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[24] (::getfield(Zygote, Symbol("##38#39")){Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float64) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:101
[25] gradient(::Function, ::Zygote.Params) at /home/troels/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:47
[26] top-level scope at REPL[7]:1
I can, of course, approximate the desired result by using a finite difference approximation of some kind, like
Δf(x, δ = 0.001) = (f(x + δ) - f(x)) ./ δ
Δf(x) #Works
g = gradient(() -> sum(Δf(x)), params(f)) #Works
but it would be great if this wasn’t necessary…
Ideas? Suggestions?
Some package information:
(v1.2) pkg> st Flux
Status `~/.julia/environments/v1.2/Project.toml`
[5ae59095] Colors v0.9.6
[3a865a2d] CuArrays v1.4.2
[587475ba] Flux v0.9.0 #master (https://github.com/FluxML/Flux.jl.git)
[e5e0dc1b] Juno v0.7.2
[872c559c] NNlib v0.6.0 #master (https://github.com/FluxML/NNlib.jl.git)
[189a3867] Reexport v0.2.0
[2913bbd2] StatsBase v0.32.0
[e88e6eb3] Zygote v0.4.1
(v1.2) pkg> st ForwardDiff
Status `~/.julia/environments/v1.2/Project.toml`
[163ba53b] DiffResults v0.0.4
[f6369f11] ForwardDiff v0.10.6
[276daf66] SpecialFunctions v0.7.2
[90137ffa] StaticArrays v0.11.1
Note that I’m on Flux#master
.