I’m trying to minimize a function that involves an integral, the integrand of which is one of the optimizing parameters. I can compute the function and get the gradient with Zygote, but I’m getting an odd conversion error when I try to run an optimization routine. Might anyone know what is going on?
using Distributions
using Random
using ForwardDiff
using Zygote
using QuadGK
using Optim
theta_lower = 1.0
theta_upper = 5.0
uniform_dist = Uniform(theta_lower, theta_upper)
function linearCost(θ)
out = 2.5 - 0.3*θ
return(out)
end
function getAlpha(θ₁, θ₂, B, costFn, Fdist)
intcdF1 = quadgk(θ -> (linearCost(θ)*pdf(Fdist,θ)), θ₁, θ₂, rtol=1e-8)[1]
intcdF2 = quadgk(θ -> (linearCost(θ)*pdf(Fdist,θ)), θ₂, theta_upper, rtol=1e-8)[1]
alpha_num = B + θ₂*(1.0-cdf(Fdist,θ₂)) - intcdF2
alpha_denom = intcdF1 + θ₂*(1.0-cdf(Fdist,θ₂)) - θ₁*(1-cdf(Fdist,θ₁))
α = alpha_num / alpha_denom
return(α)
end
function socialPlannerObjective(x, params)
θ₁, Δθ = x
B, costFn, Fdist = params
θ₂ = θ₁ + Δθ
intdiffcdF1= quadgk(θ -> ((θ-linearCost(θ))*pdf(Fdist,θ)), θ₁, θ₂, rtol=1e-8)[1]
intdiffcdF2 = quadgk(θ -> ((θ-linearCost(θ))*pdf(Fdist,θ)), θ₂, theta_upper, rtol=1e-8)[1]
alpha = getAlpha(θ₁, θ₂, B, costFn, Fdist)
@show out = alpha*intdiffcdF1 + intdiffcdF2
return(out)
end
gradsocialPlannerObjective(y) = Zygote.gradient(x -> socialPlannerObjective(x,[10, linearCost, uniform_dist]), y)[1]
## test
@show socialPlannerObjective([1.3, 0.3], [10, linearCost, uniform_dist])
@show gradsocialPlannerObjective([1.3, 0.6])
optimize(x -> socialPlannerObjective(x,[10, linearCost, uniform_dist]), x->gradsocialPlannerObjective(x), [1.3, 0.6], LBFGS(); inplace = false)
Output being received:
out = alpha * intdiffcdF1 + intdiffcdF2 = 0.03361344537815181
socialPlannerObjective([1.3, 0.3], [10, linearCost, uniform_dist]) = 0.03361344537815181
out = alpha * intdiffcdF1 + intdiffcdF2 = 0.40183246073298395
gradsocialPlannerObjective([1.3, 0.6]) = [2.7550300156245733, 1.3382479921054806]
out = alpha * intdiffcdF1 + intdiffcdF2 = 0.40183246073298395
out = alpha * intdiffcdF1 + intdiffcdF2 = 0.40183246073298395
MethodError: Cannot `convert` an object of type Array{QuadGK.Segment{Float64,Float64,Float64},1} to an object of type QuadGK.Segment{Float64,Float64,Float64}
Closest candidates are:
convert(::Type{T}, !Matched::QuadGK.Segment) where T<:QuadGK.Segment at /Users/svass/.julia/packages/QuadGK/jmDk8/src/evalrule.jl:10
convert(::Type{T}, !Matched::T) where T at essentials.jl:168
QuadGK.Segment{Float64,Float64,Float64}(::Any, !Matched::Any, !Matched::Any, !Matched::Any) where {TX, TI, TE} at /Users/svass/.julia/packages/QuadGK/jmDk8/src/evalrule.jl:3
Stacktrace:
[1] setindex!(::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Int64) at ./array.jl:782
[2] adjoint at /Users/svass/.julia/packages/Zygote/YeCEW/src/lib/array.jl:60 [inlined]
[3] _pullback at /Users/svass/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[4] percolate_down! at /Users/svass/.julia/packages/DataStructures/GvsTk/src/heaps/arrays_as_heaps.jl:28 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(DataStructures.percolate_down!), ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Int64, ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Base.Order.ReverseOrdering{Base.Order.ForwardOrdering}, ::Int64) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[6] percolate_down! at /Users/svass/.julia/packages/DataStructures/GvsTk/src/heaps/arrays_as_heaps.jl:18 [inlined]
[7] _pullback(::Zygote.Context, ::typeof(DataStructures.percolate_down!), ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Int64, ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Base.Order.ReverseOrdering{Base.Order.ForwardOrdering}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[8] heappop! at /Users/svass/.julia/packages/DataStructures/GvsTk/src/heaps/arrays_as_heaps.jl:59 [inlined]
[9] adapt at /Users/svass/.julia/packages/QuadGK/jmDk8/src/adapt.jl:36 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(QuadGK.adapt), ::var"#64#66"{Uniform{Float64}}, ::Array{QuadGK.Segment{Float64,Float64,Float64},1}, ::Float64, ::Float64, ::Int64, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Int64, ::Float64, ::Float64, ::Int64, ::typeof(LinearAlgebra.norm)) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[11] do_quadgk at /Users/svass/.julia/packages/QuadGK/jmDk8/src/adapt.jl:28 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(QuadGK.do_quadgk), ::var"#64#66"{Uniform{Float64}}, ::Tuple{Float64,Float64}, ::Int64, ::Nothing, ::Float64, ::Int64, ::typeof(LinearAlgebra.norm)) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[13] #28 at /Users/svass/.julia/packages/QuadGK/jmDk8/src/adapt.jl:159 [inlined]
[14] handle_infinities at /Users/svass/.julia/packages/QuadGK/jmDk8/src/adapt.jl:93 [inlined]
[15] _pullback(::Zygote.Context, ::typeof(QuadGK.handle_infinities), ::QuadGK.var"#28#29"{Nothing,Float64,Int64,Int64,typeof(LinearAlgebra.norm)}, ::var"#64#66"{Uniform{Float64}}, ::Tuple{Float64,Float64}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[16] #quadgk#27 at /Users/svass/.julia/packages/QuadGK/jmDk8/src/adapt.jl:157 [inlined]
[17] adjoint at /Users/svass/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:168 [inlined]
[18] _pullback at /Users/svass/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[19] #quadgk at ./none:0 [inlined]
[20] _pullback(::Zygote.Context, ::QuadGK.var"#kw##quadgk", ::NamedTuple{(:rtol,),Tuple{Float64}}, ::typeof(quadgk), ::var"#64#66"{Uniform{Float64}}, ::Float64, ::Float64) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[21] socialPlannerObjective at ./In[8]:42 [inlined]
[22] _pullback(::Zygote.Context, ::typeof(socialPlannerObjective), ::Array{Float64,1}, ::Array{Any,1}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[23] #67 at ./In[8]:52 [inlined]
[24] _pullback(::Zygote.Context, ::var"#67#68", ::Array{Float64,1}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[25] _pullback(::Function, ::Array{Float64,1}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:39
[26] pullback(::Function, ::Array{Float64,1}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:45
[27] gradient(::Function, ::Array{Float64,1}) at /Users/svass/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:54
[28] gradsocialPlannerObjective(::Array{Float64,1}) at ./In[8]:52
[29] #70 at ./In[8]:58 [inlined]
[30] (::NLSolversBase.var"#gg!#2"{var"#70#72"})(::Array{Float64,1}, ::Array{Float64,1}) at /Users/svass/.julia/packages/NLSolversBase/mGaJg/src/objective_types/inplace_factory.jl:21
[31] (::NLSolversBase.var"#fg!#8"{var"#69#71",NLSolversBase.var"#gg!#2"{var"#70#72"}})(::Array{Float64,1}, ::Array{Float64,1}) at /Users/svass/.julia/packages/NLSolversBase/mGaJg/src/objective_types/abstract.jl:13
[32] value_gradient!!(::OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}) at /Users/svass/.julia/packages/NLSolversBase/mGaJg/src/interface.jl:82
[33] value_gradient!(::OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}) at /Users/svass/.julia/packages/NLSolversBase/mGaJg/src/interface.jl:69
[34] value_gradient!(::Optim.ManifoldObjective{OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}}, ::Array{Float64,1}) at /Users/svass/.julia/packages/Optim/L5T76/src/Manifolds.jl:50
[35] (::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}},Array{Float64,1},Array{Float64,1},Array{Float64,1}})(::Float64) at /Users/svass/.julia/packages/LineSearches/WrsMD/src/LineSearches.jl:84
[36] (::HagerZhang{Float64,Base.RefValue{Bool}})(::Function, ::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}},Array{Float64,1},Array{Float64,1},Array{Float64,1}}, ::Float64, ::Float64, ::Float64) at /Users/svass/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:140
[37] HagerZhang at /Users/svass/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:101 [inlined]
[38] perform_linesearch!(::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}, ::LBFGS{Nothing,InitialStatic{Float64},HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.ManifoldObjective{OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}}) at /Users/svass/.julia/packages/Optim/L5T76/src/utilities/perform_linesearch.jl:56
[39] update_state!(::OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}, ::LBFGS{Nothing,InitialStatic{Float64},HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}) at /Users/svass/.julia/packages/Optim/L5T76/src/multivariate/solvers/first_order/l_bfgs.jl:198
[40] optimize(::OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::LBFGS{Nothing,InitialStatic{Float64},HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,Nothing}, ::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}) at /Users/svass/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:57
[41] optimize(::OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::LBFGS{Nothing,InitialStatic{Float64},HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,Nothing}) at /Users/svass/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:33
[42] #optimize#95 at /Users/svass/.julia/packages/Optim/L5T76/src/multivariate/optimize/interface.jl:129 [inlined]
[43] (::Optim.var"#kw##optimize")(::NamedTuple{(:inplace,),Tuple{Bool}}, ::typeof(optimize), ::Function, ::Function, ::Array{Float64,1}, ::LBFGS{Nothing,InitialStatic{Float64},HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,Nothing}) at ./none:0 (repeats 2 times)
[44] top-level scope at In[8]:57