Hi, I am using rrule for gradient calculation in zygote. I defined a function which calculates the gradients and I just call that function inside the entire gradient calculation. Can zygote handle this type of rule?
function ldt_simulation(nsamples,iteration_counter,pf_vec)
.
.
.
return grad_logPf
end
function Dgfdpf(pf_vec; iteration_counter,β, η, fem_params)
pfg = ldt_simulation(nsamples,iteration_counter,pf_vec)
return pfg
end
function pf_p0(p0; r, fem_params)
pf_vec = Filter(p0; r, fem_params)
pf_vec
end
function rrule(::typeof(pf_p0), p0; r, fem_params)
function pf_pullback(dgdpf)
NO_FIELDS, Dgdp(dgdpf; r, fem_params)
end
pf_p0(p0; r, fem_params), pf_pullback
end
function Dgdp(dgdpf; r, fem_params)
Af = assemble_matrix(fem_params.Pf, fem_params.Qf) do u, v
∫(a_f(r, u, v))fem_params.dΩ + ∫(v * u)fem_params.dΩ
end
wvec = Af' \ dgdpf
wh = FEFunction(fem_params.Pf, wvec)
l_temp(dp) = ∫(wh * dp)fem_params.dΩ
return assemble_vector(l_temp, fem_params.P)
end
function gf_p(p0::Vector;iteration_counter, r, β, η, fem_params)
pf_vec = pf_p0(p0; r, fem_params)
gf_pf(pf_vec;iteration_counter, β, η, fem_params)
end
function gf_p(p0::Vector, grad::Vector;iteration_counter, r, β, η, fem_params)
if length(grad) > 0
dgdp, = Zygote.gradient(p -> gf_p(p; iteration_counter,r, β, η, fem_params), p0)
grad[:] = dgdp
end
gvalue = gf_p(p0::Vector; iteration_counter,r, β, η, fem_params)
gvalue
end
Thank you