# Problem with using ODE solvers with autodifferentiation

I am using the DifferentialEquations package to simulate Hodgkin Huxley neurons. This involves evaluating certain parameters inside the ODE function using dependent variables for which I need to make function calls from within the ODE function, e.g.
V = u[1];
[param1, param2] = my_function(V);
du[1] = (V - param1) / param2;

To speed up the code I have made my_function accept only Float64 by type assertion:

function my_function(V::Float64)
#operations
end

The problem is while using certain solvers that use autodiff (e.g. QNDF() or Rodas4()) I get an error when instead of passing a Float64 a Dual type number is passed. The solution here describes two workarounds: 1st is obvious (turn off autodiff). However i don’t understand the second. Could anyone please explain it to me? If I just turn off autodiff the accuracy of the solution will take a hit right?
Thanks

Restricting the input types has no effect on performance, just remove the restriction and you’ll be fine

https://docs.julialang.org/en/v1/manual/performance-tips/index.html

See “type declaratios”

5 Likes

Actually it does seem to make a difference. Please see attached screenshot:
I used QNDF() from the DifferentialEquations package: without any type assertions on the input to the functions computing the parameters (hence with autodiff=true) vs. with Float64 type assertions (hence autodiff=false). However at the expense of increased time, the solution (neuron membrane voltage plot in the plot window) did not drift off with autodiff turned on (blue=autodiff, orange=no autodiff). So assuming autodiff gives the better solution, i guess I have to compromise on the time and memory effeciency.
Thanks

Type assertions in function signatures don’t do anything for performance. Functions autospecialize on types. That’s the very essence of why Julia works.

Thanks for clarifying. So is it using/not using autodiff that makes the difference here?

Yes. It has to build double caches since it needs to use `(m+1)*n` sized caches for `m` partials

Thanks a lot for the help!

You can probably use `StaticArrays.jl` to reduce the difference in timing. Indeed, you seem to have a low dimensional ODE.

2 Likes

Thanks for the advice, I am new to Julia and am learning Static Array implementation…hoping for the best.