Hi,
I hope this finds you well. I recently moved over from Matlab in search of better computational performance. I am doing this post to get best practices on higher-order approximations for large problems such as the one in the title.
The current problem is to solve a multi-country trade model using bilateral trade data for 44 countries. Currently, we have a simple Armington model setup, but in the future, we would like to extend the model to allow for multiple industries and within-country geographies.
As a first approach, I tried automatic differentiation with ForwardDiff.jl. The results were very slow and caused StackOverFlowError at the 4th order approximation. The codes made a very large number of allocations. So I attributed this low performance to the ForwardDiff.jl package and moved on to the next method. This MWE can also be posted if needed.
Next, I tried numerical differentiation with FiniteDifferences.jl as a more intuitive way to obtain derivatives. I have already checked that even Matlab can handle the approximation up to 3rd order, so I thought this would not cause any speed problems. The official page of FiniteDifferences.jl also includes the following information: “FiniteDifferences.jl supports higher-order approximation and step size adaptation” and “Finite difference methods are optimized to minimize allocations,” so I was optimistic about the results.
But to my surprise, my Julia code was very slow at the second-order approximation and not practical. Again, a very large number of allocations were made, and I imagine this may have affected performance. Below is the MWE, and I also attached the results of @profview when I tried the 2nd-floor approximation using FiniteDifferences.jl.
Based on these circumstances, my current expectation is that the performance of my code may be slowed down due to the internal behavior of these packages, regardless of how the differentiation is performed (ForwardDiff.jl or FiniteDifferences.jl). How can I quickly get a higher-order approximation of the title problem? Any help is highly appreciated.
Best,
Daisuke
#---------------------------------------
# housekeeping
#---------------------------------------
using FiniteDifferences, Profile, LinearAlgebra
# function about equilibrium condition (excess labor demand)
function eha_wzet_combined(wzet,y,x,sig,N)
# define the matrix dimension
tau_mat = vcat(ones(N-1,N), [wzet[N+1]*ones(N-1); 1]');
# armington condition
excess = wzet[1:N] -
sum(
y .* ((tau_mat .* wzet[1:N]) .^ (1-sig) ./
sum( x .* (tau_mat .* wzet[1:N]) .^ (1-sig), dims = 1 )) .*
wzet[1:N]',
dims = 2
)[:,1];
return excess
end
# function to compute the sum of 2nd order derivatives
function sum_dDyFdyj_D11(wzet,N,ddw,Dg1)
out = ddw[1](wzet)[1:N-1,1:N-1] * Dg1(wzet)[1]
if N > 2
for j in 2:N-1
out += ddw[j](wzet)[1:N-1,1:N-1] * Dg1(wzet)[j];
end
end
return out
end
#---------------------------------------
# parameters and data
#---------------------------------------
N = 44; # number of countries
sig = 4; # elasticity of substitution
homebias = 0.9; # the size of the domestic consumption tendency
X = homebias * Matrix{Float64}(I,N,N) + (1-homebias) * 1/N * ones(N,N);
y = X./sum(X,dims=2);
x = X./sum(X,dims=1);
wzet = ones(N+1); # initial point
#---------------------------------------
# 1st order solution
#---------------------------------------
ed(wzet) = eha_wzet_combined(wzet,y,x,sig,N);
dfwzet(wzet) = jacobian(central_fdm(5,1), ed, wzet)[1];
Dg1(wzet) = - dfwzet(wzet)[1:N-1,1:N-1] \ dfwzet(wzet)[1:N-1,N+1];
#---------------------------------------
# 2nd order solution
#---------------------------------------
dfzet = wzet -> dfwzet(wzet)[:,N+1];
ddfzwzet(wzet) = jacobian(central_fdm(5,1), dfzet, wzet)[1];
ddw = Array{Function}(undef, N-1)
for j in 1:N-1
dfwj = wzet -> dfwzet(wzet)[:,j];
ddw[j] = wzet -> jacobian(central_fdm(5,1), dfwj, wzet)[1]
end
H2(wzet) = ( sum_dDyFdyj_D11(wzet,N,ddw,Dg1) + ddfzwzet(wzet)[1:N-1,1:N-1] ) * Dg1(wzet) + ddfzwzet(wzet)[1:N-1,1:N-1] * Dg1(wzet) + ddfzwzet(wzet)[1:N-1,N+1];
Dg2(wzet) = - dfwzet(wzet)[1:N-1,1:N-1] \ H2(wzet);
Dg2(ones(N+1)) # run once to force compilation
@time Dg2(ones(N+1))
#@profview Dg2(ones(N+1))