Error with gradient function in quantum reinforcement learning algorithm

Hello. I am trying to implement the REINFORCE algorithm with parametrized quantum circuits from this Python tutorial (Parametrized Quantum Circuits for Reinforcement Learning  |  TensorFlow Quantum). To do this, I’ve implemented a quantum system emulator of sorts with the code below:

begin
	# representations of a qubit and its properties
	struct Qubit
		# |Ψ⟩ = α|0⟩ + β|1⟩
		# |α|² + |β|² = 1
		α::Complex
		β::Complex
		Qubit() = new(1, 0)
		Qubit(α::ComplexF64, β::ComplexF64) = verify_magnitude_sum(α, β) ? new(α, β) : error("Invalid Probability Amplitudes")
		Qubit(θ::Real, ϕ::Real) = new(cos(θ/2)+0.0im, exp(im*ϕ)*sin(θ/2)+0.0im)
		Qubit(v::Matrix{ComplexF64}) = new(v[1:2]...)
	end
	qubit_vector(q::Qubit)::Matrix = convert.(ComplexF64, reshape([q.α, q.β], (2, 1)))
	multi_qubit_vector(qs::Qubit...)::Matrix = kron(qubit_vector.(qs)...)	
	struct NQubit{N}
		# |ψ⟩ = ∑ᵢαᵢ|bin(i)⟩
		# ∑ᵢ|αᵢ|² = 1
		coefficients::Matrix{ComplexF64}
		NQubit(qubits::Qubit...) = length(qubits)>1 ? new{length(qubits)}(multi_qubit_vector(qubits...)) : new{length(qubits)}(qubit_vector(qubits[1]))
		NQubit(qubits::Matrix{ComplexF64}) = verify_magnitude_sum(qubits...) ? new{Int(log2(length(qubits)))}(qubits) : error("Invalid Probability Amplitudes")
		NQubit(qubits::Vector{ComplexF64}) = verify_magnitude_sum(qubits...) ? new{Int(log2(length(qubits)))}(hvcat(1, qubits...)) : error("Invalid Probability Amplitudes")
	end
	custom_round(ψ::NQubit) = NQubit(custom_round.(ψ.coefficients))
	
	import Base: *
	a::Matrix{ComplexF64} * q::Qubit = Qubit(a * convert.(ComplexF64, reshape([q.α, q.β], (2, 1))))
	a::Matrix{ComplexF64} * q::NQubit = NQubit(a * q.coefficients)
