Questions about modeling Fokker-Planck equation

I’m interested in solving a certain Fokker-Planck equation. The goal is to learn an underlying potential energy function given a time series of the evolution of a distribution.

I’m planning to use the approach of this paper. Briefly, they discretize the distribution in space and set up a master equation describing hops between adjacent grid points where the rate is governed by the difference in energy. This results in physically and thermodynamically consistent evolution of the distribution.

To predict the dynamics at some time t, they use the matrix exponential:

p(x, t) = \exp(R(t-t_0)) p(x, t_0).

This has the nice property of ensuring normalization and positivity given a p(x, t_0) with those properties. This is critically important, as p is a probability density.

That said, I have two concerns about this.

First, the rate matrix R will be very sparse, but it’s exponential will be dense, resulting in a huge blow-up in memory usage and limiting problem sizes.

Second, my understanding is that matrix exponential uses BLAS. Will it be possible to differentiate through this? I.e. is there an adjoint defined?

Alternatively, I could imagine solving the master equation using a standard ODE routine. Is there any way to ensure normalization and positivity? Is there a preferred solver for this type of problem?

Finally, setting up the rate matrix likely involves all sorts of indexing gymnastics, as the discrete grid needs to be flattened into a vector to express the problem in matrix form. Are there any routines to help make this less error-prone?


Just do what’s done in Optimizing DiffEq Code with upwinding on the advection and it should be rather simple to solve. With exponentials isn’t much better, and you’d want to use ExponentialUtiltities.jl etc. instead of straight matrix exponentiation if you want to get any speed.

Thanks for the reply.

After doing some more research, the Chang-Cooper discretization (also called exponential upwinding) seems to be a standard approach to this problem. It maintains conservation and positivity, which is critical for this application.

I should be able to implement this following your tutorial link.

One more question: I’m ultimately trying to learn the potential function, so I’ll need to back propagate through all of this. In the past, I’ve had lots of difficulty with this whenever in-place operations are used. Is this somehting that I should be concerned about?