Enzyme not working

Hi there,

Opening this issue to see whether I could get some help (@wsmoses ) with making Reverse mode in autodiff work. I made it work on trivial examples, even containing conditional statements but for some reason it won’t work on the function I am interested. This function takens in various parameters and then a vector of floats over which we want to calculate the gradient. It is a long function so for now I will only provide a reduced example. I am sorry if this is still too long. Below you can find self-contained code which includes a function at the end that calls autodiff using Enzyme. Moreover, after the code I post the error message.

To add a bit more explanation about the code here:

  1. The function we want to differentiate is called get_ll_single and all the ones above are auxiliary functions that are used within the cited function.
  2. The Enzyme.jl use is embedded in the last function call_autodiff_example which just defined some parameters and calls autodiff.
using Distributions 
using Polynomials
using Combinatorics
using Statistics
using FastGaussQuadrature
using Interpolations
using LinearAlgebra
using Printf
using BenchmarkTools
using StaticArrays
using Enzyme

struct SpecFormat
    special_banks::Vector{Int64} # list of banks that are special
    home_bank::Int64
    max_banks_in_choice_set::Int64 # max number of banks in the choice set 

    hermite_order::Int64
    bernstein_order::Int64

    initial_search_cost_X::Vector{Union{Symbol, String}}
    per_bank_search_cost_X::Vector{Union{Symbol, String}}
    ρ_X::Vector{Union{Symbol, String}}
    γ_X::Vector{Union{Symbol, String}}
    γ_std_X::Vector{Union{Symbol, String}}
    κ_X::Vector{Union{Symbol, String}}
    initial_mean_shifters::Vector{Union{Symbol, String}}
    initial_std_shifters::Vector{Union{Symbol, String}}
    final_mean_shifters::Vector{Union{Symbol, String}}
    final_std_shifters::Vector{Union{Symbol, String}}

end

struct SubParamIndices
    κ::Int64
    γ::Int64
    γ_std::Int64
    initial_search_cost::Int64
    ρ::Int64
    initial_bank_μ::Int64
    initial_bank_σ::Int64
    final_bank_μ::Int64
    final_bank_σ::Int64
    σinv_choice::Int64
    σinv_start_search::Int64

    per_bank_search_cost::Vector{Int64}
    bank_params::Vector{Vector{Int64}}
    jump_bid_params::Vector{Int64}
    implied_jump_bid_params::Vector{Int64}

    n_params::Int64

end

struct ParamIndices
    bank_params::Vector{Vector{Int64}}
    jump_bid_params::Vector{Int64}
    implied_jump_bid_params::Vector{Int64}
    initial_search_cost_X::Vector{Int64}
    per_bank_search_cost_X::Vector{Int64}
    ρ_X::Vector{Int64}
    γ_X::Vector{Int64}
    γ_std_X::Vector{Int64}
    κ_X::Vector{Int64}
    bank_mean_shifters::Vector{Int64}
    bank_std_shifters::Vector{Int64}
    initial_search_cost_X_indices::Vector{Int64}
    ρ_X_indices::Vector{Int64}
    γ_X_indices::Vector{Int64}
    γ_std_X_indices::Vector{Int64}
    κ_X_indices::Vector{Int64}
    initial_mean_indices::Vector{Int64}
    initial_std_indices::Vector{Int64}
    final_mean_indices::Vector{Int64}
    final_std_indices::Vector{Int64}
    σinv_choice::Int64
    σinv_start_search::Int64
    n_params::Int64
end

function cumtrapz_upper(x::Vector{Float64}, y::AbstractArray{TV, 1})::Vector{TV} where TV
    z::Vector{TV} = zeros(Float64, length(x))
    z[1] = 0.0
    @inbounds @fastmath for i = 1:length(x)-1
        z[i+1] = z[i] + (x[i+1] - x[i]) * (y[i+1] + y[i]) 
    end
    z = z[end] .- z
    return 0.5 * z
end

