Hi,
I’m trying to use Enzyme.jl to compute gradients for a nonlinear least squares objective function I’d like to optimize for data assimilation (to my understanding this is equivalent to what is known as weak 4DVar). The objective function discretizes a differential equation using Simpson-Hermite quadrature, and is defined in the following code:
module SimpsonHermite
@inline function simpson_residual(xi, xmi, xip, p, fi, fmi, fip, dt, vf, Rf)
# vf = vf(u,p,force) is vector field
# dt = discretization timestep
# xi = state vector at time t_i
# xmi = state vector at time t_{i + 1/2}
# xip = state vector at time t_{i+1}
# p = parameter vector
# fi = forcing (time-dependent parameters) vector (or scalar) at time t_i
# fmi = forcing at time t_{i + 1/2}
# fip = forcing at time t_{i+1}
# Rf = vector of weights for each coordinate of the residual
Rf .* (xip - xi - (dt/6)*(vf(xi,p,fi) + 4*vf(xmi,p,fmi) + vf(xip,p,fip)))
end
@inline function hermite_residual(xi, xmi, xip, p, fi, fmi, fip, dt, vf, Rf)
Rf .* (xmi - ((xi + xip)/2 + (dt/8)*(vf(xi,p,fi) - vf(xip,p,fip))))
end
function gen_simpson_hermite_objective(vf, dt, frec, data, Rf, D, P)
#dt should have length N and contains all the timesteps
#frec should have length (2N+1)*F and has all the driving forces
#data should have length (2N+1)*L and has all the observed data
#F = num driving forces, L = num observed dims
N = length(dt);
F = length(frec) ÷ (2N+1)
L = length(data) ÷ (2N+1)
function obj(x)
@views begin
p = x[(2N+1)*D+1:(2N+1)*D+P]
obj_model = 0.0
obj_data = 0.0
for i=1:N
#squared Simpson residuals
obj_model += sum(y^2 for y in simpson_residual(x[2D*(i-1)+1:2D*(i-1)+D], x[2D*(i-1)+D+1:2D*(i-1)+2D], x[2D*(i-1)+2D+1:2D*(i-1)+3D],
p, frec[2F*(i-1)+1:2F*(i-1)+F], frec[2F*(i-1)+F+1:2F*(i-1)+2F], frec[2F*(i-1)+2F+1:2F*(i-1)+3F], dt[i], vf, Rf))
#squared Hermite residuals
obj_model += sum(y^2 for y in hermite_residual(x[2D*(i-1)+1:2D*(i-1)+D], x[2D*(i-1)+D+1:2D*(i-1)+2D], x[2D*(i-1)+2D+1:2D*(i-1)+3D],
p, frec[2F*(i-1)+1:2F*(i-1)+F], frec[2F*(i-1)+F+1:2F*(i-1)+2F], frec[2F*(i-1)+2F+1:2F*(i-1)+3F], dt[i], vf, Rf))
obj_data += sum(y^2 for y in (x[2D*(i-1)+1:2D*(i-1)+L] - data[2L*(i-1)+1:2L*(i-1)+L]))
obj_data += sum(y^2 for y in (x[2D*(i-1)+D+1:2D*(i-1)+D+L] - data[2L*(i-1)+L+1:2L*(i-1)+2L]))
end
obj_data += sum(y^2 for y in (x[2*N*D+1:2*N*D+L] - data[2*N*L+1:2*N*L+L]))
end
return obj_data/((2*N+1)*L) + obj_model/((2*N+1)*D)
end
end
end
I then generate data from the Lorenz96 model, use this to build the objective, and compute the objective at a random point within box constraints of interest:
u0 = [rand()*(ub[i] - lb[i]) + lb[i] for i=1:length(lb)]; #random initialization for first step
step_obj = SimpsonHermite.gen_simpson_hermite_objective(lor96, dt, frec, data, Rf, D, P) #frec, data, Rf, D, and P are all initialized to some constant values based on the generated Lorenz96 data
step_obj(u0) #returns ~862.68
Now, I compute the gradient using Enzyme:
u0_enz = copy(u0);
bu0_enz = zeros(length(u0));
Enzyme.autodiff(Reverse, step_obj, Duplicated(u0_enz, bu0_enz))
Surprisingly, if I now recompute the objective at u0
by calling step_obj
, I receive a different result!
step_obj(u0) # now returns 9.241999372228449e9
Each time I run the code block with the Enzyme autodiff call, the step_obj
seems to be modified once more, and eventually returns Inf
or NaN
results. Regenerating the objective function by rerunning the line step_obj = SimpsonHermite.gen_simpson_hermite_objective(...)
does not fix this problem; only by re-including the SimpsonHermite
module, and then rerunning the objective generation function, does the objective go back to returning correct values.
This is very strange, and I’m not sure if it’s a bug or if I’m doing something incorrect. Any help would be greatly appreciated!