Rrule for zygote

Hi, I am using rrule for gradient calculation in zygote. I defined a function which calculates the gradients and I just call that function inside the entire gradient calculation. Can zygote handle this type of rule?

function ldt_simulation(nsamples,iteration_counter,pf_vec)
.
.
.
    return grad_logPf 
end
function Dgfdpf(pf_vec; iteration_counter,β, η, fem_params)
    pfg = ldt_simulation(nsamples,iteration_counter,pf_vec)
    return  pfg
end 

function pf_p0(p0; r, fem_params)
    pf_vec = Filter(p0; r, fem_params)
    pf_vec
end

function rrule(::typeof(pf_p0), p0; r, fem_params)
  function pf_pullback(dgdpf)
    NO_FIELDS, Dgdp(dgdpf; r, fem_params)
  end
  pf_p0(p0; r, fem_params), pf_pullback
end

function Dgdp(dgdpf; r, fem_params)
    Af = assemble_matrix(fem_params.Pf, fem_params.Qf) do u, v
        ∫(a_f(r, u, v))fem_params.dΩ + ∫(v * u)fem_params.dΩ
    end
    wvec = Af' \ dgdpf
    wh = FEFunction(fem_params.Pf, wvec)
    l_temp(dp) = ∫(wh * dp)fem_params.dΩ
    return assemble_vector(l_temp, fem_params.P)
end 


function gf_p(p0::Vector;iteration_counter, r, β, η, fem_params)
    pf_vec = pf_p0(p0; r, fem_params)
    gf_pf(pf_vec;iteration_counter, β, η, fem_params)
end

function gf_p(p0::Vector, grad::Vector;iteration_counter, r, β, η, fem_params)
    if length(grad) > 0
        
        dgdp, = Zygote.gradient(p -> gf_p(p; iteration_counter,r, β, η, fem_params), p0)
        grad[:] = dgdp
    end
    gvalue = gf_p(p0::Vector; iteration_counter,r, β, η, fem_params)
    
    gvalue
end

Thank you

Not quite. An rrule is not just a function that computes a gradient, it is a vector–Jacobian product.

That is, if you have a function f(x), then an rrule for x returns a “pullback” function g(v) that computes v^T f'(x) where f' is the Jacobian.

If you have a scalar-valued function f(x) of a vector x, then the Jacobian is a row vector f'(x) = (\nabla f)^T, and your pullback takes a scalar v and returns v (\nabla f)^T (in the correct shape).

3 Likes

I have a functionf(x) = K[(ln(pf)-ln(pa))^+]^2 which has a scalar value if inside the bracket has a + value or zero if inside the bracket isn’t positive. The ldt_simulation function calculates the gradient of this f(x) which is a vector. Based on this assumptions and what you explained I think it can handle the gradient calculation by zygote, Am I right? Is there any way to check this?

1 Like

Yes, you can implement an rrule for anything (assuming the derivative exists). You normally check it with finite difference.

1 Like

Often I am able to check with ForwardDiff.jl as well (maybe less performant for high number of inputs, but for sure easy to use in a lot of case).

1 Like

I think the OP @mary is trying to do topology optimization through Gridap.jl FEM, similar to this tutorial, in which case ForwardDiff may not be so easy.

I would usually finite difference to compare the gradient to some random directional derivatives, i.e. check that f(x + \delta x) - f(x) \approx \delta x^T \nabla f for some small random vectors \delta x.

1 Like

I wanted to check that with finite difference but the objective is K[(ln(pf)-ln(pa))^+]^2 which pf is probability of failure. By perturbing a small value the pf is not changing so the finite difference value is 0.0 but I have a huge gradient because in addition to 2K[(ln(pf)-ln(pa))^+] I have the \nabla ln(pf) multiplied. I am not sure that finite difference is working here or not.

If you perturb the inputs by too little then the accuracy of finite differences is destroyed by the finite precision. See also chapter 4 of our matrix-calculus notes.