Dear all,
At the moment, I have the following set-up:
Opts = Optim.Options(
x_tol = 0.0,
f_tol = 0.0,
g_abstol = 0.0,
g_reltol = 0.0,
iterations = 5000,
store_trace = true,
show_trace = true,
show_every = 100
)
tape = ReverseDiff.GradientTape(SSM, get_initializer())
g!(G, x) = ReverseDiff.gradient!(G, tape, x)
opti = optimize(SSM, g!, get_initializer(), LBFGS(linesearch=LineSearches.BackTracking()), Opts)
I noticed that ReverseDiff.GradientTape
is sensitive to starting values. It would generate a “domain error” for certain starting values. Generating tape
works now and I’ve got the optimization running now, but on occasion it breaks because of, again, a domain error. I see that it breaks in attempting to update the gradient, in the forward pass step. Here is the output plus part of the error:
Optimizing via reverse auto-differentiation, LBFGS ...
Iter Function value Gradient norm
0 -2.486828e+02 1.468568e+00
* time: 0.021384000778198242
100 -2.497133e+02 1.133929e+00
* time: 6.159498929977417
200 -2.605060e+02 2.121255e+00
* time: 12.05714201927185
ERROR: LoadError: DomainError with -0.04907639254690035:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).
Stacktrace:
[1] throw_complex_domainerror(f::Symbol, x::Float64)
@ Base.Math ./math.jl:33
[2] sqrt
@ ./math.jl:567 [inlined]
[3] sqrt
@ ~/.julia/packages/ForwardDiff/pDtsf/src/dual.jl:240 [inlined]
[4] derivative!
@ ~/.julia/packages/ForwardDiff/pDtsf/src/derivative.jl:46 [inlined]
[5] unary_scalar_forward_exec!(f::typeof(sqrt), output::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, input::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, cache::Base.RefValue{Float64})
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/derivatives/scalars.jl:96
[6] scalar_forward_exec!(instruction::ReverseDiff.ScalarInstruction{typeof(sqrt), ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, Base.RefValue{Float64}})
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/derivatives/scalars.jl:86
[7] forward_exec!(instruction::ReverseDiff.ScalarInstruction{typeof(sqrt), ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, Base.RefValue{Float64}})
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/tape.jl:82
[8] forward_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/tape.jl:77
[9] forward_pass!
@ ~/.julia/packages/ReverseDiff/YkVxM/src/api/tape.jl:34 [inlined]
[10] seeded_forward_pass!
@ ~/.julia/packages/ReverseDiff/YkVxM/src/api/tape.jl:42 [inlined]
[11] gradient!
@ ~/.julia/packages/ReverseDiff/YkVxM/src/api/gradients.jl:79 [inlined]
[12] (::var"#g!#156"{ReverseDiff.GradientTape{var"#SSM#155"{StateSpaceModel{Vector{Float64}, Matrix{Float64}, Vector{Matrix{Float64}}, Int64}, MethodOptions{Bool, String, Int64, Float64}, Vector{Float64}, Vector{Distribution{Multivariate, Continuous}}, Vector{Any}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}})(G::Vector{Float64}, x::Vector{Float64})
@ Main ~/MCMC.jl:113
[13] gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
@ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:63
[14] value_gradient!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
@ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:73
[15] update_g!(d::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, state::Optim.LBFGSState{Vector{Float64}, Vector{Vector{Float64}}, Vector{Vector{Float64}}, Float64, Vector{Float64}}, method::LBFGS{Nothing, InitialStatic{Float64}, BackTracking{Float64, Int64}, Optim.var"#20#22"})
Is there any way to circumvent this issue? How should I think about this? Also, does anyone have any tips on making code “ReverseDiff friendly”?