function hermite5_cdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
    poly_coeffs::Vector{TV} = [-48 * a[1] * a[2] - 24 * sqrt(2) * a[2] * a[3] + 8 * sqrt(6) * a[1] * a[4] - 24 * sqrt(3) * a[3] * a[4] + 4 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5], 
        -24 * a[2]^2 - 24 * sqrt(2) * a[1] * a[3] - 12 * a[3]^2 - 24 * a[4]^2 + 12 * sqrt(6) * a[1] * a[5] - 12 * sqrt(3) * a[3] * a[5] - 15 * a[5]^2, 
        -24 * sqrt(2) * a[2] * a[3] - 8 * sqrt(6) * a[1] * a[4] + 8 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
        -12 * a[3]^2 - 8 * sqrt(6) * a[2] * a[4] + 4 * a[4]^2 - 4 * sqrt(6) * a[1] * a[5] + 8 * sqrt(3) * a[3] * a[5] - 17 * a[5]^2,
        -8 * sqrt(3) * a[3] * a[4] - 4 * sqrt(6) * a[2] * a[5] + 12 * a[4] * a[5],
        -4 * a[4]^2 - 4 * sqrt(3) * a[3] * a[5] + 5 * a[5]^2,
        -4 * a[4] * a[5],
        -a[5]^2]
    func(x) = σ * Polynomial(poly_coeffs).((x .- μ) / σ) .* pdf.(Normal(μ, σ), x) ./ 24.0 .+ cdf.(Normal(μ, σ), x)
    return func
end

function hermite5_pdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
    poly_coeffs::Vector{TV} = [a[1] - a[3] / sqrt(2) + sqrt(3/8) * a[5], 
        a[2] - sqrt(3/2) * a[4], 
        a[3] / sqrt(2) - sqrt(3/2) * a[5], 
        a[4] / sqrt(6), 
        a[5] / sqrt(24)]
    func(x) = (Polynomial(poly_coeffs).((x .- μ) / σ)).^2 .* pdf.(Normal(μ, σ), x)
    return func
end

function bernstein3(a::AbstractArray{TV, 1}; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    base_func(x) = a[1] * (1 - x)^3 + 3 * a[2] * x * (1 - x)^2 + 3 * a[3] * x^2 * (1 - x) + a[4] * x^3
    func(x) = base_func((x - lower) / (upper - lower))
    return func
end

function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
    # Cubic equation solver for complex polynomial (degree=3)
    # http://en.wikipedia.org/wiki/Cubic_function   Lagrange's method
    a1  =  1 / poly[4]
    E1  = -poly[3]*a1
    E2  =  poly[2]*a1
    E3  = -poly[1]*a1
    s0  =  E1
    E12 =  E1*E1
    A   =  2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
    B   =  E12 - 3*E2                 # = s1 s2
    # quadratic equation: z^2 - Az + B^3=0  where roots are equal to s1^3 and s2^3
    Δ = sqrt(A*A - 4*B*B*B)
    if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
        s1 = exp(log(0.5 * (A + Δ)) * (1/3))
    else
        s1 = exp(log(0.5 * (A - Δ)) * (1/3))
    end
    if s1 == 0
        s2 = s1
    else
        s2 = B / s1
    end
    zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
    zeta2 = conj(zeta1)
    # return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)

    sol1 = (1/3) * (s0 + s1 + s2)
    sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
    sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)

    if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
        return real(sol1)
    elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
        return real(sol2)
    elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
        return real(sol3)
    else
        return NaN
    end
end

function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    # Get this in the form ax^3 + bx^2 + cx + d = 0
    a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
    b = 3 * p[1] - 6 * p[2] + 3 * p[3]
    c = 3 * p[2] - 3 * p[1]
    d = p[1] - val
    coeffs = [d, c, b, a]

    if d >= 0.0 # (poly(0.0))
        return lower
    elseif a + b + c + d <= 0.0 # (poly(1.0))
        return upper
    end

    x = solve_cubic_eq(Complex.(coeffs))
    return x * (upper - lower) + lower