end
begin
	
	# 1-qubit gates
	R_x(θ::Real)::Matrix{ComplexF64} = [cos(θ/2) -im*sin(θ/2); -im*sin(θ/2) cos(θ/2)]
	R_y(θ::Real)::Matrix{ComplexF64} = [cos(θ/2) -sin(θ/2); sin(θ/2) cos(θ/2)]
	R_z(θ::Real)::Matrix{ComplexF64} = [exp(-im*θ/2) 0; 0 exp(im*θ/2)]
	P(λ)::Matrix{ComplexF64} = [1 0; 0 exp(im*λ)]
	H = Matrix{ComplexF64}([1/√(2) 1/√(2); 1/√(2) -1/√(2)])
	X = Matrix{ComplexF64}([0 1; 1 0])
	Y = Matrix{ComplexF64}([0 -im; im 0])
	Z = Matrix{ComplexF64}([1 0; 0 -1])
	S = Matrix{ComplexF64}([1 0; 0 im])
	T = Matrix{ComplexF64}([1 0; 0 sqrt(im)])
	decompose(U::Matrix{ComplexF64}) = begin
		γ = atan(imag(det(U)),real(det(U)))/2
		V = exp(-im*γ)*U
		θ = abs(V[1, 1])≥abs(V[1, 2]) ? 2*acos(abs(V[1, 1])) : 2*asin(abs(V[1, 2]))
		if cos(θ/2) == 0
			λ = atan(imag(V[2, 1]/sin(θ/2)), real(V[2, 1]/sin(θ/2)))
			ϕ = -λ
		elseif sin(θ/2) == 0
			ϕ = atan(imag(V[2, 2]/cos(θ/2)), real(V[2, 2]/cos(θ/2)))
			λ = ϕ
		else
			ϕ = atan(imag(V[2, 2]/cos(θ/2)), real(V[2, 2]/cos(θ/2)))+atan(imag(V[2, 1]/sin(θ/2)), real(V[2, 1]/sin(θ/2)))
			λ = 2*atan(imag(V[2, 2]/cos(θ/2)), real(V[2, 2]/cos(θ/2)))-ϕ
		end
		(
			round(rad2deg(real(θ)), digits=5), 
			round(rad2deg(real(ϕ)), digits=5), 
			round(rad2deg(real(λ)), digits=5),
			round(rad2deg(real(γ)), digits=5)
		)
	end
	U₃(θ, ϕ, λ, γ) = custom_round.(exp(im*γ)*(R_z(ϕ)*R_y(θ)*R_z(λ)))

	# 2-qubit gates
	CU(num_registers::Int, U::Matrix{ComplexF64}, control_index::Int, target_index::Int) = 
		target_index > control_index ? kron(
			I(2^(control_index-1)), 
			[1, 0] * [1 0], 
			I(2^(num_registers-control_index))
		) + kron(
			I(2^(control_index-1)), 
			[0, 1] * [0 1], 
			I(2^(target_index-control_index-1)), 
			U, 
			I(2^(num_registers-target_index))
		) : kron(
			I(2^(control_index-1)), 
			[1, 0] * [1 0], 
			I(2^(num_registers-control_index))
		) + kron(
			I(2^(target_index-1)),
			U,
			I(2^(control_index-target_index-1)), 
			[0, 1] * [0 1], 
			I(2^(num_registers-control_index))
		)
	CX = CU(2, X, 1, 2)
	CZ = CU(2, Z, 1, 2)
	CS = CU(2, S, 1, 2)
	CH = CU(2, H, 1, 2)
	SWAP = Matrix{ComplexF64}([1 0 0 0; 0 0 1 0; 0 1 0 0; 0 0 0 1])

	# 3-qubit gates
	CCU(num_registers::Int, U::Matrix{ComplexF64}, control_index_1::Int, control_index_2::Int, target_index::Int) = 
		target_index > min(control_index_1, control_index_2) ? kron(
			I(2^(min(control_index_1, control_index_2)-1)), 
			[1, 0] * [1 0], 
			I(2^(num_registers-min(control_index_1, control_index_2)))
		) + kron(
			I(2^(min(control_index_1, control_index_2)-1)), 
			[0, 1] * [0 1], 
			CU(num_registers-min(control_index_1, control_index_2), U, target_index-min(control_index_1, control_index_2), max(control_index_1, control_index_2)-min(control_index_1, control_index_2))
		) : kron(
			I(2^(max(control_index_1, control_index_2)-1)), 
			[1, 0] * [1 0], 
			I(2^(num_registers-max(control_index_1, control_index_2)))
		) + kron(
			CU(max(control_index_1, control_index_2)-1, U, target_index, min(control_index_1, control_index_2)),
			[0, 1] * [0 1],
			I(2^(num_registers-max(control_index_1, control_index_2)))
		)
	CCU(num_registers::Int, U::Matrix{ComplexF64}, control_index_1::Int, control_index_2::Int, target_index::Int) = begin
		control_index_1, control_index_2 = sort([control_index_1, control_index_2])
		target_index > control_index_2 ? kron(
			I(2^(control_index_1-1)),
			[0, 1] * [0 1],
			I(2^(control_index_2-control_index_1-1)),
			CU(num_registers-control_index_2+1, U, 1, target_index-control_index_2+1)
		) + kron(
			I(2^(control_index_1-1)),
			[1, 0] * [1 0],
			I(2^(num_registers-control_index_1))
		) : target_index < control_index_1 ? kron(
			I(2^(target_index-1)),
			CU(control_index_1-target_index+1, U, control_index_1-target_index+1, 1),
			I(2^(control_index_2-control_index_1-1)),
			[0, 1] * [0 1],
			I(2^(num_registers-control_index_2))
		) + kron(
			I(2^(control_index_2-1)),
			[1, 0] * [1 0],
			I(2^(num_registers-control_index_2))
		) : kron(
			I(2^(control_index_1-1)),
			[0, 1] * [0 1],
			I(2^(target_index-control_index_1-1)),
			CU(num_registers-target_index+1, U, control_index_2-target_index+1, 1)
		) + kron(
			I(2^(control_index_1-1)),
			[1, 0] * [1 0],
			I(2^(num_registers-control_index_1))
		)
	end
	CB(B::Matrix{ComplexF64})::Matrix{ComplexF64} = hvcat(
		(2, 2),
		[1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1],
		[0 0 0 0; 0 0 0 0; 0 0 0 0; 0 0 0 0],
		[0 0 0 0; 0 0 0 0; 0 0 0 0; 0 0 0 0],
		B
	)
	CCX = CCU(3, X, 1, 2, 3)
	CSWAP = CB(SWAP)

	# N-qubit gates
	COLUMN(I_before::Int, gate::Matrix{ComplexF64}, I_after::Int) = kron(
		I(2^I_before),
		gate,
		I(2^I_after)
	)
	QFT(N::Int) = custom_round.([exp(2*i*j*π*im/(2^N))/sqrt(2^N) for i in 0:2^N-1, j in 0:2^N-1])
	IQFT(N::Int) = adjoint(QFT(N))
	NSWAP(a::Int, b::Int, N::Int) = begin
		# move register A to register B in an N-Qubit system
		U = kron((Iₙ(2) for i in 1:N)...)
		for i in 1:abs(b-a)
			U = kron((Iₙ(2) for j in 1:(b>a ? a+i-2 : a-i-1))..., SWAP, (Iₙ(2) for j in (b>a ? a+i+1 : a-i+2):N)...) * U
		end
		U
	end
	ENTANGLING_LAYER(N::Int) = N == 2 ? CZ : *((CU(N, Z, 1, N), (CU(N, Z, i, i+1) for i in 1:N-1)...)...)
	
