Working on a nested optimization problem, I noticed that a lot of time was being spent in gradient calls, and a little digging found that doing out-of-place gradients was introducing a type-instability. I followed the advice here, but it does not appear to have helped…see MWE below:
using ForwardDiff, LinearAlgebra
f(x) = dot(x, x)
x = randn(100)
gconfig = ForwardDiff.GradientConfig(f, x)
# all of these are type-unstable
g1(x) = ForwardDiff.gradient(f, x)
g2(x) = ForwardDiff.gradient(f, x, gconfig)
g3(x::T) where T = ForwardDiff.gradient(f, x)::T
g4(x::T) where T = ForwardDiff.gradient(f, x, gconfig)::T
@code_warntype g1(x)
@code_warntype g2(x)
@code_warntype g3(x)
@code_warntype g4(x)
Using the in-place gradient is type-stable, as expected:
# these are good
g1!(G, x) = ForwardDiff.gradient!(G, f, x)
g2!(G, x) = ForwardDiff.gradient!(G, f, x, gconfig)
G = zero(x)
@code_warntype g1!(G, x)
@code_warntype g2!(G, x)
@time g1!(G, x)
@time g2!(G, x)
So I should probably just use that. Still, I’m curious what the cause of the instability is. On Julia 1.6.2, ForwardDiff v0.10.19. Thanks!
I don’t see how the compiler could infer CHK from the arguments fand x. And I don’t know if there is a syntactic way to specify CHK only and let the compiler infer T.
Actually, those don’t work–the macro says they’re type-stable, but they don’t actually run. With g1 and x defined as above, for example:
julia> g1(x)
ERROR: TypeError: in Type{...} expression, expected UnionAll, got a value of type typeof(ForwardDiff.gradient)
[1] g1(x::Vector{Float64})
@ Main .\REPL[8]:1
[2] top-level scope
@ REPL[17]:1
No worries, I didn’t try to actually execute it at first either. This seems to work, and the reason I thought it didn’t above seems to be that I’d defined the GradientConfig in the global scope. In the example below g1 is not type-stable, but g2 is.
using ForwardDiff, LinearAlgebra
x = randn(100)
f(x) = dot(x, x)
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{2}())
g1(x) = ForwardDiff.gradient(f, x, cfg)
g2 = let cfg = cfg
x -> ForwardDiff.gradient(f, x, cfg)
@code_warntype g1(x)
@code_warntype g2(x)
I’ll need to check my original, non-MWE code, but suspect the issue there may have also been scoping-related.