Using ForwardDiff for Symbolic.derivative() definition

I’ve only been playing with MTK.jl for a day or two, but I’m really impressed - the workflow here is so much nicer than using a dedicated acausal modelling language! I’m running into situations where judicious function registration is advisable, and I’d like to understand how flexibly derivatives can be specified. ModelingToolkit.jl enables you to register functions via @register_symbolic, and also to attach derivative information to those functions. This may be handy when the function is too big and scary to be converted entirely into symbolic form. For example, we can do the following:

function foo_f(t)
    return t^2
end
@variables t
@register_symbolic foo_f(t)
Symbolics.derivative(::typeof(foo_f), args::NTuple{1,Any}, ::Val{1}) = 2*args[1]
dfdt = Symbolics.derivative(foo_f, t). #2t
substitute(dfdt, t => 2) #4

I’ve been playing around with automatically calculating derivatives via ForwardDiff.jl, but I’m running into some issues. Directly defining a new method for Symbolics.derivative works fine:

function foo_f2(x)
    return x^2
end
@variables t
@register_symbolic foo_f2(t)
Symbolics.derivative(::typeof(foo_f2), t::Num) = ForwardDiff.derivative(foo_f2, t)
dfdt = Symbolics.derivative(foo_f2, t) #2t
substitute(dfdt, t => 2) #4

But doing the proper thing as described in the docs gives an error:

function foo_f3(x)
    return x^2
end
@variables t
@register_symbolic foo_f3(t)
Symbolics.derivative(::typeof(foo_f3), args::NTuple{1,Any}, ::Val{1}) = ForwardDiff.derivative(foo_f3, args[1])
dfdt_3 = Symbolics.derivative(foo_f3, t)
substitute(dfdt_3, t => 2)
ERROR: MethodError: no method matching derivative(::typeof(foo_f3), ::SymbolicUtils.BasicSymbolic{Real})
The function `derivative` exists, but no method is defined for this combination of argument types.

I’m curious if it’s possible to get ForwardDiff and Symbolics.derivative to play nice with each other for a simple example like this. Also, will the second approach above (directly overriding the helper function Symbolic.derivative(::typeof(foo_f2), t::Num)) work, or is this asking for trouble later on?

dfdt_3 = Symbolics.derivative(foo_f3(t), t) ?

Same error, sorry.

function foo_f3(x)
    return x^2
end
@variables t
@register_symbolic foo_f3(t)
Symbolics.derivative(::typeof(foo_f3), args::NTuple{1,Any}, ::Val{1}) = ForwardDiff.derivative(foo_f3, args[1])
dfdt_3 = Symbolics.derivative(foo_f3(t), t)
substitute(dfdt_3, t => 2)
ERROR: MethodError: no method matching derivative(::typeof(foo_f3), ::SymbolicUtils.BasicSymbolic{Real})
The function `derivative` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  derivative(::Any, ::Complex)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/derivative.jl:76
  derivative(::Any, ::AbstractArray)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/derivative.jl:75
  derivative(::F, ::AbstractArray, ::Real, ::ForwardDiff.DerivativeConfig{T}, ::Val{CHK}) where {F, T, CHK}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/derivative.jl:25
  ...

Stacktrace:
 [1] derivative(::typeof(foo_f3), args::Tuple{SymbolicUtils.BasicSymbolic{Real}}, ::Val{1})
   @ Main ~/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Misc/codeplay/MToolkit2.jl:142
 [2] derivative_idx(O::SymbolicUtils.BasicSymbolic{Real}, idx::Int64)
   @ Symbolics ~/.julia/packages/Symbolics/YbNrd/src/diff.jl:347
 [3] expand_derivatives(O::SymbolicUtils.BasicSymbolic{Real}, simplify::Bool; occurrences::Nothing)
   @ Symbolics ~/.julia/packages/Symbolics/YbNrd/src/diff.jl:262
 [4] expand_derivatives(O::SymbolicUtils.BasicSymbolic{Real}, simplify::Bool)
   @ Symbolics ~/.julia/packages/Symbolics/YbNrd/src/diff.jl:183
 [5] derivative(O::Num, var::Num; simplify::Bool)
   @ Symbolics ~/.julia/packages/Symbolics/YbNrd/src/diff.jl:445
 [6] derivative(O::Num, var::Num)
   @ Symbolics ~/.julia/packages/Symbolics/YbNrd/src/diff.jl:441
 [7] top-level scope
   @ ~/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Misc/codeplay/MToolkit2.jl:143

i think you need to generate a wrapper function around the ForwardDiff.derivative function and use that instead:

AD_derivative(f,x) = ForwardDiff.derivative(f,x)
foo_f3(x) = x^2
@register_symbolic AD_derivative(f::Function,x)
@register_symbolic foo_3(x)
Symbolics.derivative(::typeof(foo_f3), args::NTuple{1,Any}, ::Val{1})  = AD_derivative(foo_x3,args[1])

you want to stop the tracing the original function and the derivative.

1 Like

Chef’s Kiss
Thanks!

And we could add an option to make this automatic. I don’t know if I talked with @cryptic.ax yet but it has been on my mind to integrate with DifferentiationInterface and allow for an ADType to setup a DI call.

Interesting. So something like

@register_symbolic foo_f(t) derivative = AutoForwardDiff()

Which registers the function, registers a call to DI for it, and defines the derivative?

1 Like

Yes exactly. And it would be really nice if we could check during registration if a function is say compilable with Enzyme, and just default to that if not otherwise given and if it compiles well. But that check could add a bit of overhead.

1 Like