It seems like ReverseDiff.jl can be used on functions that accept NamedTuples, for instance in LogDensityProblems.jl and TransformVariables.jl, but I can’t figure out the details. Will you help me understand how I could change my attempt below to use ReverseDiff.jl with say the function g
, instead of just f
?
Here’s an example showing that LogDensityProblems and TransformVariables work well with ReverseDiff, and also highlights where I’m stuck in my attempt.
using TransformVariables
using LogDensityProblems
using ReverseDiff
g(θ) = -0.5 * θ.x' * θ.x
f(x) = -0.5 * x' * x
el = TransformedLogDensity(as((x = as(Array, 2),)), g);
adg = ADgradient(:ReverseDiff, el);
# my attempt
struct RAD{F, T, R}
lp::F
tape::T
result::R
end
function RAD(lp, x)
tape = ReverseDiff.compile(ReverseDiff.GradientTape(lp, (x,)))
res = map(ReverseDiff.DiffResults.GradientResult, (similar(x), ))
return RAD(lp, tape, res)
end
x = randn(2);
adf = RAD(f, similar(x));
LogDensityProblems.logdensity_and_gradient(adg, x)
ReverseDiff.gradient!(adf.result, adf.tape, (x,))
RAD(g, (x = x,)) # errors
The immediate error is that there is “no method matching similar(::NamedTuple{(:x,), Tuple{Vector{Float64}}})”, but I believe this is just the first error of many to follow.
Would you help me understand the strategy used in LogDensityProblems.jl and TransformVariables.jl? I tried reading the code, but was stumped by the function https://github.com/tpapp/TransformVariables.jl/blob/e6efa6ac266a3bf5d5fd3b26be443bb35391f1c9/src/aggregation.jl#L57
Thanks in advance.