Callback on Internal Stages of TSIT5

When I step through my simulation, I can see that there are some internal stages in TSIT5 are giving values that are too large. This is to say – the variable is constrained to give values between 0 and 1 but the internal stage is giving something greater than 3. I want to create a callback function that forces recomputation of the internal stages with half the stepsize if this happens. How do I implement such a callback function?

https://docs.sciml.ai/OrdinaryDiffEq/stable/nonstiff/explicitrk/#OrdinaryDiffEq.Tsit5

This is what the stage limiters are for.

oye, this is a bit terse – I’m a newbie.

Is there an example somewhere of how to implement a stage_limiter? My guess is:

using DifferentialEquations

# Define your differential equation and initial condition
function f(dx, x, p, t)
    dx[1] = ... 
end
x0 = [initial_value] 
tspan = (t_initial, t_final) 
prob = ODEProblem(f, x0, tspan)

# Define a stage_limiter function
function my_stage_limiter(u, t, integrator)
    TestVals = [el[1,2] for el in integrator.k];
    if sum(TestVals .< 1) != 7
         integrator.dtpropose = integrator.dtpropose/2
   end
   return integrator
end

solver = Tsit5(stage_limiter=my_stage_limiter) 
sol = solve(prob, solver)

I think this is wrong though because this doesn’t force recomputing the current step, it just halves the value for the next step.

Stage limiters allow for snapping the value into the domain while keeping order. So you’d for example integrator.u[1] > 0 && integrator.u[1] = 0 and stuff like that.

It’s not you it’s me. And by me I mean that we have a PR in progress to overhaul these docs.

Complete docs of all OrdinaryDiffEq bells and whistles should come rather soon. It’s a process to get it all in there because we expose almost everything we can, and some of that is obscure :sweat_smile:.

1 Like

haha, np np! :smiley: :melting_face:

okay so, suppose I have 3 variables (variables u[2], u[3], and u[4]) I know are between 0 and 1 – then the function should read:

function my_stage_limiter(u, t, integ)
   return integ.u[2] > 0 && integ.u[3] > 0 && integ.u[4] > 0 && integ.u[2] < 1 && integ.u[3] > 1 && integ.u[4] < 1
end

If this evaluates false during any of the internal stages, the step will recompute with a smaller stepsize?

No, stage limiters are for modifying the stage value. We don’t have a mechanism for forcing no violations between stages. @ranocha you had an algorithm you pushed for that though? Maybe we should look into adding something for this case.

1 Like

Is there a way to let all the stage values compute, then check them all for violations. If a violation exists, recompute the stage values for half the stepsize?

The keyword argument isoutofdomain can take a function to reject steps by what comes out after the last stage.

Okay, so it takes a function that returns true or false. If it returns false, the step is recomputed with a smaller stepsize?

nvm, its in the documentation lol

wait, but this takes only u,p,t as variables. Can I also pass the integrator into the function? I need to check the internal stages.

That’s a feature we can consider. It would be a breaking change but it does sound like a reasonable extension.

1 Like

Yes, we have a stage callback (how we call stage limiters in Trixi.jl) to enforce such bounds for our discretizations of conservation laws, see Callbacks · Trixi.jl.

Sounds reasonable to me, too.

I know minimal working examples are a thing – but if you copy and paste this big block then look at the internal stage, you’ll see what I’m trying to debug. I tried breaking up the model for the last couple weeks, but I’m introducing new errors.

######################################################################################

using DifferentialEquations, NaNMath

endd = 20000000.0

#####################################################################
### Parameters

## Sensor Parameters
alphaOffset = .045;
GF =  53;
GS =  3;
GD =  1.0;
gamma = 10e-7;
delta = 10e-4;
sensorFuseDelta = .001;

## Targets
FBar = .25;
SBar = .03; 
DBar = .02;

## Homeostatic Time Scales
tauG = 4000.0;
tauHalfac = 2100.0;
tauS = 2000.0;
tauAlpha = 2000.0;

## Reversal Potentials
EK(t) = -80;
EH = -20.0;
ENa = 30.0;
EL = -50.0;

#####################################################################

