Hi there,
Opening this issue to see whether I could get some help (@wsmoses ) with making Reverse mode in autodiff work. I made it work on trivial examples, even containing conditional statements but for some reason it won’t work on the function I am interested. This function takens in various parameters and then a vector of floats over which we want to calculate the gradient. It is a long function so for now I will only provide a reduced example. I am sorry if this is still too long. Below you can find self-contained code which includes a function at the end that calls autodiff using Enzyme. Moreover, after the code I post the error message.
To add a bit more explanation about the code here:
- The function we want to differentiate is called
get_ll_single
and all the ones above are auxiliary functions that are used within the cited function. - The Enzyme.jl use is embedded in the last function
call_autodiff_example
which just defined some parameters and calls autodiff.
using Distributions
using Polynomials
using Combinatorics
using Statistics
using FastGaussQuadrature
using Interpolations
using LinearAlgebra
using Printf
using BenchmarkTools
using StaticArrays
using Enzyme
struct SpecFormat
special_banks::Vector{Int64} # list of banks that are special
home_bank::Int64
max_banks_in_choice_set::Int64 # max number of banks in the choice set
hermite_order::Int64
bernstein_order::Int64
initial_search_cost_X::Vector{Union{Symbol, String}}
per_bank_search_cost_X::Vector{Union{Symbol, String}}
ρ_X::Vector{Union{Symbol, String}}
γ_X::Vector{Union{Symbol, String}}
γ_std_X::Vector{Union{Symbol, String}}
κ_X::Vector{Union{Symbol, String}}
initial_mean_shifters::Vector{Union{Symbol, String}}
initial_std_shifters::Vector{Union{Symbol, String}}
final_mean_shifters::Vector{Union{Symbol, String}}
final_std_shifters::Vector{Union{Symbol, String}}
end
struct SubParamIndices
κ::Int64
γ::Int64
γ_std::Int64
initial_search_cost::Int64
ρ::Int64
initial_bank_μ::Int64
initial_bank_σ::Int64
final_bank_μ::Int64
final_bank_σ::Int64
σinv_choice::Int64
σinv_start_search::Int64
per_bank_search_cost::Vector{Int64}
bank_params::Vector{Vector{Int64}}
jump_bid_params::Vector{Int64}
implied_jump_bid_params::Vector{Int64}
n_params::Int64
end
struct ParamIndices
bank_params::Vector{Vector{Int64}}
jump_bid_params::Vector{Int64}
implied_jump_bid_params::Vector{Int64}
initial_search_cost_X::Vector{Int64}
per_bank_search_cost_X::Vector{Int64}
ρ_X::Vector{Int64}
γ_X::Vector{Int64}
γ_std_X::Vector{Int64}
κ_X::Vector{Int64}
bank_mean_shifters::Vector{Int64}
bank_std_shifters::Vector{Int64}
initial_search_cost_X_indices::Vector{Int64}
ρ_X_indices::Vector{Int64}
γ_X_indices::Vector{Int64}
γ_std_X_indices::Vector{Int64}
κ_X_indices::Vector{Int64}
initial_mean_indices::Vector{Int64}
initial_std_indices::Vector{Int64}
final_mean_indices::Vector{Int64}
final_std_indices::Vector{Int64}
σinv_choice::Int64
σinv_start_search::Int64
n_params::Int64
end
function cumtrapz_upper(x::Vector{Float64}, y::AbstractArray{TV, 1})::Vector{TV} where TV
z::Vector{TV} = zeros(Float64, length(x))
z[1] = 0.0
@inbounds @fastmath for i = 1:length(x)-1
z[i+1] = z[i] + (x[i+1] - x[i]) * (y[i+1] + y[i])
end
z = z[end] .- z
return 0.5 * z
end
function hermite5_cdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
poly_coeffs::Vector{TV} = [-48 * a[1] * a[2] - 24 * sqrt(2) * a[2] * a[3] + 8 * sqrt(6) * a[1] * a[4] - 24 * sqrt(3) * a[3] * a[4] + 4 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
-24 * a[2]^2 - 24 * sqrt(2) * a[1] * a[3] - 12 * a[3]^2 - 24 * a[4]^2 + 12 * sqrt(6) * a[1] * a[5] - 12 * sqrt(3) * a[3] * a[5] - 15 * a[5]^2,
-24 * sqrt(2) * a[2] * a[3] - 8 * sqrt(6) * a[1] * a[4] + 8 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
-12 * a[3]^2 - 8 * sqrt(6) * a[2] * a[4] + 4 * a[4]^2 - 4 * sqrt(6) * a[1] * a[5] + 8 * sqrt(3) * a[3] * a[5] - 17 * a[5]^2,
-8 * sqrt(3) * a[3] * a[4] - 4 * sqrt(6) * a[2] * a[5] + 12 * a[4] * a[5],
-4 * a[4]^2 - 4 * sqrt(3) * a[3] * a[5] + 5 * a[5]^2,
-4 * a[4] * a[5],
-a[5]^2]
func(x) = σ * Polynomial(poly_coeffs).((x .- μ) / σ) .* pdf.(Normal(μ, σ), x) ./ 24.0 .+ cdf.(Normal(μ, σ), x)
return func
end
function hermite5_pdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
poly_coeffs::Vector{TV} = [a[1] - a[3] / sqrt(2) + sqrt(3/8) * a[5],
a[2] - sqrt(3/2) * a[4],
a[3] / sqrt(2) - sqrt(3/2) * a[5],
a[4] / sqrt(6),
a[5] / sqrt(24)]
func(x) = (Polynomial(poly_coeffs).((x .- μ) / σ)).^2 .* pdf.(Normal(μ, σ), x)
return func
end
function bernstein3(a::AbstractArray{TV, 1}; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
base_func(x) = a[1] * (1 - x)^3 + 3 * a[2] * x * (1 - x)^2 + 3 * a[3] * x^2 * (1 - x) + a[4] * x^3
func(x) = base_func((x - lower) / (upper - lower))
return func
end
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
# Cubic equation solver for complex polynomial (degree=3)
# http://en.wikipedia.org/wiki/Cubic_function Lagrange's method
a1 = 1 / poly[4]
E1 = -poly[3]*a1
E2 = poly[2]*a1
E3 = -poly[1]*a1
s0 = E1
E12 = E1*E1
A = 2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
B = E12 - 3*E2 # = s1 s2
# quadratic equation: z^2 - Az + B^3=0 where roots are equal to s1^3 and s2^3
Δ = sqrt(A*A - 4*B*B*B)
if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
s1 = exp(log(0.5 * (A + Δ)) * (1/3))
else
s1 = exp(log(0.5 * (A - Δ)) * (1/3))
end
if s1 == 0
s2 = s1
else
s2 = B / s1
end
zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
zeta2 = conj(zeta1)
# return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)
sol1 = (1/3) * (s0 + s1 + s2)
sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)
if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
return real(sol1)
elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
return real(sol2)
elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
return real(sol3)
else
return NaN
end
end
function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
# Get this in the form ax^3 + bx^2 + cx + d = 0
a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
b = 3 * p[1] - 6 * p[2] + 3 * p[3]
c = 3 * p[2] - 3 * p[1]
d = p[1] - val
coeffs = [d, c, b, a]
if d >= 0.0 # (poly(0.0))
return lower
elseif a + b + c + d <= 0.0 # (poly(1.0))
return upper
end
x = solve_cubic_eq(Complex.(coeffs))
return x * (upper - lower) + lower
end
# From PolynomialRoots.jl
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
# Cubic equation solver for complex polynomial (degree=3)
# http://en.wikipedia.org/wiki/Cubic_function Lagrange's method
a1 = 1 / poly[4]
E1 = -poly[3]*a1
E2 = poly[2]*a1
E3 = -poly[1]*a1
s0 = E1
E12 = E1*E1
A = 2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
B = E12 - 3*E2 # = s1 s2
# quadratic equation: z^2 - Az + B^3=0 where roots are equal to s1^3 and s2^3
Δ = sqrt(A*A - 4*B*B*B)
if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
s1 = exp(log(0.5 * (A + Δ)) * (1/3))
else
s1 = exp(log(0.5 * (A - Δ)) * (1/3))
end
if s1 == 0
s2 = s1
else
s2 = B / s1
end
zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
zeta2 = conj(zeta1)
# return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)
sol1 = (1/3) * (s0 + s1 + s2)
sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)
if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
return real(sol1)
elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
return real(sol2)
elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
return real(sol3)
else
return NaN
end
end
function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
# Get this in the form ax^3 + bx^2 + cx + d = 0
a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
b = 3 * p[1] - 6 * p[2] + 3 * p[3]
c = 3 * p[2] - 3 * p[1]
d = p[1] - val
coeffs = [d, c, b, a]
if d >= 0.0 # (poly(0.0))
return lower
elseif a + b + c + d <= 0.0 # (poly(1.0))
return upper
end
x = solve_cubic_eq(Complex.(coeffs))
return x * (upper - lower) + lower
end
# Here, we take in a set of choice sets, each of with can be repeated N times.
function logit_probabilities!(inside_probabilities::Array{TV, 3}, outside_probabilities::Matrix{TV}, utilities::Array{TV, 3}, N::Vector{Int64}, σinv::TV) where TV
for uJ_index in axes(utilities, 2)
for bias_index in axes(utilities, 3)
this_exp_probs = @views exp.(utilities[:, uJ_index, bias_index] * σinv)
inv_denom = 1.0 / (1.0 + dot(this_exp_probs, N))
inside_probabilities[:, uJ_index, bias_index] .= inv_denom * this_exp_probs
outside_probabilities[uJ_index, bias_index] = inv_denom
end
end
return nothing
end
function partial_interpolate(x::AbstractArray{Float64, 1}, y::AbstractArray{TV, 1}, idx::Int64, val::Float64) where TV
z::TV = 0.0
if idx == 0
z = x[1]
elseif idx == length(x)
z = x[end]
else
z = ((y[idx+1] - y[idx]) * (val - x[idx]) / (x[idx+1] - x[idx])) + y[idx]
end
return z
end
function get_spec()::SpecFormat
special_banks::Vector{Int64} = [1, 14, 16]
home_bank::Int64 = 12
max_banks_in_choice_set::Int64 = 3
hermite_order::Int64 = 5
bernstein_order::Int64 = 4
initial_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
per_bank_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :s_n_branches]
ρ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
γ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
γ_std_X::Vector{Union{Symbol, String}} = [:ones]
κ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
initial_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
initial_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
final_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
final_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
spec::SpecFormat = SpecFormat(special_banks, home_bank, max_banks_in_choice_set, hermite_order, bernstein_order,
initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X, initial_mean_shifters, initial_std_shifters,
final_mean_shifters, final_std_shifters)
return spec
end
function get_indices(grid::AbstractArray{Float64, 1}, vals::AbstractArray{Float64, 1})::Vector{Int64}
indices = zeros(Int64, length(vals))
for (i, val) in enumerate(vals)
indices[i] = searchsortedlast(grid, val)
end
return indices
end
function get_param_indices(spec::SpecFormat)::Tuple{ParamIndices, SubParamIndices}
param_val::Int64 = 0
bank_params::Vector{Vector{Int64}} = []
for b = 1:(length(spec.special_banks)+1)
this_bank_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
push!(bank_params, this_bank_params)
end
jump_bid_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
implied_jump_bid_params::Vector{Int64}, param_val = param_update(spec.bernstein_order, param_val)
initial_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.initial_search_cost_X), param_val)
per_bank_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.per_bank_search_cost_X), param_val)
ρ_X::Vector{Int64}, param_val = param_update(length(spec.ρ_X), param_val)
γ_X::Vector{Int64}, param_val = param_update(length(spec.γ_X), param_val)
γ_std_X::Vector{Int64}, param_val = param_update(length(spec.γ_std_X), param_val)
κ_X::Vector{Int64}, param_val = param_update(length(spec.κ_X), param_val)
@assert(length(spec.initial_mean_shifters) == length(spec.final_mean_shifters))
@assert(length(spec.initial_std_shifters) == length(spec.final_std_shifters))
bank_mean_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_mean_shifters), param_val)
bank_std_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_std_shifters), param_val)
σinv_choice::Int64, param_val = param_update(1, param_val; return_as_vector = false)
σinv_start_search::Int64, param_val = param_update(1, param_val; return_as_vector = false)
# Now get the inidices: Will have one X matrix at the individual level (N_individuals x total parameters)
# and then indices for which subsets of this X to use for various parameters
all_X = sort!(union(spec.initial_search_cost_X, spec.ρ_X, spec.γ_X, spec.γ_std_X, spec.κ_X,
spec.initial_mean_shifters, spec.initial_std_shifters, spec.final_mean_shifters, spec.final_std_shifters)) # per-bank search cost will be at the individual-bank level
initial_search_cost_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_search_cost_X]
ρ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.ρ_X]
γ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_X]
γ_std_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_std_X]
κ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.κ_X]
initial_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_mean_shifters]
initial_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_std_shifters]
final_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_mean_shifters]
final_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_std_shifters]
param_indices = ParamIndices(bank_params, jump_bid_params, implied_jump_bid_params,
initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X,
bank_mean_shifters, bank_std_shifters,
initial_search_cost_X_indices, ρ_X_indices, γ_X_indices, γ_std_X_indices, κ_X_indices,
initial_mean_indices, initial_std_indices, final_mean_indices, final_std_indices,
σinv_choice, σinv_start_search, param_val)
sub_param_val::Int64 = 0
sub_κ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_γ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_γ_std::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_search_cost::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_ρ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_final_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_final_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_σinv_choice::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_σinv_start_search::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_per_bank_search_cost::Vector{Int64}, sub_param_val = param_update(length(spec.special_banks)+1, sub_param_val)
sub_bank_params::Vector{Vector{Int64}} = []
for b = 1:(length(spec.special_banks)+1)
this_bank_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
push!(sub_bank_params, this_bank_params)
end
sub_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
sub_implied_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.bernstein_order, sub_param_val)
sub_param_indices = SubParamIndices(sub_κ, sub_γ, sub_γ_std, sub_initial_search_cost, sub_ρ,
sub_initial_bank_μ, sub_initial_bank_σ, sub_final_bank_μ, sub_final_bank_σ,
sub_σinv_choice, sub_σinv_start_search, sub_per_bank_search_cost, sub_bank_params,
sub_jump_bid_params, sub_implied_jump_bid_params, sub_param_val)
return (param_indices, sub_param_indices)
end
function param_update(N::Int64, param_val::Int64; return_as_vector::Bool = true)
if N > 1 || return_as_vector
indices::Vector{Int64} = collect(1:N) .+ param_val
param_val += N
return indices, param_val
else
indices_int::Int64 = param_val + 1
param_val += 1
return indices_int, param_val
end
end
function get_ll_single(params::AbstractArray{TV, 1}, param_indices::SubParamIndices, uJ_grid::Vector{Float64}, base_nodes::SVector{9, Float64}, weights::SVector{9, Float64}
, u_current_idx::Int64, remaining_term::Float64, amount::Float64, chosen_bank::Int64, current_monthly_payment::Float64, final_monthly_payment::Float64,
search_type::Int64, choice_sets::Vector{Vector{Int64}}, choice_sets_N::Vector{Int64}, chosen_choice_set::Int64) where TV
# Initialize the values of the frictions -- this applies for everyone
κ::TV = params[param_indices.κ]
γ::TV = params[param_indices.γ]
γ_std::TV = params[param_indices.γ_std]
initial_search_cost::TV = params[param_indices.initial_search_cost]
ρ::TV = params[param_indices.ρ]
initial_bank_μ::TV = params[param_indices.initial_bank_μ]
initial_bank_σ::TV = params[param_indices.initial_bank_σ]
final_bank_μ::TV = params[param_indices.final_bank_μ]
final_bank_σ::TV = params[param_indices.final_bank_σ]
σinv_choice::TV = params[param_indices.σinv_choice]
σinv_start_search::TV = params[param_indices.σinv_start_search]
## Now the loop begins
# Note that the cdf of zero-profit rates is written in terms of monthly payments if A = 1, scaled by number of
# payments. So, a value of m corresponds to a monthly payment of m * A / T.
#
# We know that monthly utility = (delta - m * A / T) - kappa / NPVrate(T),
# So, Pr(utility <= u) = Pr((delta - m * A / T) - kappa/NPV <= u) = Pr(scaled monthly payment >= (T/A) * (delta - (u + kappa/NPV)))
# The bias is in terms of monthly payments. So, we just add the bias.
β_customer_m::Float64 = 0.95^(1/12)
per_bank_search_costs::Vector{TV} = params[param_indices.per_bank_search_cost]
npv_scale::Float64 = (1.0 - β_customer_m^remaining_term) / (1 - β_customer_m)
overall_scale::Float64 = remaining_term / amount
bias_grid = base_nodes * γ_std * sqrt(2) .+ γ # in units of dollars
scaled_κ::TV = κ / npv_scale # in units of dollars per month
# Generate the cdfs and pdfs for the banks: cdfs at the time of search, and both later if needed
# Maps utilities ($) to utilities ($) <--- could rethink whether this is the best map
implied_jump_bid::Function = bernstein3(params[param_indices.implied_jump_bid_params]; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
implied_jump_bid_inv(val::Float64) = bernstein3_inv(params[param_indices.implied_jump_bid_params], val; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
initial_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], initial_bank_μ, initial_bank_σ)(-overall_scale .* (x .+ scaled_κ)) for a in param_indices.bank_params]
initial_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
initial_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
initial_home_bank_cdf(x) = initial_jump_bid_cdf(implied_jump_bid.(x))
final_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
final_bank_pdf = [x -> overall_scale .* hermite5_pdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
final_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
final_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
final_home_bank_cdf(x) = final_jump_bid_cdf(implied_jump_bid.(x))
refi_at_home_bank::Bool = chosen_bank == 0
u_current::Float64 = -current_monthly_payment
u_final::TV = -final_monthly_payment - scaled_κ * refi_at_home_bank
choice_set_utilities::Array{TV, 3} = zeros(length(choice_sets), length(uJ_grid), length(bias_grid)) ### CONTAINER 3D-FLOAT
base_0::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
base_1::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
to_add::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
initial_bank_cdf_vals::Matrix{Float64} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
one_minus_initial_bank_cdf_vals::Matrix{TV} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
@inbounds for (bias_index, bias) in enumerate(bias_grid)
biased_uJ = uJ_grid .+ bias
for b = 1:length(initial_bank_cdf)
initial_bank_cdf_vals[b, :] .= initial_bank_cdf[b](biased_uJ)
one_minus_initial_bank_cdf_vals[b, :] .= @views 1.0 .- initial_bank_cdf_vals[b, :]
end
initial_home_bank_cdf_vals::Vector{Float64} = initial_home_bank_cdf(biased_uJ)
good_indices::BitVector = (initial_home_bank_cdf_vals[end] .- initial_home_bank_cdf_vals) .>= 1e-7
@inbounds for (choice_set_index, choice_set) in enumerate(choice_sets)
base_0 .= @views initial_bank_cdf_vals[choice_set[1], :]
for (b_idx, b) in enumerate(choice_set)
b_idx > 1 || continue
base_0 .*= @views initial_bank_cdf_vals[b, :]
end
cumtrapz_0 = cumtrapz_upper(uJ_grid, base_0)
base_1 .= 0.0
for (b_index, b) in enumerate(choice_set)
to_add .= @views one_minus_initial_bank_cdf_vals[b, :]
for (bprime_index, bprime) in enumerate(choice_set)
if bprime_index != b_index
to_add .*= @views initial_bank_cdf_vals[bprime, :]
end
end
base_1 .+= to_add
end
cumtrapz_1 = cumtrapz_upper(uJ_grid, base_1)
base_1 .*= initial_home_bank_cdf_vals
cumtrapz_1_with_home = cumtrapz_upper(uJ_grid, base_1)
next_part = base_0
for i in eachindex(next_part)
good_indices[i] ? next_part[i] = uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - (cumtrapz_1_with_home[i] - initial_home_bank_cdf_vals[i] * cumtrapz_1[i]) / (initial_home_bank_cdf_vals[end] - initial_home_bank_cdf_vals[i]) : uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - cumtrapz_1[i]
end
this_search_cost::TV = sum([per_bank_search_costs[b] for b in choice_set])
choice_set_utilities[choice_set_index, :, bias_index] .= @. next_part - this_search_cost
end
end
choice_set_probabilities::Array{TV, 3} = zeros(size(choice_set_utilities))
outside_option_probabilities::Array{TV, 2} = zeros(length(uJ_grid), length(bias_grid))
logit_probabilities!(choice_set_probabilities, outside_option_probabilities,
choice_set_utilities, choice_sets_N, σinv_choice)
benefit_of_search_uJ::Array{TV, 2} = dropdims(sum((choice_sets_N .* choice_set_probabilities) .* choice_set_utilities, dims = 1), dims = 1)
S_base::Vector{TV} = zeros(length(bias_grid))
for (bias_index, bias) in enumerate(bias_grid)
initial_jump_bid_pdf_vals = initial_jump_bid_pdf(uJ_grid .+ bias)
S_base[bias_index] = @views partial_cumtrapz(uJ_grid, benefit_of_search_uJ[1:(u_current_idx+1), bias_index] .* initial_jump_bid_pdf_vals[1:(u_current_idx+1)], u_current_idx, u_current) + initial_jump_bid_cdf(u_current + bias) *
partial_interpolate(uJ_grid, benefit_of_search_uJ[:, bias_index], u_current_idx, u_current)
end
pr_search::Vector{TV} = @. (exp(σinv_start_search * (S_base - initial_search_cost))) / (1.0 + (exp(σinv_start_search * (S_base - initial_search_cost)))) ### THIS ARE VECTORS OF SIZE |CHOICE SETS|=9
pr_no_search = 1 .- pr_search
likelihood = 1 - ρ + ρ * dot(pr_no_search, weights)
return log(likelihood)
end
function call_autodiff_example()
spec::SpecFormat = get_spec()
sub_param_indices::SubParamIndices = get_param_indices(spec)[2]
uJ_grid::Vector{Float64} = union(collect(0:0.05:2.0), [2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0]) # this will be a constant
uJ_grid = sort!(-uJ_grid)
base_nodes, base_weights = SVector{9}.(FastGaussQuadrature.gausshermite(9))
weights = SVector{9}(base_weights / sqrt(π))
params = [0.0, 0.0, 0.25, 0.0, 0.1, 1.2, 0.25, 1.2, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0]
u_current_idx = 44
remaining_term = 67.0
amount = 13.32
chosen_bank = -1
current_monthly_payment = 0.24
final_monthly_payment = NaN
search_type = 1
choice_sets = [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3], [1, 3, 4],
[1, 4], [1, 4, 4], [2], [2, 3], [2, 3, 4], [2, 4], [2, 4, 4],
[3], [3, 4], [3, 4, 4], [4], [4, 4], [4, 4, 4]]
choice_sets_N = [1, 1, 1, 15, 1, 15, 15, 105, 1, 1, 15, 15, 105, 1, 15, 105, 15, 105, 455]
chosen_choice_set = 0
dx = zeros(44)
single_ll(x) = get_ll_single(x, sub_param_indices, uJ_grid, base_nodes, weights, u_current_idx,
remaining_term, amount, chosen_bank, current_monthly_payment, final_monthly_payment, search_type, choice_sets, choice_sets_N, chosen_choice_set)
autodiff(Reverse, single_ll, Active, Duplicated(params, dx))
end
call_autofiff_example()
Then after some time it crashes given the following Stackstrace:
ERROR: LoadError: Enzyme execution failed.
Mismatched activity for: store {} addrspace(10)* %value_phi5253395, {} addrspace(10)** %.fca.1.gep, align 8, !dbg !1671, !noalias !216 const val: %value_phi5253395 = phi {} addrspace(10)* [ %arrayref514, %L1631.lr.ph ], [ %arrayref1084, %L3207 ]
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}
llvalue= %arrayref514 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %arrayptr5112966, align 8, !dbg !1096, !tbaa !1097, !alias.scope !197, !noalias !198
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] get_ll_single
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:1612
[2] get_ll_single
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
[3] *
@ ./float.jl:411 [inlined]
[4] trapz
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:87
[5] macro expansion
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6587 [inlined]
[6] enzyme_call
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6188 [inlined]
[7] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6065 [inlined]
[8] autodiff
@ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:309 [inlined]
[9] autodiff
@ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
[10] call_autofiff_example()
@ Main ~/Documents/refinancing/multistep/example_Enzyme.jl:592
[11] top-level scope
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:596
[12] include(fname::String)
@ Base.MainInclude ./client.jl:489
[13] top-level scope
@ REPL[2]:1
in expression starting at /Users/miguelborrero/Documents/refinancing/multistep/example_Enzyme.jl:596
I would really appreciate some help since it is important for me to make this work so thanks a lot in advance!