end

# From PolynomialRoots.jl
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
    # Cubic equation solver for complex polynomial (degree=3)
    # http://en.wikipedia.org/wiki/Cubic_function   Lagrange's method
    a1  =  1 / poly[4]
    E1  = -poly[3]*a1
    E2  =  poly[2]*a1
    E3  = -poly[1]*a1
    s0  =  E1
    E12 =  E1*E1
    A   =  2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
    B   =  E12 - 3*E2                 # = s1 s2
    # quadratic equation: z^2 - Az + B^3=0  where roots are equal to s1^3 and s2^3
    Δ = sqrt(A*A - 4*B*B*B)
    if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
        s1 = exp(log(0.5 * (A + Δ)) * (1/3))
    else
        s1 = exp(log(0.5 * (A - Δ)) * (1/3))
    end
    if s1 == 0
        s2 = s1
    else
        s2 = B / s1
    end
    zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
    zeta2 = conj(zeta1)
    # return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)

    sol1 = (1/3) * (s0 + s1 + s2)
    sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
    sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)

    if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
        return real(sol1)
    elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
        return real(sol2)
    elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
        return real(sol3)
    else
        return NaN
    end
end

function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    # Get this in the form ax^3 + bx^2 + cx + d = 0
    a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
    b = 3 * p[1] - 6 * p[2] + 3 * p[3]
    c = 3 * p[2] - 3 * p[1]
    d = p[1] - val
    coeffs = [d, c, b, a]

    if d >= 0.0 # (poly(0.0))
        return lower
    elseif a + b + c + d <= 0.0 # (poly(1.0))
        return upper
    end

    x = solve_cubic_eq(Complex.(coeffs))
    return x * (upper - lower) + lower
end

# Here, we take in a set of choice sets, each of with can be repeated N times.
function logit_probabilities!(inside_probabilities::Array{TV, 3}, outside_probabilities::Matrix{TV}, utilities::Array{TV, 3}, N::Vector{Int64}, σinv::TV) where TV

    for uJ_index in axes(utilities, 2)
        for bias_index in axes(utilities, 3)
            this_exp_probs = @views exp.(utilities[:, uJ_index, bias_index] * σinv)
            inv_denom = 1.0 / (1.0 + dot(this_exp_probs, N))
            inside_probabilities[:, uJ_index, bias_index] .= inv_denom * this_exp_probs
            outside_probabilities[uJ_index, bias_index] = inv_denom
        end
    end

    return nothing
end

function partial_interpolate(x::AbstractArray{Float64, 1}, y::AbstractArray{TV, 1}, idx::Int64, val::Float64) where TV
    z::TV = 0.0
    if idx == 0
        z = x[1]
    elseif idx == length(x)
        z = x[end]
    else
        z = ((y[idx+1] - y[idx]) * (val - x[idx]) / (x[idx+1] - x[idx])) + y[idx]
    end

    return z
end

function get_spec()::SpecFormat

    special_banks::Vector{Int64} = [1, 14, 16] 
    home_bank::Int64 = 12
    max_banks_in_choice_set::Int64 = 3

    hermite_order::Int64 = 5
    bernstein_order::Int64 = 4

    initial_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    per_bank_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :s_n_branches]
    ρ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    γ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    γ_std_X::Vector{Union{Symbol, String}} = [:ones]
    κ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    initial_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
    initial_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
    final_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
    final_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]

    spec::SpecFormat = SpecFormat(special_banks, home_bank, max_banks_in_choice_set, hermite_order, bernstein_order,
        initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X, initial_mean_shifters, initial_std_shifters,
        final_mean_shifters, final_std_shifters)
    return spec

end

function get_indices(grid::AbstractArray{Float64, 1}, vals::AbstractArray{Float64, 1})::Vector{Int64}
    indices = zeros(Int64, length(vals))
    for (i, val) in enumerate(vals)
        indices[i] = searchsortedlast(grid, val)
    end

    return indices
end

