Trying to Solve a 2nd Order ODE and need advice

The sizes of your things were all wonky. In fact, on latest versions of the packages it actually throws an error about how some of the “vectors” are matrices. If you fix your vector sizes it works, like this:

using Random, Distributions, DifferentialEquations, LinearAlgebra # Plots

# Data generation
Random.seed!(12)
num_observations = 3
clusters = 1
dim = 10
spread = 10
margin = rand(0:10,dim)*50

C = Diagonal(ones(dim))
x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations)

for i in 1:clusters
    if i == 1
        x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
    else
        x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
        x1 = cat(x1, x1_temp)
        x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
        x2 = cat(x2_temp)
    end
end

simulated_separableish_features = vcat(x1,x2)
simulated_labels = vcat(zeros(size(x1)[1]), ones(size(x2)[1]))
simulated_labels = reshape(simulated_labels,(length(simulated_labels),1))
intercept = ones(size(simulated_separableish_features)[1], 1)
simulated_separableish_features = hcat(intercept, simulated_separableish_features)
dim+=1

# Loss and gradient of loss
function loss(X, Y, w)
    pred = Y - X * w
    return pred' * pred
end
function grad(X, Y, w)
    pred = X * w - Y
    return vec(transpose(X) * pred)
end

f = x->loss(simulated_separableish_features, simulated_labels, x)
df = x->grad(simulated_separableish_features, simulated_labels, x)

# Simulate ODE flow
function ODE(du, u, p, t)
    return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end

function simulate(p, s, N, period, step, xinit, vinit)
    h = step
    if s==1
        alg = Euler()
    elseif s==2
        alg = SSPRK2()
    elseif s==4
        alg = ERKN4()
    elseif s==8
        alg = DPRK8()
    else
        return nothing
    end
    x = xinit
    v = vinit
    loop = Int(N/period)
    losses = zeros(loop)

    soln=nothing
    
    for j in 1:loop
        losses[j] = f(x)[1]
        if j%1000 == 0
            println("Loop="*string(j)*"\t Loss="*string(losses[j]))
        end
        t_start = 0.01 + (j-1)*h*period
        t_end = 0.01 + j*h*period
        tspan = (t_start, t_end)
        prob = SecondOrderODEProblem(ODE, v, x, tspan, p)
        soln = solve(prob, alg, dt=h, adaptive=false)
        x = soln[12:22, period+1]
        v = soln[1:11, period+1]
    end
    return losses, soln
end

N = 100000
truncate = 0
period = 10
pvalues = [2,3,4,5]
#nrange = period * range(1, Int(N/period)) - period + 1

svals = [4,4,4,4]
steps = [1e-7, 1e-7, 1e-7, 1e-7]

weights = reshape(rand(dim),(dim))
dweights = zeros(dim)

losslst = zeros(length(svals),Int(N/period))

for i in 1:length(svals)
    s = svals[i]
    p = pvalues[i]
    println("p="*string(p))
    p = [p]
    losses, soln = simulate(p, s, N, period, steps[i], weights, dweights)
    losslst[i,:] = losses
    #plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log)
end
p=2
Loop=1000        Loss=1.1370352296014028e6
Loop=2000        Loss=558700.795600677
Loop=3000        Loss=210969.5668560976
Loop=4000        Loss=46713.953892804886
Loop=5000        Loss=869.1600901123418
Loop=6000        Loss=14136.22018135579
Loop=7000        Loss=43717.592335439964
Loop=8000        Loss=65406.289572243666
Loop=9000        Loss=70587.38391429758
Loop=10000       Loss=60986.03885638417
p=3
Loop=1000        Loss=1.529397566418039e6
Loop=2000        Loss=788531.2931497109
Loop=3000        Loss=231283.21188213775
Loop=4000        Loss=14019.069164400047
Loop=5000        Loss=18898.1528855852
Loop=6000        Loss=63363.246270909964
Loop=7000        Loss=57764.30638731056
Loop=8000        Loss=21374.125887804494
Loop=9000        Loss=980.3786965814796
Loop=10000       Loss=4972.182386471314
p=4
Loop=1000        Loss=1.9257352546918865e6
Loop=2000        Loss=1.9068147812419895e6
Loop=3000        Loss=1.879136151356047e6
Loop=4000        Loss=1.8431964827807134e6
Loop=5000        Loss=1.7987728502450513e6
Loop=6000        Loss=1.7454295328677325e6
Loop=7000        Loss=1.6827463787847036e6
Loop=8000        Loss=1.610446201562994e6
Loop=9000        Loss=1.528482124362498e6
Loop=10000       Loss=1.4371053342682128e6
p=5
Loop=1000        Loss=1.9332291461814602e6
Loop=2000        Loss=1.9329330522603847e6
Loop=3000        Loss=1.9324767525664547e6
Loop=4000        Loss=1.931840315207187e6
Loop=5000        Loss=1.9309874349720925e6
Loop=6000        Loss=1.9298718518708802e6
Loop=7000        Loss=1.9284392058824003e6
Loop=8000        Loss=1.9266274683366583e6
Loop=9000        Loss=1.9243669102391053e6
Loop=10000       Loss=1.9215799306837576e6

Also, it’s a bit odd that some of the s== choices are not real ODE solvers?

1 Like