Trying to Solve a 2nd Order ODE and need advice

Dear Community
I’m new to Julia and recently trying to implement the algorithms from Direct Runge-Kutta Discretization
Achieves Acceleration
in which I need to solve an 2nd order ODE

\ddot x(t) + \frac{2p+1}{t}\dot x(t)+p^2t^{p-2}\nabla f(x(t))=0

where \nabla f(x(t)) is the gradient of the loss function. I tried to write code with DifferentialEquations.jl as below

using Random, Distributions, DifferentialEquations, Plots, LinearAlgebra

# Data generation
Random.seed!(12)
num_observations = 3
clusters = 1
dim = 10
spread = 10
margin = rand(0:10,dim)*50

C = Diagonal(ones(dim))
x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations)

for i in 1:clusters
    if i == 1
        x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
    else
        x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x1 = cat(x1, x1_temp)
        x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
        x2 = cat(x2_temp)
    end
end

simulated_separableish_features = vcat(x1,x2)
simulated_labels = vcat(zeros(size(x1)[1]), ones(size(x2)[1]))
simulated_labels = reshape(simulated_labels,(length(simulated_labels),1))
intercept = ones(size(simulated_separableish_features)[1], 1)
simulated_separableish_features = hcat(intercept, simulated_separableish_features)
dim+=1

# Loss and gradient of loss
function loss(X, Y, w)
    pred = Y - X * w
    return pred' * pred
end
function grad(X, Y, w)
    pred = X * w - Y
    return transpose(X) * pred
end

f = x->loss(simulated_separableish_features, simulated_labels, x)
df = x->grad(simulated_separableish_features, simulated_labels, x)

# Simulate ODE flow
function ODE(du, u, p, t)
    return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end

function simulate(p, s, N, period, step, xinit, vinit)
    h = step
    if s==1
        alg = Euler()
    elseif s==2
        alg = SSPRK2()
    elseif s==4
        alg = ERKN4()
    elseif s==8
        alg = DPRK8()
    else
        return nothing
    end
    x = xinit
    v = vinit
    loop = Int(N/period)
    losses = zeros(loop)

    soln=nothing
    
    for j in 1:loop
        losses[j] = f(x)[1]
        if j%1000 == 0
            println("Loop="*string(j)*"\t Loss="*string(losses[j]))
        end
        t_start = 0.01 + (j-1)*h*period
        t_end = 0.01 + j*h*period
        tspan = (t_start, t_end)
        prob = SecondOrderODEProblem(ODE, v, x, tspan, p)
        soln = solve(prob, alg, dt=h, adaptive=false)
        x = soln[12:22, period+1]
        v = soln[1:11, period+1]
    end
    return losses, soln
end

N = 100000
truncate = 0
period = 10
pvalues = [2,3,4,5]
#nrange = period * range(1, Int(N/period)) - period + 1

svals = [4,4,4,4]
steps = [1e-7, 1e-7, 1e-7, 1e-7]

weights = reshape(rand(dim),(dim,1))
dweights = zeros(dim,1)

losslst = zeros(length(svals),Int(N/period))

for i in 1:length(svals)
    s = svals[i]
    p = pvalues[i]
    println("p="*string(p))
    p = [p]
    losses, soln = simulate(p, s, N, period, steps[i], weights, dweights)
    losslst[i,:] = losses
    plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log)
end

But the weight (x in function simulation) is not updated correctly as expected and the loss function remains unchanged. I cannot figure out what’s wrong and looking forward to your advices.

1 Like

Forgot to mention that I tried to change the initial value of du and the equation is solved, though loss keeps rising instead of decreasing. Besides, in this case, du (visited by soln[1:11,:]) still remained unchanged.

You did an in-place ODE definition on an out of place function. Write this as:

function ODE(u, p, t)
    return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end

According to the documentation of SecondOrderODEProblem, f should be specified as f(du,u,p,t) (or in-place as f(ddu,du,u,p,t) ) and it seems that I’m not making mistake here.

The sizes of your things were all wonky. In fact, on latest versions of the packages it actually throws an error about how some of the “vectors” are matrices. If you fix your vector sizes it works, like this:

using Random, Distributions, DifferentialEquations, LinearAlgebra # Plots

# Data generation
Random.seed!(12)
num_observations = 3
clusters = 1
dim = 10
spread = 10
margin = rand(0:10,dim)*50

C = Diagonal(ones(dim))
x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations)

for i in 1:clusters
    if i == 1
        x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
    else
        x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x1 = cat(x1, x1_temp)
        x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
        x2 = cat(x2_temp)
    end
end