function get_param_indices(spec::SpecFormat)::Tuple{ParamIndices, SubParamIndices}
    param_val::Int64 = 0

    bank_params::Vector{Vector{Int64}} = []
    for b = 1:(length(spec.special_banks)+1)
        this_bank_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
        push!(bank_params, this_bank_params)
    end

    jump_bid_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
    implied_jump_bid_params::Vector{Int64}, param_val = param_update(spec.bernstein_order, param_val)

    initial_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.initial_search_cost_X), param_val)
    per_bank_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.per_bank_search_cost_X), param_val)
    ρ_X::Vector{Int64}, param_val = param_update(length(spec.ρ_X), param_val)
    γ_X::Vector{Int64}, param_val = param_update(length(spec.γ_X), param_val)
    γ_std_X::Vector{Int64}, param_val = param_update(length(spec.γ_std_X), param_val)
    κ_X::Vector{Int64}, param_val = param_update(length(spec.κ_X), param_val)

    @assert(length(spec.initial_mean_shifters) == length(spec.final_mean_shifters))
    @assert(length(spec.initial_std_shifters) == length(spec.final_std_shifters))
    bank_mean_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_mean_shifters), param_val)
    bank_std_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_std_shifters), param_val)
    σinv_choice::Int64, param_val = param_update(1, param_val; return_as_vector = false)
    σinv_start_search::Int64, param_val = param_update(1, param_val; return_as_vector = false)

    # Now get the inidices: Will have one X matrix at the individual level (N_individuals x total parameters)
    # and then indices for which subsets of this X to use for various parameters
    all_X = sort!(union(spec.initial_search_cost_X, spec.ρ_X, spec.γ_X, spec.γ_std_X, spec.κ_X, 
        spec.initial_mean_shifters, spec.initial_std_shifters, spec.final_mean_shifters, spec.final_std_shifters)) # per-bank search cost will be at the individual-bank level
    initial_search_cost_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_search_cost_X]
    ρ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.ρ_X]
    γ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_X]
    γ_std_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_std_X]
    κ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.κ_X]
    initial_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_mean_shifters]
    initial_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_std_shifters]
    final_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_mean_shifters]
    final_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_std_shifters]

    param_indices = ParamIndices(bank_params, jump_bid_params, implied_jump_bid_params, 
        initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X,
        bank_mean_shifters, bank_std_shifters,
        initial_search_cost_X_indices, ρ_X_indices, γ_X_indices, γ_std_X_indices, κ_X_indices, 
        initial_mean_indices, initial_std_indices, final_mean_indices, final_std_indices, 
        σinv_choice, σinv_start_search, param_val)

    sub_param_val::Int64 = 0
    sub_κ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_γ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_γ_std::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_search_cost::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_ρ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_final_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_final_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_σinv_choice::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_σinv_start_search::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_per_bank_search_cost::Vector{Int64}, sub_param_val = param_update(length(spec.special_banks)+1, sub_param_val)
    sub_bank_params::Vector{Vector{Int64}} = []
    for b = 1:(length(spec.special_banks)+1)
        this_bank_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
        push!(sub_bank_params, this_bank_params)
    end
    sub_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
    sub_implied_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.bernstein_order, sub_param_val)

    sub_param_indices = SubParamIndices(sub_κ, sub_γ, sub_γ_std, sub_initial_search_cost, sub_ρ, 
        sub_initial_bank_μ, sub_initial_bank_σ, sub_final_bank_μ, sub_final_bank_σ, 
        sub_σinv_choice, sub_σinv_start_search, sub_per_bank_search_cost, sub_bank_params, 
        sub_jump_bid_params, sub_implied_jump_bid_params, sub_param_val)
        

    return (param_indices, sub_param_indices)
end

function param_update(N::Int64, param_val::Int64; return_as_vector::Bool = true)
    if N > 1 || return_as_vector
        indices::Vector{Int64} = collect(1:N) .+ param_val
        param_val += N
        return indices, param_val
    else
        indices_int::Int64 = param_val + 1
        param_val += 1
        return indices_int, param_val
    end
