Automatic Differentiation for existing AMG libraries

Hi all,

I recently joined Julia for my MSc thesis and was having some doubts about the Automatic Differentiation tools (Enzyme, Zygote, JuliaDiff, etc.).

In short, I wanted to use existing linear algebra libraries containing Algebraic Multigrid solvers so that I can compute gradients over a whole V- or W-Cycle, instead of building a custom AMG solver (which I could do in PyTorch, that’s why I wanted to move to Julia), fully differentiable. One that wasn’t a wrapper for C code, such as PETSc.jl, is the AlgebraicMultigrid.jl package.

I have been messing around with it and Enzyme with a really simple code, but I seem to get zero values for the gradients of the variables I want to track (e.g. Alpha), which makes me think that the copyto! and similar functions (which are not differentiable) make the chain rule break, so I cannot obtain dLoss/dAlpha.

Do you have any useful information about an approach of this kind? Is there a key element that I might be missing?

Thanks!

Hi, welcome to the community!
Can you share a minimum working example, as simple as possible? It would help a lot to debug your code, and figure out where the limitation lies.

Hi gdalle,

Sure, here is the simplest example I could think of. The idea is:

  1. Define a trainable smoother based on a Jacobi smoother multiplied by a parameter \alpha.
  2. Define variables for the problem (grid size and initial \alpha).
  3. Initialize the smoother and AMG (ruge_stuben) with my smoother (for pre- and post-smoothing).
  4. Define loss function for one AMG V-Cycle based on final residual.
  5. Create a loss function for Enzyme using AMG loss and run Enzyme.gradient based on that (only compute gradients for dLoss/d \alpha and not ml, b).

Thanks for your time!

using AlgebraicMultigrid
using SparseArrays
using LinearAlgebra
using Enzyme # Import Enzyme

# 1 Define the Trainable Damped Jacobi Smoother
mutable struct TrainableSmoother <: AlgebraicMultigrid.Smoother
    alpha::Float64 # Damping factor (trainable)
    iter::Int    # Number of smoothing iterations per application
end

TrainableSmoother(alpha::Float64 = 1.0; iter::Int=1) = TrainableSmoother(alpha, iter)

# Make the smoother equivalent to AlgebraicMultigrid SMOOTHERS
function (smoother::TrainableSmoother)(A::SparseMatrixCSC{Tv, Ti}, x::AbstractVector{Tv}, b::AbstractVector{Tv}) where {Tv, Ti}
    LinearAlgebra.checksquare(A)
    length(x) == size(A, 1) == length(b) || throw(DimensionMismatch())
    # Adding epsilon for numerical stability in case of near-zero diagonal entries
    inv_diag = 1.0 ./ (diag(A) .+ eps(Tv))
    for _ in 1:smoother.iter
        r = b - A * x
        x .+= smoother.alpha .* inv_diag .* r
    end
    return nothing
end

function (smoother::TrainableSmoother)(A::SparseMatrixCSC{Tv, Ti}, x::AbstractMatrix{Tv}, b::AbstractMatrix{Tv}) where {Tv, Ti}
    LinearAlgebra.checksquare(A)
    size(A, 1) == size(x, 1) == size(b, 1) || throw(DimensionMismatch("Matrix and vectors dimensions must match"))
    size(x, 2) == size(b, 2) || throw(DimensionMismatch("x and b must have the same number of columns"))
    inv_diag = 1.0 ./ (diag(A) .+ eps(Tv)) # Add epsilon for stability
    for _ in 1:smoother.iter
        @inbounds for col in 1:size(x, 2)
            xc = view(x, :, col)
            bc = view(b, :, col)
            r = bc - A * xc
            xc .+= smoother.alpha .* inv_diag .* r
        end
    end
    return nothing
end

# "Poisson" matrix
function poisson_matrix(N::Int)
    if N <= 0 error("N must be positive") end
    T = sparse(Tridiagonal(fill(-1.0, N-1), fill(2.0, N), fill(-1.0, N-1)))
    Id = sparse(I, N, N)
    A = kron(Id, T) + kron(T, Id)
    return A
end

# 2 Setup the Problem 
println("\n ***************** TEST GRADIENTS (USING ENZYME - CORRECTED v3) ***************** \n")
n = 3 # Size of "grid"
A = poisson_matrix(n);
N_total = size(A, 1)
b = rand(N_total) 
ALPHA = 1.0 # Starting alpha for the smoother

# 3 Initiate Smoother and Build AMG Hierarchy
my_smoother = TrainableSmoother(ALPHA, iter=3)

println("Building AMG hierarchy using the custom TrainableSmoother...")
ml = ruge_stuben(A,
                 presmoother = my_smoother,
                 postsmoother = my_smoother,
                 max_levels = 10,
                 max_coarse = 10)

