Hi, I am using rrule for gradient calculation in zygote. I defined a function which calculates the gradients and I just call that function inside the entire gradient calculation. Can zygote handle this type of rule?
function ldt_simulation(nsamples,iteration_counter,pf_vec)
.
.
.
    return grad_logPf 
end
function Dgfdpf(pf_vec; iteration_counter,β, η, fem_params)
    pfg = ldt_simulation(nsamples,iteration_counter,pf_vec)
    return  pfg
end 
function pf_p0(p0; r, fem_params)
    pf_vec = Filter(p0; r, fem_params)
    pf_vec
end
function rrule(::typeof(pf_p0), p0; r, fem_params)
  function pf_pullback(dgdpf)
    NO_FIELDS, Dgdp(dgdpf; r, fem_params)
  end
  pf_p0(p0; r, fem_params), pf_pullback
end
function Dgdp(dgdpf; r, fem_params)
    Af = assemble_matrix(fem_params.Pf, fem_params.Qf) do u, v
        ∫(a_f(r, u, v))fem_params.dΩ + ∫(v * u)fem_params.dΩ
    end
    wvec = Af' \ dgdpf
    wh = FEFunction(fem_params.Pf, wvec)
    l_temp(dp) = ∫(wh * dp)fem_params.dΩ
    return assemble_vector(l_temp, fem_params.P)
end 
function gf_p(p0::Vector;iteration_counter, r, β, η, fem_params)
    pf_vec = pf_p0(p0; r, fem_params)
    gf_pf(pf_vec;iteration_counter, β, η, fem_params)
end
function gf_p(p0::Vector, grad::Vector;iteration_counter, r, β, η, fem_params)
    if length(grad) > 0
        
        dgdp, = Zygote.gradient(p -> gf_p(p; iteration_counter,r, β, η, fem_params), p0)
        grad[:] = dgdp
    end
    gvalue = gf_p(p0::Vector; iteration_counter,r, β, η, fem_params)
    
    gvalue
end
Thank you