Passing a Cache to optimization function

Hi!

In my original problem the computation of the objective value function is expensive and during the computation I can collaterally compute the constraints of my problem. To make my problem more efficient I tried to implement a version where I store the constraint info associated with the last evaluated solution in a Cache. When trying to use AutoEnzyme() in Optimization.jl it seems to break because Im forced to pass the Cache through the parameter vector p and my understanding is that this is treated as a Const by Enzyme under the hood. If I explicitly use autodiff on a version of my objective function where the Cache is a separate argument and passing the Cache as DuplicatedNoNeed it will differentiate the function correctly. To give a better sense of what I’m trying here is my MWE:

using Optimization 
using OptimizationMOI 
using Ipopt 
using Enzyme 
using SpecialFunctions

cdf_eval(x, mean, stddev) =  0.5 * (1 + erf((x - mean) / (stddev * sqrt(2.0))))

struct Student{T}
    θ::T
    σ::T 
    u1::T
    u2::T  
end

mutable struct Cache{T}
    last_cutoffs::Vector{T}          
    welfare::T                 
    demand::Vector{T}          
end

function student_choice!(cutoffs, student, prob)
    """
    simple function to simulate the school choice a student given a cutoff vector 
    and compute its expected value. function modified prob in-place for efficiency motives 
    """
    # unpack student 
    (; θ, σ, u1, u2) = student 
    # we constraint c2 > c1 
    if u2 ≥ u1
        prob[2] = 1.0 - cdf_eval(cutoffs[2], θ, σ)
        prob[1] = prob[2] - cdf_eval(cutoffs[1], θ, σ)
        expected_u = prob[2] * u2 + prob[1] * u1
    else
        prob[2] = 0.0
        prob[1] = 1.0 - cdf_eval(cutoffs[1], θ, σ)
        expected_u = prob[1] * u1
    end
    return expected_u
end

function eval!(cutoffs, p)
    """
    objetive value function that updates the Cache with demand calculations for constraint and also welfare calculation for 
    objective value 
    """
    array_students = p[1]::Vector{Student{Float64}}
    cache = p[2]::Cache
    demand = cache.demand
    demand .= zero(eltype(cutoffs))

    probs = zeros(eltype(cutoffs), 2)
    welfare = zero(eltype(cutoffs))
    @inbounds for student ∈ array_students
        probs .= zero(eltype(cutoffs))
        welfare += student_choice!(cutoffs, student, probs)
        demand .+= probs
    end

    cache.last_cutoffs .= cutoffs
    cache.welfare = welfare
    return welfare             
end

function f(cutoffs, p)
    last = cache.last_cutoffs
    last === cutoffs || eval!(cutoffs, p)
    return cache.welfare
end

function cons(res, cutoffs, p)
    cache = p[2]::Cache
    capacities = p[3]::Vector{Float64}
    last = cache.last_cutoffs
    last === cutoffs || eval!(cutoffs, p)
    res[1] = capacities[1] - cache.demand[1] 
    res[2] = capacities[2] - cache.demand[2] 
    res[3] = cutoffs[2] - cutoffs[1]
    return nothing
end
function main()
    ## SIMULATE DATA ## 
    number_students = 1000
    array_students = Array{Student{Float64}}(undef, number_students)
    for student_id ∈ 1:1000
        θ = rand()
        σ = 0.2
        u1 = rand() + 1.0 
        u2 = rand() + 1.0
        array_students[student_id] = Student(θ, σ, u1, u2)
    end
    capacities = [400.0, 200.0]
    ## PERFORM OPTIMIZATION ## 
    cache = Cache(zeros(2), 0.0, zeros(2))
    p = [array_students, cache, capacities]
    x0 = rand(2)
    optprob = OptimizationFunction(f, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), cons = cons)
    prob = OptimizationProblem(optprob, x0, p, lcons = [0.0, 0.0, 0.0], ucons = [Inf, Inf, Inf])
    sol = solve(prob, Ipopt.Optimizer())
end
main()

If there a way around this? Like telling AutoEnzyme that a given element of the parameter vector needs to be treated as differentiable even though we are not optimizing over it?

The error when I try the above is:

Allocation could not have its type statically determined   %150 = call noalias nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.gc_alloc_obj({}* nonnull %6, i64 %143, {} addrspace(10)* nonnull %148) #42, !dbg !118
Stacktrace:
  [1] shadow_alloc_rewrite(V::Ptr{…}, gutils::Ptr{…}, Orig::Ptr{…}, idx::UInt64, prev::Ptr{…}, used::UInt8)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:681
  [2] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/.julia/packages/Enzyme/hu9gq/src/api.jl:338
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:1793
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:4669
  [5] codegen
    @ ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:3455 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:5533
  [7] _thunk
    @ ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:5533 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:5585 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:5696
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/hu9gq/src/compiler.jl:5881
 [11] autodiff(::ForwardMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::BatchDuplicated{…}, ::BatchDuplicatedNoNeed{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/hu9gq/src/Enzyme.jl:641
 [12] autodiff
    @ ~/.julia/packages/Enzyme/hu9gq/src/Enzyme.jl:545 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/hu9gq/src/Enzyme.jl:517 [inlined]
 [14] (::OptimizationEnzymeExt.var"#lag_h!#54"{…})(h::SubArray{…}, θ::Vector{…}, σ::Float64, μ::Vector{…}, p::Vector{…})
    @ OptimizationEnzymeExt ~/.julia/packages/OptimizationBase/UXLhR/ext/OptimizationEnzymeExt.jl:344
 [15] lag_h!
    @ ~/.julia/packages/OptimizationBase/UXLhR/ext/OptimizationEnzymeExt.jl:341 [inlined]
 [16] eval_hessian_lagrangian(evaluator::OptimizationMOI.MOIOptimizationNLPEvaluator{…}, h::SubArray{…}, x::Vector{…}, σ::Float64, μ::SubArray{…})
    @ OptimizationMOI ~/.julia/packages/OptimizationMOI/L0B28/src/nlp.jl:378
 [17] eval_hessian_lagrangian(model::IpoptMathOptInterfaceExt.Optimizer, H::Vector{Float64}, x::Vector{Float64}, σ::Float64, μ::Vector{Float64})
    @ IpoptMathOptInterfaceExt ~/.julia/packages/Ipopt/Fbuwv/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl:1303
 [18] (::IpoptMathOptInterfaceExt.var"#eval_h_cb#11"{…})(x::Vector{…}, rows::Vector{…}, cols::Vector{…}, obj_factor::Float64, lambda::Vector{…}, values::Vector{…})
    @ IpoptMathOptInterfaceExt ~/.julia/packages/Ipopt/Fbuwv/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl:1387
 [19] _Eval_H_CB(n::Int32, x_ptr::Ptr{…}, ::Int32, obj_factor::Float64, m::Int32, lambda_ptr::Ptr{…}, ::Int32, nele_hess::Int32, iRow::Ptr{…}, jCol::Ptr{…}, values_ptr::Ptr{…}, user_data::Ptr{…})
    @ Ipopt ~/.julia/packages/Ipopt/Fbuwv/src/C_wrapper.jl:0
 [20] #5
    @ ~/.julia/packages/Ipopt/Fbuwv/src/C_wrapper.jl:407 [inlined]
 [21] disable_sigint(f::Ipopt.var"#5#6"{IpoptProblem, Base.RefValue{Float64}})
    @ Base ./c.jl:167
 [22] IpoptSolve
    @ ~/.julia/packages/Ipopt/Fbuwv/src/C_wrapper.jl:406 [inlined]
 [23] optimize!(model::IpoptMathOptInterfaceExt.Optimizer)
    @ IpoptMathOptInterfaceExt ~/.julia/packages/Ipopt/Fbuwv/ext/IpoptMathOptInterfaceExt/MOI_wrapper.jl:1523
 [24] __solve(cache::OptimizationMOI.MOIOptimizationNLPCache{OptimizationMOI.MOIOptimizationNLPEvaluator{…}, IpoptMathOptInterfaceExt.Optimizer})
    @ OptimizationMOI ~/.julia/packages/OptimizationMOI/L0B28/src/nlp.jl:552
 [25] solve!(cache::OptimizationMOI.MOIOptimizationNLPCache{OptimizationMOI.MOIOptimizationNLPEvaluator{…}, IpoptMathOptInterfaceExt.Optimizer})
    @ SciMLBase ~/.julia/packages/SciMLBase/iHgIu/src/solve.jl:227
 [26] solve(::OptimizationProblem{…}, ::IpoptMathOptInterfaceExt.Optimizer; kwargs::@Kwargs{})
    @ SciMLBase ~/.julia/packages/SciMLBase/iHgIu/src/solve.jl:129
 [27] solve(::OptimizationProblem{…}, ::IpoptMathOptInterfaceExt.Optimizer)
    @ SciMLBase ~/.julia/packages/SciMLBase/iHgIu/src/solve.jl:126
 [28] main()
    @ Main ./REPL[37]:19
 [29] top-level scope
    @ REPL[40]:1
Some type information was truncated. Use `show(err)` to see complete types.

Nowadays, DifferentiationInterface.jl supports both Constants and Caches, but the Optimization.jl bindings consider everything to be constant. Perhaps it would be worth splitting p in two?

Hi there, thanks for the comment. I do not really 100% follow your suggestion, can you elaborate a bit on what do you mean by splitting p into two and how that would allow me to tell Optimization.jl that one part is differentiable?