end

function get_ll_single(params::AbstractArray{TV, 1}, param_indices::SubParamIndices, uJ_grid::Vector{Float64}, base_nodes::SVector{9, Float64}, weights::SVector{9, Float64}
        , u_current_idx::Int64, remaining_term::Float64, amount::Float64, chosen_bank::Int64, current_monthly_payment::Float64, final_monthly_payment::Float64,
    search_type::Int64, choice_sets::Vector{Vector{Int64}}, choice_sets_N::Vector{Int64}, chosen_choice_set::Int64) where TV


    # Initialize the values of the frictions -- this applies for everyone
    κ::TV = params[param_indices.κ] 
    γ::TV = params[param_indices.γ] 
    γ_std::TV = params[param_indices.γ_std]
    initial_search_cost::TV = params[param_indices.initial_search_cost] 
    ρ::TV = params[param_indices.ρ] 
    initial_bank_μ::TV = params[param_indices.initial_bank_μ]  
    initial_bank_σ::TV = params[param_indices.initial_bank_σ] 
    final_bank_μ::TV = params[param_indices.final_bank_μ] 
    final_bank_σ::TV = params[param_indices.final_bank_σ] 
    σinv_choice::TV = params[param_indices.σinv_choice]
    σinv_start_search::TV = params[param_indices.σinv_start_search]
    
    ## Now the loop begins
    # Note that the cdf of zero-profit rates is written in terms of monthly payments if A = 1, scaled by number of 
    # payments. So, a value of m corresponds to a monthly payment of m * A / T.
    # 
    # We know that monthly utility = (delta - m * A / T) - kappa / NPVrate(T), 
    # So, Pr(utility <= u) = Pr((delta - m * A / T) - kappa/NPV <= u) = Pr(scaled monthly payment >= (T/A) * (delta - (u + kappa/NPV)))
    # The bias is in terms of monthly payments. So, we just add the bias.
    β_customer_m::Float64 = 0.95^(1/12)
    per_bank_search_costs::Vector{TV} = params[param_indices.per_bank_search_cost]
    npv_scale::Float64 = (1.0 - β_customer_m^remaining_term) / (1 - β_customer_m)
    overall_scale::Float64 = remaining_term / amount
    bias_grid = base_nodes * γ_std * sqrt(2) .+ γ # in units of dollars
    scaled_κ::TV = κ / npv_scale # in units of dollars per month


    # Generate the cdfs and pdfs for the banks: cdfs at the time of search, and both later if needed 

    # Maps utilities ($) to utilities ($) <--- could rethink whether this is the best map
    implied_jump_bid::Function = bernstein3(params[param_indices.implied_jump_bid_params]; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
    implied_jump_bid_inv(val::Float64) = bernstein3_inv(params[param_indices.implied_jump_bid_params], val; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)

    
    initial_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], initial_bank_μ, initial_bank_σ)(-overall_scale .* (x .+ scaled_κ)) for a in param_indices.bank_params]
    initial_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x) 
    initial_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
    initial_home_bank_cdf(x) = initial_jump_bid_cdf(implied_jump_bid.(x))

    final_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
    final_bank_pdf = [x -> overall_scale .* hermite5_pdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
    final_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x) 
    final_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
    final_home_bank_cdf(x) = final_jump_bid_cdf(implied_jump_bid.(x))


    refi_at_home_bank::Bool = chosen_bank == 0
    u_current::Float64 = -current_monthly_payment
    u_final::TV = -final_monthly_payment - scaled_κ * refi_at_home_bank


    choice_set_utilities::Array{TV, 3} = zeros(length(choice_sets), length(uJ_grid), length(bias_grid)) ### CONTAINER 3D-FLOAT

    base_0::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
    base_1::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
    to_add::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID


    initial_bank_cdf_vals::Matrix{Float64} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
    one_minus_initial_bank_cdf_vals::Matrix{TV} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))


    @inbounds for (bias_index, bias) in enumerate(bias_grid)

        biased_uJ = uJ_grid .+ bias 

        for b = 1:length(initial_bank_cdf)
            initial_bank_cdf_vals[b, :] .= initial_bank_cdf[b](biased_uJ) 
            one_minus_initial_bank_cdf_vals[b, :] .= @views 1.0 .- initial_bank_cdf_vals[b, :]
        end

        initial_home_bank_cdf_vals::Vector{Float64} = initial_home_bank_cdf(biased_uJ)
        good_indices::BitVector = (initial_home_bank_cdf_vals[end] .- initial_home_bank_cdf_vals) .>= 1e-7
        
        @inbounds for (choice_set_index, choice_set) in enumerate(choice_sets)

            base_0 .= @views initial_bank_cdf_vals[choice_set[1], :]
            for (b_idx, b) in enumerate(choice_set)
                b_idx > 1 || continue
                base_0 .*= @views initial_bank_cdf_vals[b, :]
            end
            cumtrapz_0 = cumtrapz_upper(uJ_grid, base_0)

            base_1 .= 0.0
            for (b_index, b) in enumerate(choice_set)
                to_add .= @views one_minus_initial_bank_cdf_vals[b, :]
                for (bprime_index, bprime) in enumerate(choice_set)
                    if bprime_index != b_index
                        to_add .*= @views initial_bank_cdf_vals[bprime, :]
                    end
                end
                base_1 .+= to_add
            end
            cumtrapz_1 = cumtrapz_upper(uJ_grid, base_1)

            base_1 .*= initial_home_bank_cdf_vals
            cumtrapz_1_with_home = cumtrapz_upper(uJ_grid, base_1)

            next_part = base_0
            for i in eachindex(next_part)
                good_indices[i] ? next_part[i] = uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - (cumtrapz_1_with_home[i] - initial_home_bank_cdf_vals[i] * cumtrapz_1[i]) / (initial_home_bank_cdf_vals[end] - initial_home_bank_cdf_vals[i]) : uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - cumtrapz_1[i]
            end


            this_search_cost::TV = sum([per_bank_search_costs[b] for b in choice_set])
            choice_set_utilities[choice_set_index, :, bias_index] .= @. next_part - this_search_cost            

        end
    end


    choice_set_probabilities::Array{TV, 3} = zeros(size(choice_set_utilities))
    outside_option_probabilities::Array{TV, 2} = zeros(length(uJ_grid), length(bias_grid))
    logit_probabilities!(choice_set_probabilities, outside_option_probabilities,
        choice_set_utilities, choice_sets_N, σinv_choice)

    benefit_of_search_uJ::Array{TV, 2} = dropdims(sum((choice_sets_N .* choice_set_probabilities) .* choice_set_utilities, dims = 1), dims = 1)
    S_base::Vector{TV} = zeros(length(bias_grid)) 

    for (bias_index, bias) in enumerate(bias_grid)
        initial_jump_bid_pdf_vals = initial_jump_bid_pdf(uJ_grid .+ bias)

        S_base[bias_index] = @views partial_cumtrapz(uJ_grid, benefit_of_search_uJ[1:(u_current_idx+1), bias_index] .* initial_jump_bid_pdf_vals[1:(u_current_idx+1)], u_current_idx, u_current) + initial_jump_bid_cdf(u_current + bias) * 
            partial_interpolate(uJ_grid, benefit_of_search_uJ[:, bias_index], u_current_idx, u_current)

    end

    pr_search::Vector{TV} = @. (exp(σinv_start_search * (S_base - initial_search_cost))) / (1.0 + (exp(σinv_start_search * (S_base - initial_search_cost)))) ### THIS ARE VECTORS OF SIZE |CHOICE SETS|=9
    pr_no_search = 1 .- pr_search


    likelihood = 1 - ρ + ρ * dot(pr_no_search, weights)

    
    return log(likelihood)