println("Hierarchy built:")
println(ml)

# 4 Define Loss Function Using the Library's `_solve` (with more iterations)
function loss_function_amg_solve(ml_obj, target_b; solve_iters=1) # Default to 1 iteration
    # Run for one iteration and disable residual
    x_final = AlgebraicMultigrid._solve(ml_obj, target_b, maxiter=solve_iters, calculate_residual=false)

    if isempty(ml_obj.levels) && size(ml_obj.final_A,1) == 0
         error("AMG hierarchy appears empty.")
    end
    # Get finest matrix A
    finest_A = isempty(ml_obj.levels) ? ml_obj.final_A : ml_obj.levels[1].A

    # Calculate residual *after* the solve
    residual = target_b - finest_A * x_final
    loss_val = norm(residual)

    return loss_val
end

# 5 Test Initial Loss and Calculate Gradient with Enzyme

# Calculate initial loss using the modified function
initial_loss = loss_function_amg_solve(ml, b, solve_iters=1)
println("Initial Alpha: ", my_smoother.alpha)
println("Initial Loss: ", initial_loss)

# Check if initial loss is finite before proceeding
if !isfinite(initial_loss)
    error("Initial loss is non-finite ($initial_loss).")
end

println("\nEnzyme gradient calculation through AlgebraicMultigrid._solve...")

# Enzyme Wrapper Function
function loss_wrapper_for_enzyme(alpha_value, ml_obj, target_b)
    if !(ml_obj.presmoother isa TrainableSmoother)
        error("ml_obj.presmoother is not a TrainableSmoother instance.")
    end
    smoother_instance = ml_obj.presmoother
    original_alpha = smoother_instance.alpha
    try
        # Temporarily set the alpha value
        smoother_instance.alpha = alpha_value
        # Call the loss function (which now runs more iterations)
        loss = loss_function_amg_solve(ml_obj, target_b, solve_iters=1) # Ensure consistency
        return loss
    finally
        # Restore the original alpha value
        smoother_instance.alpha = original_alpha
    end
end

# Store the initial alpha value to pass to Enzyme
alpha_current = my_smoother.alpha

try
    # Call Enzyme.gradient
    gradient_result = Enzyme.gradient(Enzyme.Reverse, loss_wrapper_for_enzyme, alpha_current, Const(ml), Const(b))

    # Print the type and value of the result
    println("Type returned by Enzyme.gradient: ", typeof(gradient_result))
    println("Value returned by Enzyme.gradient: ", gradient_result)

    #  Explicitly handle the tuple structure
    alpha_grad_enzyme = nothing # Initialize
    if gradient_result isa Tuple && !isempty(gradient_result) && length(gradient_result) > 0
        println("Result was a tuple, extracting first element.")
        alpha_grad_enzyme = gradient_result[1] # Assume the first element is the gradient
    else
        println("Result was not a tuple (or was empty), using directly.")
        alpha_grad_enzyme = gradient_result # Assume it's the scalar gradient or Nothing
    end

    # Check the extracted gradient
    # Check if it's a Real number AND finite
    if alpha_grad_enzyme !== nothing && isa(alpha_grad_enzyme, Real) && isfinite(alpha_grad_enzyme)
        println("Enzyme gradient calculation successful!")
        println("Gradient w.r.t alpha: ", alpha_grad_enzyme)

###
### More code that does not add anything ...
###

And the output is:

 ***************** TEST GRADIENTS (USING ENZYME - CORRECTED v3) ***************** 

Building AMG hierarchy using the custom TrainableSmoother...
Hierarchy built:
Multilevel Solver
-----------------
Operator Complexity: 1.0
Grid Complexity: 1.0
No. of Levels: 1
Coarse Solver: Pinv
Level     Unknowns     NonZeros
-----     --------     --------
    1            9           33 [100.00%]

Initial Alpha: 1.0
Initial Loss: 2.4986185198476507e-15

Enzyme gradient calculation through AlgebraicMultigrid._solve...
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler C:\Users\eric_\.julia\packages\GPUCompiler\2MI6e\src\utils.jl:61
Type returned by Enzyme.gradient: Tuple{Float64, Nothing, Nothing}
Value returned by Enzyme.gradient: (0.0, nothing, nothing)
Result was a tuple, extracting first element.
Enzyme gradient calculation successful!
Gradient w.r.t alpha: 0.0
Updated Alpha (example): 1.0

Finished.

Differentiation of a linear solve is just a linear solve on the duals. If you use LinearSolve.jl then that should already be specialized with Enzyme and Zygote (and we’ll add the ForwardDiff overloads hopefully soon).