Hey all—
I’m trying to write some manually specialized derivatives for a function so that they will play nicely with ForwardDiff
, and I’m coming into some issues with second derivatives (no surprises with first derivatives). In particular, I can build an object whose types and values agree with both the values
and partials
field of what ForwardDiff
produces, but somehow my call to the constructor Dual
calls a promote_type
function that bonks the equality. Here’s an example of the behavior:
using ForwardDiff
import ForwardDiff: Dual, value, partials, derivative, gradient, Partials, Tag
# Some type parameter getters for clarity below.
getN(::Type{Dual{T,V,N}}) where{T,V,N} = N
getT(::Type{Dual{T,V,N}}) where{T,V,N} = T
getV(::Type{Dual{T,V,N}}) where{T,V,N} = V
getN(x::Partials{N,V}) where{N,V} = N
getV(x::Partials{N,V}) where{N,V} = V
getters = Dict(:N=>getN, :V=>getV, :T=>getT)
# A toy function with its actual derivatives.
myfn(x,y) = x*sin(x)*cos(x+y)
_myfn(x,y) = x*sin(x)*cos(x+y) # to avoid stack overflow in manual version.
myfndx(x,y) = x*cos(x)*cos(x+y) + sin(x)*(cos(x+y) - x*sin(x+y))
myfnd2x(x,y) = -2*(sin(x)*(sin(x+y)+x*cos(x+y)) + cos(x)*(x*sin(x+y)-cos(x+y)))
function myfn(x::Dual{T,V,N}, y) where{T,V<:Dual,N}
println("test") # confirm this is getting run, like above.
out = _myfn(x,y) # okay, this is what ForwardDiff returns. I want to recreate it.
# Try to do this recreation entirely by hand, building the tag, values, and
# partials.
fn = _myfn(value(value(x)), y) # function eval
fndx = myfndx( value(value(x)), y) # first deriv
fnd2x = myfnd2x(value(value(x)), y) # second deriv
TT = Tag{typeof(d0),V} # manual tag
vl = Dual{T}(Dual{getT(V)}(fn, fndx), Dual{getT(V)}(fndx,fnd2x)) # manual values
pt = Partials((Dual{getT(V)}(fndx, fnd2x),)) # manual partials
# Comparison. I can't check equality of the partials and values because of
# floating point annoyances, but these print statements show that they are
# equal.
println("\nNote that all of these types and objects agree:")
println("Compare the partials:")
po = partials(out)
@assert typeof(pt) == typeof(po)
@show pt
@show po
# Compare the tags:
@assert getT(typeof(out)) == TT
println("Compare the values:")
ov = values(out)
@assert typeof(ov) == typeof(vl)
@show ov
@show vl
# So: now make our own manual version of "out". But here's where it gets
# weird: the type of _out disagrees with the type of out in a serious way,
# even though we've checked that all components and their types are equal. It
# seems like the creation changes the type parameter V here. And if I try to
# manually make sure it doesn't get bonked, I get a MethodError. It looks like
# promote_type(typeof(val), getV(partials)) is the culprit, but I'm not sure I
# understand whether it's buggy or not.
_out = Dual{TT}(vl, pt)
# Confirmation of the type differences:
if typeof(Dual{TT}(vl, pt)) != typeof(out)
for pr in (:N, :T, :V)
pr_out = getters[pr](typeof(out))
pr_try = getters[pr](typeof(_out))
if pr_out != pr_try
println()
@warn "Type parameter $pr disagrees:"
println(pr_out)
println(pr_try)
end
end
end
out
end
# I ultimately want this to work for curried versions like
# ForwardDiff.derivative(t->myfn(t, 1.5), 1.1), but I'm trying to get this
# working on its own first.
d0(v) = myfn(v, 1.5)
d1(v) = ForwardDiff.derivative(d0, v)
d2(v) = ForwardDiff.derivative(d1, v)
d2(1.1) # Trigger the check.
If you run this script, you get this output:
test
Note that all of these types and objects agree:
Compare the partials:
pt = Partials(Dual{Tag{typeof(d1),Float64}}(-1.696575598637723,-0.5305605534221305),)
po = Partials(Dual{Tag{typeof(d1),Float64}}(-1.6965755986377231,-0.5305605534221305),)
Compare the values:
ov = Dual{Tag{typeof(d0),Dual{Tag{typeof(d1),Float64},Float64,1}}}(Dual{Tag{typeof(d1),Float64}}(-0.8400321201319014,-1.6965755986377231),Dual{Tag{typeof(d1),Float64}}(-1.6965755986377231,-0.5305605534221305))
vl = Dual{Tag{typeof(d0),Dual{Tag{typeof(d1),Float64},Float64,1}}}(Dual{Tag{typeof(d1),Float64}}(-0.8400321201319014,-1.696575598637723),Dual{Tag{typeof(d1),Float64}}(-1.696575598637723,-0.5305605534221305))
┌ Warning: Type parameter V disagrees:
└ @ Main ~/Scratch/burnerbrowserhome/mwe.jl:65
Dual{Tag{typeof(d1),Float64},Float64,1}
Dual{Tag{typeof(d0),Dual{Tag{typeof(d1),Float64},Float64,1}},Dual{Tag{typeof(d1),Float64},Float64,1},1}
As you can see, the V
field in Dual{T,N,V}
seems to be getting promoted via the rule here in the source. I could be mistaken here, but considering that out
does have that type, it seems like there isn’t any actual incompatibility, and promote_type
is perhaps being overzealous. But then, maybe I’m missing something. All of this work is largely from ignorant reverse-engineering, so I could be pretty off the mark.
Can somebody advise? I recognize that this MWE is ridiculous and there’s no reason to do this, but I’d appreciate any help in trying to make this work, or if there is a better way to manually intervene and manually give some information to ForwardDiff in this way.