end 

function call_autodiff_example()

    spec::SpecFormat = get_spec()
    sub_param_indices::SubParamIndices = get_param_indices(spec)[2]

    uJ_grid::Vector{Float64} = union(collect(0:0.05:2.0), [2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0]) # this will be a constant 
    uJ_grid = sort!(-uJ_grid)

    base_nodes, base_weights = SVector{9}.(FastGaussQuadrature.gausshermite(9))
    weights = SVector{9}(base_weights / sqrt(π))

    params = [0.0, 0.0, 0.25, 0.0, 0.1, 1.2, 0.25, 1.2, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 
                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 
                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0]

    u_current_idx = 44
    remaining_term = 67.0
    amount = 13.32
    chosen_bank = -1
    current_monthly_payment = 0.24
    final_monthly_payment = NaN
    search_type = 1 
    choice_sets = [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3], [1, 3, 4], 
                    [1, 4], [1, 4, 4], [2], [2, 3], [2, 3, 4], [2, 4], [2, 4, 4], 
                    [3], [3, 4], [3, 4, 4], [4], [4, 4], [4, 4, 4]]
    choice_sets_N = [1, 1, 1, 15, 1, 15, 15, 105, 1, 1, 15, 15, 105, 1, 15, 105, 15, 105, 455]
    chosen_choice_set = 0
    dx = zeros(44)

    single_ll(x) = get_ll_single(x, sub_param_indices, uJ_grid, base_nodes, weights, u_current_idx,
    remaining_term, amount, chosen_bank, current_monthly_payment, final_monthly_payment, search_type, choice_sets, choice_sets_N, chosen_choice_set)
    autodiff(Reverse, single_ll, Active, Duplicated(params, dx))

