Hi gdalle,
Sure, here is the simplest example I could think of. The idea is:
- Define a trainable smoother based on a Jacobi smoother multiplied by a parameter \alpha.
- Define variables for the problem (grid size and initial \alpha).
- Initialize the smoother and AMG (ruge_stuben) with my smoother (for pre- and post-smoothing).
- Define loss function for one AMG V-Cycle based on final residual.
- Create a loss function for Enzyme using AMG loss and run Enzyme.gradient based on that (only compute gradients for dLoss/d \alpha and not
ml, b
).
Thanks for your time!
using AlgebraicMultigrid
using SparseArrays
using LinearAlgebra
using Enzyme # Import Enzyme
# 1 Define the Trainable Damped Jacobi Smoother
mutable struct TrainableSmoother <: AlgebraicMultigrid.Smoother
alpha::Float64 # Damping factor (trainable)
iter::Int # Number of smoothing iterations per application
end
TrainableSmoother(alpha::Float64 = 1.0; iter::Int=1) = TrainableSmoother(alpha, iter)
# Make the smoother equivalent to AlgebraicMultigrid SMOOTHERS
function (smoother::TrainableSmoother)(A::SparseMatrixCSC{Tv, Ti}, x::AbstractVector{Tv}, b::AbstractVector{Tv}) where {Tv, Ti}
LinearAlgebra.checksquare(A)
length(x) == size(A, 1) == length(b) || throw(DimensionMismatch())
# Adding epsilon for numerical stability in case of near-zero diagonal entries
inv_diag = 1.0 ./ (diag(A) .+ eps(Tv))
for _ in 1:smoother.iter
r = b - A * x
x .+= smoother.alpha .* inv_diag .* r
end
return nothing
end
function (smoother::TrainableSmoother)(A::SparseMatrixCSC{Tv, Ti}, x::AbstractMatrix{Tv}, b::AbstractMatrix{Tv}) where {Tv, Ti}
LinearAlgebra.checksquare(A)
size(A, 1) == size(x, 1) == size(b, 1) || throw(DimensionMismatch("Matrix and vectors dimensions must match"))
size(x, 2) == size(b, 2) || throw(DimensionMismatch("x and b must have the same number of columns"))
inv_diag = 1.0 ./ (diag(A) .+ eps(Tv)) # Add epsilon for stability
for _ in 1:smoother.iter
@inbounds for col in 1:size(x, 2)
xc = view(x, :, col)
bc = view(b, :, col)
r = bc - A * xc
xc .+= smoother.alpha .* inv_diag .* r
end
end
return nothing
end
# "Poisson" matrix
function poisson_matrix(N::Int)
if N <= 0 error("N must be positive") end
T = sparse(Tridiagonal(fill(-1.0, N-1), fill(2.0, N), fill(-1.0, N-1)))
Id = sparse(I, N, N)
A = kron(Id, T) + kron(T, Id)
return A
end
# 2 Setup the Problem
println("\n ***************** TEST GRADIENTS (USING ENZYME - CORRECTED v3) ***************** \n")
n = 3 # Size of "grid"
A = poisson_matrix(n);
N_total = size(A, 1)
b = rand(N_total)
ALPHA = 1.0 # Starting alpha for the smoother
# 3 Initiate Smoother and Build AMG Hierarchy
my_smoother = TrainableSmoother(ALPHA, iter=3)
println("Building AMG hierarchy using the custom TrainableSmoother...")
ml = ruge_stuben(A,
presmoother = my_smoother,
postsmoother = my_smoother,
max_levels = 10,
max_coarse = 10)
println("Hierarchy built:")
println(ml)
# 4 Define Loss Function Using the Library's `_solve` (with more iterations)
function loss_function_amg_solve(ml_obj, target_b; solve_iters=1) # Default to 1 iteration
# Run for one iteration and disable residual
x_final = AlgebraicMultigrid._solve(ml_obj, target_b, maxiter=solve_iters, calculate_residual=false)
if isempty(ml_obj.levels) && size(ml_obj.final_A,1) == 0
error("AMG hierarchy appears empty.")
end
# Get finest matrix A
finest_A = isempty(ml_obj.levels) ? ml_obj.final_A : ml_obj.levels[1].A
# Calculate residual *after* the solve
residual = target_b - finest_A * x_final
loss_val = norm(residual)
return loss_val
end
# 5 Test Initial Loss and Calculate Gradient with Enzyme
# Calculate initial loss using the modified function
initial_loss = loss_function_amg_solve(ml, b, solve_iters=1)
println("Initial Alpha: ", my_smoother.alpha)
println("Initial Loss: ", initial_loss)
# Check if initial loss is finite before proceeding
if !isfinite(initial_loss)
error("Initial loss is non-finite ($initial_loss).")
end
println("\nEnzyme gradient calculation through AlgebraicMultigrid._solve...")
# Enzyme Wrapper Function
function loss_wrapper_for_enzyme(alpha_value, ml_obj, target_b)
if !(ml_obj.presmoother isa TrainableSmoother)
error("ml_obj.presmoother is not a TrainableSmoother instance.")
end
smoother_instance = ml_obj.presmoother
original_alpha = smoother_instance.alpha
try
# Temporarily set the alpha value
smoother_instance.alpha = alpha_value
# Call the loss function (which now runs more iterations)
loss = loss_function_amg_solve(ml_obj, target_b, solve_iters=1) # Ensure consistency
return loss
finally
# Restore the original alpha value
smoother_instance.alpha = original_alpha
end
end
# Store the initial alpha value to pass to Enzyme
alpha_current = my_smoother.alpha
try
# Call Enzyme.gradient
gradient_result = Enzyme.gradient(Enzyme.Reverse, loss_wrapper_for_enzyme, alpha_current, Const(ml), Const(b))
# Print the type and value of the result
println("Type returned by Enzyme.gradient: ", typeof(gradient_result))
println("Value returned by Enzyme.gradient: ", gradient_result)
# Explicitly handle the tuple structure
alpha_grad_enzyme = nothing # Initialize
if gradient_result isa Tuple && !isempty(gradient_result) && length(gradient_result) > 0
println("Result was a tuple, extracting first element.")
alpha_grad_enzyme = gradient_result[1] # Assume the first element is the gradient
else
println("Result was not a tuple (or was empty), using directly.")
alpha_grad_enzyme = gradient_result # Assume it's the scalar gradient or Nothing
end
# Check the extracted gradient
# Check if it's a Real number AND finite
if alpha_grad_enzyme !== nothing && isa(alpha_grad_enzyme, Real) && isfinite(alpha_grad_enzyme)
println("Enzyme gradient calculation successful!")
println("Gradient w.r.t alpha: ", alpha_grad_enzyme)
###
### More code that does not add anything ...
###
And the output is:
***************** TEST GRADIENTS (USING ENZYME - CORRECTED v3) *****************
Building AMG hierarchy using the custom TrainableSmoother...
Hierarchy built:
Multilevel Solver
-----------------
Operator Complexity: 1.0
Grid Complexity: 1.0
No. of Levels: 1
Coarse Solver: Pinv
Level Unknowns NonZeros
----- -------- --------
1 9 33 [100.00%]
Initial Alpha: 1.0
Initial Loss: 2.4986185198476507e-15
Enzyme gradient calculation through AlgebraicMultigrid._solve...
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler C:\Users\eric_\.julia\packages\GPUCompiler\2MI6e\src\utils.jl:61
Type returned by Enzyme.gradient: Tuple{Float64, Nothing, Nothing}
Value returned by Enzyme.gradient: (0.0, nothing, nothing)
Result was a tuple, extracting first element.
Enzyme gradient calculation successful!
Gradient w.r.t alpha: 0.0
Updated Alpha (example): 1.0
Finished.