I am playing with using Enzyme.jl to try to get the derivative of some numerical function with respect to the initial condition. The function is solved using DifferentialEquations.jl. For example:
using DifferentialEquations
using Enzyme
"""
The differential equation
"""
function lorenz!(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
nothing
end
"""
Given initial conditions a,b,c, integrate the system lorenz!()
"""
function timestepping(a,b,c)
u0 = [a,b,c]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz!,u0,tspan)
ode_solution = solve(prob,Tsit5())
return ode_solution
end
"""
An arbitrary loss function that returns the final value of the `a` variable at the end of the integration
"""
function loss_function(a,b,c)
#Do an integration
solution = timestepping(a,b,c)
#Get just the last element of the first variable as an arbitrary loss
critical_value = last(solution[1,:])
return critical_value
end
a = 1.0
b = 2.0
c = 3.0
#Just run once, make sure everything works OK
first_attempt = loss_function(a,b,c)
#Now try to get a derivative
second_attempt = Enzyme.autodiff(Forward,loss_function,Duplicated,Duplicated(a,1.0),Const(b),Const(c))
i.e. I am trying to get the derivative with respect to the initial condition a
, holding all other variables constant.
However this throws a LoadError: Tuple{DataType} is not a dispatch tuple
which seems to be related to the GPUCompiler.specialization_id(job)
call. Can anyone lend some guidance on where I am going wrong here? Thanks!