Linear problem with multiple shifts

I want to solve (A - c_i) * x = b for multiple shifts c_i. A is a Hermitian matrix, where A*x is efficient than obtaining A directly. A - c_i * I is positive definite. I know that one can use Krylov subspace methods (Lanczos iteration) to solve this problem for multiple shifts efficiently.

Is there a Julia package where I can easily solve this problem? Or is there an alternative, better algorithm implemented in Julia?

Krilov.jl seems to address this. See the example provided.

2 Likes

It seems that cg_lanczos of Krylov.jl suppors only real-valued vectors (Complex linear systems · Issue #103 · JuliaSmoothOptimizers/Krylov.jl · GitHub), while I need Hermitian matrices and complex-valued vectors.

A slight modification of the function makes this possible. I share the code here for anyone interested.

using Krylov: allocate_if, @kdot, krylov_dot, @kscal!, krylov_scal!, @kaxpy!, krylov_axpy!, @kaxpby!, krylov_axpby!

"cg_lanczos of Krylov.jl extended to complex vectors"
function ccg_lanczos(A, b :: AbstractVector{Complex{T}}, shifts :: AbstractVector{T}; kwargs...) where T <: AbstractFloat
    nshifts = length(shifts)
    solver = CgLanczosShiftSolver(A, b, nshifts)
    ccg_lanczos!(solver, A, b, shifts; kwargs...)
    return (solver.x, solver.stats)
end

function ccg_lanczos!(solver :: CgLanczosShiftSolver{CT,S}, A, b :: AbstractVector{CT}, shifts :: AbstractVector{T};
                      M=I, atol :: T=√eps(T), rtol :: T=√eps(T), itmax :: Int=0,
                      check_curvature :: Bool=false, verbose :: Int=0, history :: Bool=false) where {T <: AbstractFloat, CT <: Number, S <: DenseVector{CT}}
    n, m = size(A)
    m == n || error("System must be square")
    length(b) == n || error("Inconsistent problem size")

    nshifts = length(shifts)
    (verbose > 0) && @printf("CG Lanczos: system of %d equations in %d variables with %d shifts\n", n, n, nshifts)

    # Tests M == Iₙ
    MisI = (M == I)

    # Check type consistency
    eltype(A) == CT || error("eltype(A) ≠ $CT")
    typeof(b) == S || error("typeof(b) ≠ $S")
    MisI || (eltype(M) == T) || error("eltype(M) ≠ $T")

    # Set up workspace.
    allocate_if(!MisI, solver, :v, S, n)
    Mv, Mv_prev, Mv_next = solver.Mv, solver.Mv_prev, solver.Mv_next
    x, p, σ, δhat = solver.x, solver.p, solver.σ, solver.δhat
    ω, γ, rNorms, converged = solver.ω, solver.γ, solver.rNorms, solver.converged
    not_cv, stats = solver.not_cv, solver.stats
    rNorms_history, indefinite = stats.residuals, stats.indefinite
    Krylov.reset!(stats)
    v = MisI ? Mv : solver.v

    # Initial state.
    ## Distribute x similarly to shifts.
    for i = 1 : nshifts
        x[i] .= zero(T)                       # x₀
    end
    Mv .= b                                   # Mv₁ ← b
    MisI || mul!(v, M, Mv)                    # v₁ = M⁻¹ * Mv₁
    β = sqrt(@kdot(n, v, Mv))                 # β₁ = v₁ᵀ M v₁
    rNorms .= β
    if history
        for i = 1 : nshifts
            push!(rNorms_history[i], rNorms[i])
        end
    end

    # Keep track of shifted systems with negative curvature if required.
    indefinite .= false

    if β == 0
        stats.solved = true
        stats.status = "x = 0 is a zero-residual solution"
        return solver
    end

    # Initialize each p to v.
    for i = 1 : nshifts
        p[i] .= v
    end

    # Initialize Lanczos process.
    # β₁Mv₁ = b
    @kscal!(n, one(T)/β, v)          # v₁  ←  v₁ / β₁
    MisI || @kscal!(n, one(T)/β, Mv) # Mv₁ ← Mv₁ / β₁
    Mv_prev .= Mv

    # Initialize some constants used in recursions below.
    ρ = one(T)
    σ .= β
    δhat .= zero(T)
    ω .= zero(T)
    γ .= one(T)

    # Define stopping tolerance.
    ε = atol + rtol * real(β)

    # Keep track of shifted systems that have converged.
    for i = 1 : nshifts
        converged[i] = real(rNorms[i]) ≤ ε
        not_cv[i] = !converged[i]
    end
    iter = 0
    itmax == 0 && (itmax = 2 * n)

    # Build format strings for printing.
    if Krylov.display(iter, verbose)
        fmt = "%5d" * repeat("  %8.1e", nshifts) * "\n"
        # precompile printf for our particular format
        local_printf(data...) = Core.eval(Main, :(@printf($fmt, $(data)...)))
        local_printf(iter, real(rNorms)...)
    end

    solved = sum(not_cv) == 0
    tired = iter ≥ itmax
    status = "unknown"

    # Main loop.
    while ! (solved || tired)
        # Form next Lanczos vector.
        # βₖ₊₁Mvₖ₊₁ = Avₖ - δₖMvₖ - βₖMvₖ₋₁
        mul!(Mv_next, A, v)                  # Mvₖ₊₁ ← Avₖ
        δ = @kdot(n, v, Mv_next)             # δₖ = vₖᵀ A vₖ
        @kaxpy!(n, -δ, Mv, Mv_next)          # Mvₖ₊₁ ← Mvₖ₊₁ - δₖMvₖ
        if iter > 0
            @kaxpy!(n, -β, Mv_prev, Mv_next) # Mvₖ₊₁ ← Mvₖ₊₁ - βₖMvₖ₋₁
            @. Mv_prev = Mv                  # Mvₖ₋₁ ← Mvₖ
        end
        @. Mv = Mv_next                      # Mvₖ ← Mvₖ₊₁
        MisI || mul!(v, M, Mv)               # vₖ₊₁ = M⁻¹ * Mvₖ₊₁
        β = sqrt(@kdot(n, v, Mv))            # βₖ₊₁ = vₖ₊₁ᵀ M vₖ₊₁
        @kscal!(n, one(T)/β, v)              # vₖ₊₁  ←  vₖ₊₁ / βₖ₊₁
        MisI || @kscal!(n, one(T)/β, Mv)     # Mvₖ₊₁ ← Mvₖ₊₁ / βₖ₊₁

        # Check curvature: vₖᵀ(A + sᵢI)vₖ = vₖᵀAvₖ + sᵢ‖vₖ‖² = δₖ + ρₖ * sᵢ with ρₖ = ‖vₖ‖².
        # It is possible to show that σₖ² (δₖ + ρₖ * sᵢ - ωₖ₋₁ / γₖ₋₁) = pₖᵀ (A + sᵢ I) pₖ.
        MisI || (ρ = @kdot(n, v, v))
        for i = 1 : nshifts
            δhat[i] = δ + ρ * shifts[i]
            γ[i] = 1 / (δhat[i] - ω[i] / γ[i])
            indefinite[i] |= real(γ[i]) ≤ 0
        end

        # Compute next CG iterate for each shifted system that has not yet converged.
        # Stop iterating on indefinite problems if requested.
        for i = 1 : nshifts
            not_cv[i] = check_curvature ? !(converged[i] || indefinite[i]) : !converged[i]
            if not_cv[i]
                @kaxpy!(n, γ[i], p[i], x[i])
                ω[i] = β * γ[i]
                σ[i] *= -ω[i]
                ω[i] *= ω[i]
                @kaxpby!(n, σ[i], v, ω[i], p[i])

                # Update list of systems that have not converged.
                rNorms[i] = abs(σ[i])
                converged[i] = real(rNorms[i]) ≤ ε
            end
        end

        if length(not_cv) > 0 && history
            for i = 1 : nshifts
                not_cv[i] && push!(rNorms_history[i], rNorms[i])
            end
        end

        # Is there a better way than to update this array twice per iteration?
        for i = 1 : nshifts
            not_cv[i] = check_curvature ? !(converged[i] || indefinite[i]) : !converged[i]
        end
        iter = iter + 1
        Krylov.display(iter, verbose) && local_printf(iter, rNorms...)

        solved = sum(not_cv) == 0
        tired = iter ≥ itmax
    end
    (verbose > 0) && @printf("\n")

    status = tired ? "maximum number of iterations exceeded" : "solution good enough given atol and rtol"

    # Update stats. TODO: Estimate Anorm and Acond.
    stats.solved = solved
    stats.status = status
    return solver
end
2 Likes

Have you also checked KrylovKit.jl?
Please read here.

1 Like

Thank you a lot for pointing out relevant packages. I did read that part of KrylovKit documentation, but I did not found any code for what I need (solving linear equation with multiple shiftsusing CG by Lanczos).

Maybe I can use the Lanczos iteration therein to implement a solver (i.e. implement convergence check, initialization, etc.) by myself. But I want to avoid that because I would like to leverage upon existing functionalities as much as possible.

@Jae-Mo_Lihm, do you have to solve often complex linear systems ?
It’s just by curiosity because I don’t see a lot of user that request the support of complex numbers.

I should add it before the release 1.0. :slight_smile:

Yes, I do solve complex linear systems often, so the functionality will be very useful ! :slight_smile:

The context is perturbation theory of quantum mechanics (Perturbation theory (quantum mechanics) - Wikipedia), which requires solving (w*I - H)x = b where H is a Hermitian operator and w is a real number (frequency).

Thanks @Jae-Mo_Lihm !
it will motivate us to add the support of complex numbers.

1 Like