end
begin
	struct Circuit{N}
		ψ::NQubit{N}
		columns::Array{Matrix{ComplexF64}}
		Circuit(N::Int) = new{N}(NQubit((Qubit() for i in 1:N)...), [])
		Circuit(N::Int, columns::Array{Matrix{ComplexF64}}) = new{N}(NQubit((Qubit() for i in 1:N)...), columns)
		Circuit(qubits::Qubit...) = new{length(qubits)}(NQubit(qubits...), [])
		Circuit(ψ::NQubit) = new{typeof(ψ).parameters[1]}(ψ, [])
		Circuit(ψ::NQubit, columns::Array{Matrix{ComplexF64}}) = new{typeof(ψ).parameters[1]}(ψ, columns)
	end
	
	add_gate!(C::Circuit, gate::Matrix{ComplexF64}, register::Int) = push!(C.columns, COLUMN(register-1, gate, typeof(C).parameters[1]-register))
	add_controlled_gate!(C::Circuit, gate::Matrix{ComplexF64}, control_register::Int, target_register::Int) = push!(C.columns, CU(typeof(C).parameters[1], gate, control_register, target_register))
	add_double_controlled_gate!(C::Circuit, gate::Matrix{ComplexF64}, control_register_1::Int, control_register_2::Int, target_register::Int) = push!(C.columns, CCU(typeof(C).parameters[1], gate, control_register_1, control_register_2, target_register))
	add_column!(C::Circuit, column::Matrix{ComplexF64}) = push!(C.columns, column)

	run(C::Circuit) = begin
		columns = [C.columns[i] for i in 1:length(C.columns) if isassigned(C.columns, i)]
		ψ = C.ψ
		for c in columns
			ψ = c * ψ
		end
		ψ
	end
end
make_PQC!(C::Circuit, params::Array{T}) where T<:Real = begin
	if length(params) != 6*typeof(C).parameters[1]
		return -1
	else
		N = typeof(C).parameters[1]
		for i in 1:N
			add_gate!(C, R_x(params[i]), i)
			add_gate!(C, R_y(params[i+N]), i)
			add_gate!(C, R_z(params[i+2N]), i)
		end
		for i in 1:(N-1)
			add_controlled_gate!(C, Z, i, i+1)
		end
		if N!=1
			add_controlled_gate!(C, Z, N, 1)
		end
		for i in 1:N
			add_gate!(C, R_x(params[i+3N]), i)
			add_gate!(C, R_y(params[i+4N]), i)
			add_gate!(C, R_z(params[i+5N]), i)
		end
	end
	C
end
begin
	encode(v::Vector{T}) where T<:Number = NQubit(
		vcat(
			v./sum((x->x^2).(v))^.5,
			zeros(ComplexF64, Int(2^ceil(log2(length(v)))-length(v)))
		)
	)
	encode(v) = encode(vcat(v...))
    decode(C::Circuit, num_outputs::Int) = Dense(rand(MersenneTwister(SEED), num_outputs, 2^typeof(C).parameters[1]), false)(run(C).coefficients)
