Rrule for zygote

Hi, I am using rrule for gradient calculation in zygote. I defined a function which calculates the gradients and I just call that function inside the entire gradient calculation. Can zygote handle this type of rule?

function ldt_simulation(nsamples,iteration_counter,pf_vec)
.
.
.
    return grad_logPf 
end
function Dgfdpf(pf_vec; iteration_counter,β, η, fem_params)
    pfg = ldt_simulation(nsamples,iteration_counter,pf_vec)
    return  pfg
end 

function pf_p0(p0; r, fem_params)
    pf_vec = Filter(p0; r, fem_params)
    pf_vec
end

function rrule(::typeof(pf_p0), p0; r, fem_params)
  function pf_pullback(dgdpf)
    NO_FIELDS, Dgdp(dgdpf; r, fem_params)
  end
  pf_p0(p0; r, fem_params), pf_pullback
end

function Dgdp(dgdpf; r, fem_params)
    Af = assemble_matrix(fem_params.Pf, fem_params.Qf) do u, v
        ∫(a_f(r, u, v))fem_params.dΩ + ∫(v * u)fem_params.dΩ
    end
    wvec = Af' \ dgdpf
    wh = FEFunction(fem_params.Pf, wvec)
    l_temp(dp) = ∫(wh * dp)fem_params.dΩ
    return assemble_vector(l_temp, fem_params.P)
end 


function gf_p(p0::Vector;iteration_counter, r, β, η, fem_params)
    pf_vec = pf_p0(p0; r, fem_params)
    gf_pf(pf_vec;iteration_counter, β, η, fem_params)
end

function gf_p(p0::Vector, grad::Vector;iteration_counter, r, β, η, fem_params)
    if length(grad) > 0
        
        dgdp, = Zygote.gradient(p -> gf_p(p; iteration_counter,r, β, η, fem_params), p0)
        grad[:] = dgdp
    end
    gvalue = gf_p(p0::Vector; iteration_counter,r, β, η, fem_params)
    
    gvalue
end

Thank you

Not quite. An rrule is not just a function that computes a gradient, it is a vector–Jacobian product.

That is, if you have a function f(x), then an rrule for x returns a “pullback” function g(v) that computes v^T f'(x) where f' is the Jacobian.

If you have a scalar-valued function f(x) of a vector x, then the Jacobian is a row vector f'(x) = (\nabla f)^T, and your pullback takes a scalar v and returns v (\nabla f)^T (in the correct shape).

1 Like

I have a functionf(x) = K[(ln(pf)-ln(pa))^+]^2 which has a scalar value if inside the bracket has a + value or zero if inside the bracket isn’t positive. The ldt_simulation function calculates the gradient of this f(x) which is a vector. Based on this assumptions and what you explained I think it can handle the gradient calculation by zygote, Am I right? Is there any way to check this?

1 Like

Yes, you can implement an rrule for anything (assuming the derivative exists). You normally check it with finite difference.

Often I am able to check with ForwardDiff.jl as well (maybe less performant for high number of inputs, but for sure easy to use in a lot of case).