simulated_separableish_features = vcat(x1,x2)
simulated_labels = vcat(zeros(size(x1)[1]), ones(size(x2)[1]))
simulated_labels = reshape(simulated_labels,(length(simulated_labels),1))
intercept = ones(size(simulated_separableish_features)[1], 1)
simulated_separableish_features = hcat(intercept, simulated_separableish_features)
dim+=1

# Loss and gradient of loss
function loss(X, Y, w)
    pred = Y - X * w
    return pred' * pred
end
function grad(X, Y, w)
    pred = X * w - Y
    return vec(transpose(X) * pred)
end

f = x->loss(simulated_separableish_features, simulated_labels, x)
df = x->grad(simulated_separableish_features, simulated_labels, x)

# Simulate ODE flow
function ODE(du, u, p, t)
    return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end

function simulate(p, s, N, period, step, xinit, vinit)
    h = step
    if s==1
        alg = Euler()
    elseif s==2
        alg = SSPRK2()
    elseif s==4
        alg = ERKN4()
    elseif s==8
        alg = DPRK8()
    else
        return nothing
    end
    x = xinit
    v = vinit
    loop = Int(N/period)
    losses = zeros(loop)

    soln=nothing
    
    for j in 1:loop
        losses[j] = f(x)[1]
        if j%1000 == 0
            println("Loop="*string(j)*"\t Loss="*string(losses[j]))
        end
        t_start = 0.01 + (j-1)*h*period
        t_end = 0.01 + j*h*period
        tspan = (t_start, t_end)
        prob = SecondOrderODEProblem(ODE, v, x, tspan, p)
        soln = solve(prob, alg, dt=h, adaptive=false)
        x = soln[12:22, period+1]
        v = soln[1:11, period+1]
    end
    return losses, soln
end

N = 100000
truncate = 0
period = 10
pvalues = [2,3,4,5]
#nrange = period * range(1, Int(N/period)) - period + 1

svals = [4,4,4,4]
steps = [1e-7, 1e-7, 1e-7, 1e-7]

weights = reshape(rand(dim),(dim))
dweights = zeros(dim)

losslst = zeros(length(svals),Int(N/period))

for i in 1:length(svals)
    s = svals[i]
    p = pvalues[i]
    println("p="*string(p))
    p = [p]
    losses, soln = simulate(p, s, N, period, steps[i], weights, dweights)
    losslst[i,:] = losses
    #plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log)
end
p=2
Loop=1000        Loss=1.1370352296014028e6
Loop=2000        Loss=558700.795600677
Loop=3000        Loss=210969.5668560976
Loop=4000        Loss=46713.953892804886
Loop=5000        Loss=869.1600901123418
Loop=6000        Loss=14136.22018135579
Loop=7000        Loss=43717.592335439964
Loop=8000        Loss=65406.289572243666
Loop=9000        Loss=70587.38391429758
Loop=10000       Loss=60986.03885638417
p=3
Loop=1000        Loss=1.529397566418039e6
Loop=2000        Loss=788531.2931497109
Loop=3000        Loss=231283.21188213775
Loop=4000        Loss=14019.069164400047
Loop=5000        Loss=18898.1528855852
Loop=6000        Loss=63363.246270909964
Loop=7000        Loss=57764.30638731056
Loop=8000        Loss=21374.125887804494
Loop=9000        Loss=980.3786965814796
Loop=10000       Loss=4972.182386471314
p=4
Loop=1000        Loss=1.9257352546918865e6
Loop=2000        Loss=1.9068147812419895e6
Loop=3000        Loss=1.879136151356047e6
Loop=4000        Loss=1.8431964827807134e6
Loop=5000        Loss=1.7987728502450513e6
Loop=6000        Loss=1.7454295328677325e6
Loop=7000        Loss=1.6827463787847036e6
Loop=8000        Loss=1.610446201562994e6
Loop=9000        Loss=1.528482124362498e6
Loop=10000       Loss=1.4371053342682128e6
p=5
Loop=1000        Loss=1.9332291461814602e6
Loop=2000        Loss=1.9329330522603847e6
Loop=3000        Loss=1.9324767525664547e6
Loop=4000        Loss=1.931840315207187e6
Loop=5000        Loss=1.9309874349720925e6
Loop=6000        Loss=1.9298718518708802e6
Loop=7000        Loss=1.9284392058824003e6
Loop=8000        Loss=1.9266274683366583e6
Loop=9000        Loss=1.9243669102391053e6
Loop=10000       Loss=1.9215799306837576e6

Also, it’s a bit odd that some of the s== choices are not real ODE solvers?

1 Like

Thank you for your help! The code works now!