end
begin
	space_length(s) = try
		convert(Int, s.n)
	catch
		length(s.sample())
	end
	
	struct Env
		env
		n_actions::Int
		n_states::Int
		Env(env) = new(env, space_length(env.action_space), space_length(env.observation_space))
	end

	step(env::Env, action) = env.env.step(action)
	reset(env::Env) = env.env.reset()
	play(env::Env, keys_to_action::Dict) = play_env(env.env, keys_to_action=keys_to_action)
	
end
begin
	mutable struct Policy{N_states, N_actions}
		env::Env
		parameters::Vector{Float64}
		C::Circuit
		decoder::Dense
		Policy(env::Env) = begin
			obs, info = reset(env)
			ψ = encode(obs)
			new{env.n_states, env.n_actions}(
				env,
				[rand(MersenneTwister(SEED)) for i in 1:typeof(ψ).parameters[1]*6], # zeros(typeof(ψ).parameters[1]*6)
				Circuit(ψ),
				Dense(rand(MersenneTwister(SEED), env.n_actions, 2^typeof(ψ).parameters[1]), false)
			)
		end
		Policy(env::Env, params::Vector{Float64}) = begin
			obs, info = reset(env)
			ψ = encode(obs)
			new{env.n_states, env.n_actions}(
				env,
				params, # zeros(typeof(ψ).parameters[1]*6)
				Circuit(ψ),
				Dense(rand(MersenneTwister(SEED), env.n_actions, 2^typeof(ψ).parameters[1]), false)
			)
		end
	end

	evaluate_policy(policy::Policy, state) = length(state)==policy.env.n_states ? policy.decoder(
		begin
			policy.C = Circuit(encode(state))
			make_PQC!(policy.C, policy.parameters)
			(x->abs(x)).(run(policy.C).coefficients)
		end
	) |> softmax : error("Invalid state")
end
begin
	struct Trajectory
		states::Vector
		actions::Vector
		rewards::Vector
		Trajectory() = new([], [], [])
		Trajectory(states::Vector, actions::Vector, rewards::Vector) = new(states, actions, rewards)
	end

	add_step!(T::Trajectory, state, action, reward) = begin
		push!(T.states, state)
		push!(T.actions, action)
		push!(T.rewards, reward)
	end

	Base.length(T::Trajectory) = length(T.states)
end
env = Env(gym.make("CartPole-v1", render_mode="rgb_array"))

Here is the training loop:

begin
	generate_trajectory(env::Env, policy::Policy) = begin
		T = Trajectory()
		obs, info = reset(env)
		action = argmax(evaluate_policy(policy, obs))[1] - 1
		obs, reward, terminated, truncated, info = step(env, action)
		add_step!(T, obs, action, reward)
		while !terminated && !truncated
			action = argmax(evaluate_policy(policy, obs))[1] - 1
			obs, reward, terminated, truncated, info = step(env, action)
			add_step!(T, obs, action, reward)
		end
		T
	end

	loss(π::Policy, T::Trajectory, γ::Float64) = -1/abs(length(T)) * sum(log(evaluate_policy(π, T.states[t])[T.actions[t]+1]) * sum(γ^(t2-t)*T.rewards[t2] for t2 in t:length(T)) for t in 1:length(T))
	
	REINFORCE(env::Env, N_EPISODES::Int=1000, α::Float64=0.01, γ::Float64=0.99) = begin
		π_θ = Policy(env)
		trajectory = generate_trajectory(env, π_θ)
		rewards = []
		for i in 1:N_EPISODES
			θ = π_θ.parameters
			trajectory = generate_trajectory(env, π_θ)
			for t in 0:length(trajectory)-1
				G = sum(γ^(k-t-1)*trajectory.rewards[k] for k in t+1:length(trajectory))
				grad = gradient(θ) do
					log(evaluate_policy(π_θ, trajectory.states[t+1])[trajectory.actions[t+1]+1])
		        end
			end
			π_θ.parameters = θ
		end
		π_θ, rewards
	end
end

When I run the REINFORCE function, I receive an error with the gradient function that says “MethodError: no method matching (::Main.var"workspace#7”.var"#40#44"{Main.var"workspace#7".Policy{4, 2}, Int64})(::Vector{Float64})". How exactly does the gradient function work and what are some ways to fix this error?