I am training a neural network by computing the gradient of the loss with Enzyme. Part of my computations involve associated Legendre polynomials. I did not believe these to be a crucial part of the differentiation because they are not a function of the neural network parameters; however, by profiling my code to find performance hotspots, I discovered that most of my compute time was spent differentiating through the Legendre polynomials. I tried moving these computations outside the training loop and accessing the values in a Dict, but that lead to other issues. I’d like to know the simplest way to prevent Enzyme from differentiating through the Legendre polynomials.
# MWE preventing Enzyme from differentiating through Legendre polynomials
begin
using Lux
using Enzyme
using ComponentArrays
using Random
using Statistics
using AssociatedLegendrePolynomials
using BenchmarkTools
using Profile
using ProfileView
Random.seed!(1234)
end
begin # Functions
function build_model(n_in, n_out, n_layers, n_nodes;
act_fun=leakyrelu, last_fun=relu)
# Input layer
first_layer = Lux.Dense(n_in, n_nodes, act_fun)
# Hidden block
hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)
# Output layer
last_layer = Lux.Dense(n_nodes => n_out, last_fun)
return Chain(first_layer, hidden..., last_layer)
end
function eval_model(m, params, st, x)
M = first(m(x, params, st))
return M
end
function recursive_convert(T, x)
if x isa AbstractArray
return convert.(T, x) # elementwise convert
elseif x isa NamedTuple
return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
else
return x
end
end
function lossDiff(p, args)
y = args[2]
dsigs = calculateMultiDiffCrossSections(p, args)
return mean((y - dsigs).^2)
end
function calculateDifferentialCrossSection(U, theta)
n = length(theta)
ret = zeros(n)
for i in 1:n
ret += sum(U)*cos.(theta).*AssociatedLegendrePolynomials.Plm(2, 1, cos(theta[i]*pi/180))
end
return ret
end
function calculateMultiDiffCrossSections(p, args)
X = args[1]
thetas = args[3]
M = args[4]
st = args[5]
nlen = 2
rlen = 100
datalen = 0
for i in 1:nlen
datalen += length(thetas[i])
end
dσ = zeros(eltype(X), datalen)
j = 1
for i in range(1,nlen)
exp_len = length(thetas[i])
sig = zeros(eltype(X), exp_len)
j_next = j + exp_len
U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
sig = calculateDifferentialCrossSection(U, thetas[i])
dσ[j:j_next-1] = sig
j = j_next
end
return dσ
end
end
data_type = Float32
# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
[1.0, 10.0, 20.0, 45.0, 60],
[5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)
# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = ComponentArray(recursive_convert(data_type, ps))
const _st = st
args = (X_train, XSdiff_train, theta_train, model, _st)
# Test loss function evaluation
losstest = lossDiff(p, args)
# losstest
dl_dp(p) = Enzyme.jacobian(Reverse, p -> lossDiff(p, args),p)
@btime dl_dp(p)