### Define generic functions
## Generic Activation/Inactication Curves
sigmd(x,translate,center,reactance) = 1/(1+exp(((x-translate) + center)/reactance));
## Define the time constant function for ion channels
tauX(Volt, CT, DT, AT, BT, halfAc) = CT - DT/(1 + exp(((Volt-halfAc) + AT)/BT));
## Some intrinsic currents require a special time constant function
spectau(Volt, CT, DT, AT, BT, AT2, BT2, halfAc) = CT + DT/(exp(((Volt-halfAc) + AT)/BT) + exp(((Volt-halfAc) + AT2)/BT2));
## Define the ionic currents; q is the exponent of the activation variable m
iIonic(g, m, h, q, Volt, Erev) = g*(m^q)*h*(Volt - Erev);

#####################################################################

# Ion Channel Acitvation/Inactivation Curves
NaMinf(V,trns) = sigmd(V,trns, 25.5, -5.29);  # m^3
NaHinf(V,trns) = sigmd(V,trns, 48.9, 5.18);  # h
CaTMinf(V,trns) = sigmd(V,trns, 27.1, -7.20);  # m^3
CaTHinf(V,trns) = sigmd(V,trns, 32.1, 5.50);  # h
CaSMinf(V,trns) = sigmd(V,trns, 33.0, -8.1);  # m^3
CaSHinf(V,trns) = sigmd(V,trns, 60.0, 6.20);  # h
HMinf(V,trns) = sigmd(V,trns, 70.0, 6.0);  # m
KdMinf(V,trns) = sigmd(V,trns, 12.3, -11.8);  # m^4
KCaMinf(V,trns,IntCa) = (IntCa/(IntCa + 3.0))*sigmd(V,trns, 28.3, -12.6); # m^4
AMinf(V,trns) = sigmd(V,trns, 27.2, -8.70); # m^3
AHinf(V,trns) = sigmd(V,trns, 56.9, 4.90);  # h

# Time Constants (ms)
tauNaM(V,trns) = tauX(V, 1.32, 1.26, 120.0, -25.0, trns);
tauNaH(V,trns) = tauX(V, 0.00, -0.67, 62.9, -10.0, trns)*tauX(V, 1.50, -1.00, 34.9, 3.60, trns);
tauCaTM(V,trns) = tauX(V, 21.7, 21.3, 68.1, -20.5, trns);
tauCaTH(V,trns) = tauX(V, 105.0, 89.8, 55.0, -16.9, trns);
tauCaSM(V,trns) = spectau(V, 1.40, 7.00, 27.0, 10.0, 70.0, -13.0, trns);
tauCaSH(V,trns) = spectau(V, 60.0, 150.0, 55.0, 9.00, 65.0, -16.0, trns);
tauHM(V,trns) = tauX(V, 272.0, -1499.0, 42.2, -8.73, trns);
tauKdM(V,trns) = tauX(V, 7.20, 6.40, 28.3, -19.2, trns);
tauKCaM(V,trns) = tauX(V, 90.3, 75.1, 46.0, -22.7, trns);
tauAM(V,trns) = tauX(V, 11.6, 10.4, 32.9, -15.2, trns);
tauAH(V,trns) = tauX(V, 38.6, 29.2, 38.9, -26.5, trns);

# Calcium Reverse Potential
R = 8.314*1000.0;  # Ideal Gas Constant (*10^3 to put into mV)
temp = 10.0;  # Temperature in Celcius; Temperature of the Sea lawl
z = 2.0;  # Valence of Calcium Ions
Far = 96485.33;  # Faraday's Constant
CaOut = 3000.0;  # Outer Ca Concentration (uM)
ECa(CaIn) = ((R*(273.15 + temp))/(z*Far))*NaNMath.log(CaOut/CaIn);
#ECa(CaIn) = ((R*(273.15 + temp))/(z*Far))*log(CaOut/CaIn);

# Ionic Currents (mV / ms)
iNa(g,m,h,V) = iIonic(g, m, h, 3.0, V, ENa);
iCaT(g,m,h,V,CaIn) = iIonic(g, m, h, 3.0, V, ECa(CaIn));
iCaS(g,m,h,V,CaIn) = iIonic(g, m, h, 3.0, V, ECa(CaIn));
iH(g,m,V) = iIonic(g,  m, 1.0, 1.0, V, EH);
iKd(g,m,V,t) = iIonic(g, m, 1.0, 4.0, V, EK(t));
iKCa(g,m,V,t) = iIonic(g, m, 1.0, 4.0, V, EK(t));
iA(g,m,h,V,t) = iIonic(g, m, h, 3.0, V, EK(t));
iL(V) = iIonic(.01, 1.0, 1.0, 1.0, V, EL);
iApp(t) = 0;

