Preventing Enzyme from differentiating through constant computations

I am training a neural network by computing the gradient of the loss with Enzyme. Part of my computations involve associated Legendre polynomials. I did not believe these to be a crucial part of the differentiation because they are not a function of the neural network parameters; however, by profiling my code to find performance hotspots, I discovered that most of my compute time was spent differentiating through the Legendre polynomials. I tried moving these computations outside the training loop and accessing the values in a Dict, but that lead to other issues. I’d like to know the simplest way to prevent Enzyme from differentiating through the Legendre polynomials.

# MWE preventing Enzyme from differentiating through Legendre polynomials
begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Statistics
    using AssociatedLegendrePolynomials
    using BenchmarkTools
    using Profile
    using ProfileView
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        n = length(theta)
        ret = zeros(n)
        for i in 1:n
            ret += sum(U)*cos.(theta).*AssociatedLegendrePolynomials.Plm(2, 1, cos(theta[i]*pi/180))
        end
        return ret
    end

    function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = 0
        for i in 1:nlen
            datalen += length(thetas[i])
        end
        dσ = zeros(eltype(X), datalen)
        j = 1
        for i in range(1,nlen)
            exp_len = length(thetas[i])
            sig = zeros(eltype(X), exp_len)
            j_next = j + exp_len
            U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
            sig = calculateDifferentialCrossSection(U, thetas[i])
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end


data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = ComponentArray(recursive_convert(data_type, ps))
const _st = st

args = (X_train, XSdiff_train, theta_train, model, _st)

# Test loss function evaluation
losstest = lossDiff(p, args)
# losstest
dl_dp(p) = Enzyme.jacobian(Reverse, p -> lossDiff(p, args),p)
@btime dl_dp(p)

Your theta if going through it and you want the gradient with respect to it so I guess it should be there, you can however write your own reverse pass even though it’s hard in Enzyme ( or write a chain rules one and import it to Enzyme which may be a lot easier) there is also a function in Enzyme to tell it that something should be considered constant but I’m quite sure it’s not what you want.

Oh I just realised is it going through the derivation of the légender polynomial ? In this case just write a struct with the polynomial coef and your own function to evaluate it will be mush easier it’s really weird that AssociatedLegender doesn’t provide a way to cache that though I would check again.