Strategies to use ReverseDiff.jl with NamedTuples (or ComponentArrays)

It seems like ReverseDiff.jl can be used on functions that accept NamedTuples, for instance in LogDensityProblems.jl and TransformVariables.jl, but I can’t figure out the details. Will you help me understand how I could change my attempt below to use ReverseDiff.jl with say the function g, instead of just f?

Here’s an example showing that LogDensityProblems and TransformVariables work well with ReverseDiff, and also highlights where I’m stuck in my attempt.

using TransformVariables
using LogDensityProblems
using ReverseDiff

g(θ) = -0.5 * θ.x' * θ.x
f(x) = -0.5 * x' * x

el = TransformedLogDensity(as((x = as(Array, 2),)), g);
adg = ADgradient(:ReverseDiff, el);

# my attempt
struct RAD{F, T, R}
    lp::F
    tape::T
    result::R
end

function RAD(lp, x)
    tape = ReverseDiff.compile(ReverseDiff.GradientTape(lp, (x,)))
    res = map(ReverseDiff.DiffResults.GradientResult, (similar(x), ))
    return RAD(lp, tape, res)
end

x = randn(2);
adf = RAD(f, similar(x));

LogDensityProblems.logdensity_and_gradient(adg, x)
ReverseDiff.gradient!(adf.result, adf.tape, (x,))

RAD(g, (x = x,)) # errors

The immediate error is that there is “no method matching similar(::NamedTuple{(:x,), Tuple{Vector{Float64}}})”, but I believe this is just the first error of many to follow.

Would you help me understand the strategy used in LogDensityProblems.jl and TransformVariables.jl? I tried reading the code, but was stumped by the function https://github.com/tpapp/TransformVariables.jl/blob/e6efa6ac266a3bf5d5fd3b26be443bb35391f1c9/src/aggregation.jl#L57

Thanks in advance.

ReverseDiff is not friends with many structs. I would use a different AD if you need to handle such cases.

I plan to make the introduction, one of those weekends.

1 Like

The standard approach is to flatten the struct and take the gradient wrt to a vector. Then unflatten the gradient as a post-processing step.

Is this what you’re looking for?

using ComponentArrays, ReverseDiff

g(θ) = -0.5 * θ.x' * θ.x
θ = ComponentArray(x=randn(2))

ReverseDiff.gradient(g, θ)
# ComponentVector{Float64}(x = [0.16737302092295353, -0.4702876184629478])

edit: I’m not really sure what the TransformedLogDensity stuff is doing (I’m not really familiar with TransformVariables.jl), but it seems that logdensity_and_gradient is just calculating the value and gradient of g(θ). If that’s the case, you don’t really need anything besides ReverseDiff.gradient, I think.

@jonniedie indeed, ComponentArrays works within my attempts to pre-compile the tape too. My mistake. I’ll remove my parenthetical remark from the title of this thread. edit: Turns out I won’t edit the title, cause either I can’t or I don’t know how.

Thanks all for your help. Looking forward to seeing the outcome of ReverseDiff introduced to NamedTuples :slight_smile:

1 Like