# Sensor Equations
F(FM,FH) = GF*FM^2*FH;
S(SM,SH) = GS*SM^2*SH;
D(DM)    = GD*DM^2; 
sensorFuse(EF,ES,ED) = exp(-(((EF/(sensorFuseDelta*12))^8 + (ES/(sensorFuseDelta*1.5))^8 + (ED/(sensorFuseDelta*.2))^8)^(1/8)));

############################################################################################

function f(du,u,p,t)
    du[1] = -iL(u[1])-iNa(u[14],u[2],u[3],u[1])-iCaT(u[15],u[4],u[5],u[1],u[13])-iCaS(u[16],u[6],u[7],u[1],u[13])-iH(u[17],u[8],u[1])-iKd(u[18],u[9],u[1],t)-iKCa(u[19],u[10],u[1],t)-iA(u[20],u[11],u[12],u[1],t)
    du[2]  = (NaMinf(u[1],u[21]) - u[2])/tauNaM(u[1],u[21])
    du[3]  = (NaHinf(u[1],u[22]) - u[3])/tauNaH(u[1],u[22])
    du[4]  = (CaTMinf(u[1],u[23]) - u[4])/tauCaTM(u[1],u[23])
    du[5]  = (CaTHinf(u[1],u[24]) - u[5])/tauCaTH(u[1],u[24])
    du[6]  = (CaSMinf(u[1],u[25]) - u[6])/tauCaSM(u[1],u[25])
    du[7]  = (CaSHinf(u[1],u[26]) - u[7])/tauCaSH(u[1],u[26])
    du[8]  = (HMinf(u[1],u[27]) - u[8])/tauHM(u[1],u[27])
    du[9]  = (KdMinf(u[1],u[28]) - u[9])/tauKdM(u[1],u[28])
    du[10] = (KCaMinf(u[1],u[29],u[13]) - u[10])/tauKCaM(u[1],u[29])
    du[11] = (AMinf(u[1],u[30]) - u[11])/tauAM(u[1],u[30])
    du[12] = (AHinf(u[1],u[31]) - u[12])/tauAH(u[1],u[31])
    du[13] = (-.94*(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13]))-u[13]+.05)/20
    du[14] = ((1*(FBar-F(u[32],u[33])) + 0*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))*u[14]-gamma*u[14]^3)*u[40]/tauG
    du[15] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))*u[15]-gamma*u[15]^3)*u[40]/tauG
    du[16] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))*u[16]-gamma*u[16]^3)*u[40]/tauG
    du[17] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 1*(DBar-D(u[36])))*u[17]-gamma*u[17]^3)*u[40]/tauG
    du[18] = ((1*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))*u[18]-gamma*u[18]^3)*u[40]/tauG
    du[19] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + -1*(DBar-D(u[36])))*u[19]-gamma*u[19]^3)*u[40]/tauG
    du[20] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + -1*(DBar-D(u[36])))*u[20]-gamma*u[20]^3)*u[40]/tauG
    du[21] = ((-1*(FBar-F(u[32],u[33])) + 0*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[21]^3)*u[40]/tauHalfac
    du[22] = ((1*(FBar-F(u[32],u[33])) + 0*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[22]^3)*u[40]/tauHalfac
    du[23] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[23]^3)*u[40]/tauHalfac
    du[24] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[24]^3)*u[40]/tauHalfac
    du[25] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[25]^3)*u[40]/tauHalfac
    du[26] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[26]^3)*u[40]/tauHalfac
    du[27] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + -1*(DBar-D(u[36])))-delta*u[27]^3)*u[40]/tauHalfac
    du[28] = ((-1*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 0*(DBar-D(u[36])))-delta*u[28]^3)*u[40]/tauHalfac
    du[29] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 1*(DBar-D(u[36])))-delta*u[29]^3)*u[40]/tauHalfac
    du[30] = ((0*(FBar-F(u[32],u[33])) + 1*(SBar-S(u[34],u[35])) + 1*(DBar-D(u[36])))-delta*u[30]^3)*u[40]/tauHalfac
    du[31] = ((0*(FBar-F(u[32],u[33])) + -1*(SBar-S(u[34],u[35])) + -1*(DBar-D(u[36])))-delta*u[31]^3)*u[40]/tauHalfac
    du[32] = (sigmd(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13]),0,14.8,1) - u[32])/.5
    du[33] = (sigmd(-1*(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13])),0,-9.8,1) - u[33])/1.5
    du[34] = (sigmd(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13]),0,7.2,1) - u[34])/50
    du[35] = (sigmd(-1*(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13])),0,-2.8,1) - u[35])/60
    du[36] = (sigmd(iCaT(u[15],u[4],u[5],u[1],u[13])+iCaS(u[16],u[6],u[7],u[1],u[13]),0,3,1) - u[36])/500
    du[37] = ((FBar - F(u[32],u[33])) - u[37])/tauS
    du[38] = ((SBar - S(u[34],u[35])) - u[38])/tauS
    du[39] = ((DBar - D(u[36])) - u[39])/tauS
    du[40] = (sigmd(-1*sensorFuse(u[37],u[38],u[39]),0,alphaOffset,-.001)-u[40])/tauAlpha
