Compile FastDifferentiation derivatives one time only

Is it though? The type FunctionWithPreppedGradient is owned by the user. So it is not type piracy in a strict sense I think. Depending on the dispatch rules inside DifferentiationInterface (which you are the expert of :slight_smile: ) it could cause method ambiguity problems and perhaps the signature needs to be more specific.

In this case it is convenient to define this additional method then the resulting code would work independently of whether the gradient has been prepared or not (if I understand correctly from the documentation that value_and_gradient does not require the prepared thing).

Thinking this over while typing, I agree that it probably would be better to define a custom function and just forward to DifferentiationInterface.jl.

1 Like

Sorry yes, not piracy but it will definitely cause ambiguities because DI specializes on the backend and not the function, whereas this goes the other way.

1 Like

I think you could make this type FunctionWithPreppedGradient in a way that it lazily saves the prepared hessian/gradient/etc on first call. Then it would not be very cumbersome I think? You’d just have to wrap each of the function into that struct and the rest of the code would look like the Tensorial.jl code, no? I don’t think you could get any closer - at least not easily.

But if your version with Tensorial’s custom differentiation operators works well, maybe there’s no need for more

Sorry for the late reply. If I understand correctly, I would still need to pass a FunctionWithPreppedGradient object instead of prep to each function. In that case, I don’t think it’s a fundamental solution, because if I want to call func4 from within func3 , for example, I would need to modify all the other functions as well. I think that structure is not good.

With a static array (or tensor) you can deduce the size from the type, so maybe a macro inside a generated function could be enough? I don’t know if that’s possible, just grasping at straws here.

Thank you always for your thoughtful suggestion. However, to construct the framework I want, it seems essential to evaluate the given function f at compile-time within the derivative(f,x) function, for example. This approach appears to be infeasible. Alternatively, the preparation process should return a bitstype similar to ForwardDiff.jl with SArray case.