end

call_autofiff_example()

Then after some time it crashes given the following Stackstrace:

ERROR: LoadError: Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %value_phi5253395, {} addrspace(10)** %.fca.1.gep, align 8, !dbg !1671, !noalias !216 const val:   %value_phi5253395 = phi {} addrspace(10)* [ %arrayref514, %L1631.lr.ph ], [ %arrayref1084, %L3207 ]
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}
 llvalue=  %arrayref514 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %arrayptr5112966, align 8, !dbg !1096, !tbaa !1097, !alias.scope !197, !noalias !198
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] get_ll_single
   @ ~/Documents/refinancing/multistep/example_Enzyme.jl:527

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:1612
  [2] get_ll_single
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
  [3] *
    @ ./float.jl:411 [inlined]
  [4] trapz
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:87
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6587 [inlined]
  [6] enzyme_call
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6188 [inlined]
  [7] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6065 [inlined]
  [8] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:309 [inlined]
  [9] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
 [10] call_autofiff_example()
    @ Main ~/Documents/refinancing/multistep/example_Enzyme.jl:592
 [11] top-level scope
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:596
 [12] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [13] top-level scope
    @ REPL[2]:1
in expression starting at /Users/miguelborrero/Documents/refinancing/multistep/example_Enzyme.jl:596

I would really appreciate some help since it is important for me to make this work so thanks a lot in advance!

Hey! Can you post a complete example with the necessary imports and function / object definitions?

And apologies when I suggested to open an issue I meant on GitHub - EnzymeAD/Enzyme.jl: Julia bindings for the Enzyme automatic differentiator

3 Likes

Okay, I converged on a minimal example that it might still be too long but I can not make it shorter and preserve relevance. I edited the post with the code and the error message I am getting. It does have a different error message but maybe fixing this is a first step to then making it work on the complete example.

Oh sorry, Im new so I misunderstood! I am opening an issue now. Thanks so much!

Any update?

So there’s been a GitHub issue open some time. @wsmoses any chance it gets fixed?

If it does I will update this post since we converged on a much simpler MWE.

Is it the following issue?

Yes, thats the one.

Yes it should be possible to fix but I haven’t had a chance t debug. Posting the simplified version would be appreciated

Thanks a lot! I understand you must be so busy. Regarding further simplification its already down by a lot and reducing it more gave me a different error, sorry.