end

############################################################################################

#Start
ICs = [-29.337327051014498 0.387305560452983 0.520851480339422 0.15123291866400054 0.7788086343207855 0.18292525224066983 0.40518167654812637 0.40305741894377733 0.09880328448541692 0.1192742729813851 0.17090849597998528 0.4853554113680752 2.455375230646913 79.93706465236964 0.2873998785854699 0.0318308211031974 0.15906997203306664 99.67308490480416 0.8899609102246536 2.6818203693916858 -6.016940524541976 6.0497029078690145 -0.08934230228658809 0.1845828676421055 0.11586968152835098 0.7364356646257856 2.4826107420682773 -6.238217364811495 -1.966199265610444 -2.1624666424135977 2.014861831075435 4.143056495983624e-7 0.999901341448654 0.16812145388254565 0.5889709108552483 0.4212141706080315 0.24284481424824048 -0.03505188961118154 -0.16940390920534412 1.0];

tspan = (0.0,endd);
prob = ODEProblem(f,ICs,tspan);
function dmnCheck(u,p,t)
    return u[1] > 200.0 || u[1] < -200.0 
end
integ = init(prob,Tsit5(),maxiters=1e10,isoutofdomain=dmnCheck);
solve!(integ)

You’ll see that isoutofdomain isn’t catching the failure. If you look at integ.k you should see an NaN. Some of those other values are also whacky. For instance, 1.5e74 is kinda big and u[2] should be between 0 and 1.

julia> integ.k
7-element Vector{Matrix{Float64}}:
 [369.6253619957377 4.780405955884058 … 6.00843666970291e-6 0.0]
 [558.5762304399766 3.6783055725425826 … 6.01891607017653e-6 0.0]
 [163.5732051281901 2.7015515288225345 … 6.02970035701271e-6 0.0]
 [19068.06445943883 -0.7991962924634176 … 6.066390790152989e-6 0.0]
 [-6.30053467595449e17 -0.8521960049675845 … 6.086972935145541e-6 0.0]
 [1.584810201744311e74 -2.2604713798417015 … 6.089445492965473e-6 0.0]
 [NaN 13.454079037895752 … 6.420614871497879e-6 0.0]

Is there some workaround I can use right now that’s apparent to y’all? Or suggest what to modify directly in the code base so I can hack this thing and get it working?

The k are the derivatives, not the stage values.

Wait really? Is there some documentation somewhere on this stuff? I’ve just been guessing what the outputs of fieldnames(typeof(integ)) to try to debug.

Sidenote: I also used this for a long time time to inspect fields/properties. It turns out that the base funtion propertynames more or less does the same, and it feels more appropriate to me.

1 Like

Anyway, there should be bounds on derivatives too, like some of these must be nonnegative and some of these must be nonpositive. u[2] should not then exhibit behavior that is both positive and negative in integ.k.

There is no intention to document things that are very internal to the algorithms. You can poke around all you want since the code is open (that’s why it’s open) and interpret it to use it at your own risk (though I will say the k stuff hasn’t changed in about 6 years for the ERK methods so it’s pretty stable).

Note that it’s not the derivatives for all algorithms so use it at your own risk and check what it means.

Though you’re trying to accomplish something rather odd. I think a quick fix would just be to make a dispatch of a version of the method where the limiter is passed the integrator and that should be sufficient for what you need.

I mean I tried passing a limiter to the integrator to check the u’s, it’s not working. I’m not sure what else to try.

Im trying to figure out where the NaN is computed in integ.k.

Where in the code is the tsit5 code implemented and where is integ.k loaded? Or if you don’t know off the top of your head, where would you start looking?

I looked up tsit5 as a string literal in the sciml code base, but that wasn’t helpful.