How to compute gradient with respect to a circuit parameter?

  1. I think there are not room for improvement if the overhead is only 2x.

  2. The forward diff implementation, the correctness of gradients is not checked.

using YaoExtensions: LinearAlgebra
using Yao, Random, YaoExtensions
using ForwardDiff: Dual
using LinearAlgebra

const SINGLE_GATES = [Rx, Ry, Rz]
@const_gate _ZZ = mat(kron(Z,Z))
rng = MersenneTwister()
randG = Random.Sampler(rng, SINGLE_GATES)

# an ugly patch
@inline function LinearAlgebra.__normalize!(v::AbstractVector, nrm)
    # The largest positive floating point number whose inverse is less than infinity
    δ = inv(prevfloat(typemax(nrm)))

    if nrm ≥ δ # Safe to multiply with inverse
        invnrm = inv(nrm)
        rmul!(v, invnrm)

    else # scale elements to avoid overflow
        εδ = eps(one(nrm))/δ
        rmul!(v, εδ)
        rmul!(v, inv(nrm*εδ))
    end

    v
end

function circuit_layer(n::Int, β::Vector, eps::Float64)
    U = chain(n)
    append!(U, chain(n, put(n, loc=>chain(rand(rng,randG)((2*rand())*eps*π))) for loc = 1:n))
    for j = 1:n-1
        append!(U, chain(n, put(n, (j,j+1)=>rot(_ZZ, β[j]))))
    end
    return U
end

function cgrad_flayer(n::Int, β::Vector, eps::Float64, θ::Array{Float64, 1})
    U = chain(n)
    append!(U, chain(n, put(n, loc=>chain(rand(rng,randG)(θ[loc]))) for loc = 1:n))
    for j = 1:n-1
        append!(U, chain(n, put(n, (j,j+1)=>rot(_ZZ, β[j]))))
    end
    return U
end

_onehot(n, i) = (x=zeros(n); x[i]=1; x)
function get_gradients(pmax::Int, spacing::Float64, n::Int, eps::Float64, O::PutBlock)
    fp = trunc(log2(pmax)) 
    times = union(trunc.([2^i for i in 1:spacing:fp]))
    
    β = Dual.((π*0.5)*(2.0*rand(n-1) .- 1.0), _onehot(n-1, 1))
    ψ1 = Yao.uniform_state(Complex{eltype(β)}, n)
    θ = (2*π*eps)*rand(n)
    grads  = zeros(length(times))
    count = 0
    
    ψ1 |> cgrad_flayer(n, β, eps, θ)
    
    for t in 1:last(times)
        U = circuit_layer(n, β, eps)
        ψ1 |> U
        if (t in times) == true
            count = count + 1
            grads[count] = real(expect(O, ψ1)).partials[1]
        end
    end
    grads
end

@time data = get_gradients(2000, 0.15, 10, 0.1, put(10, (1,2)=>_ZZ))