The sizes of your things were all wonky. In fact, on latest versions of the packages it actually throws an error about how some of the “vectors” are matrices. If you fix your vector sizes it works, like this:
using Random, Distributions, DifferentialEquations, LinearAlgebra # Plots
# Data generation
Random.seed!(12)
num_observations = 3
clusters = 1
dim = 10
spread = 10
margin = rand(0:10,dim)*50
C = Diagonal(ones(dim))
x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations)
for i in 1:clusters
if i == 1
x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
else
x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
x1 = cat(x1, x1_temp)
x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
x2 = cat(x2_temp)
end
end
simulated_separableish_features = vcat(x1,x2)
simulated_labels = vcat(zeros(size(x1)[1]), ones(size(x2)[1]))
simulated_labels = reshape(simulated_labels,(length(simulated_labels),1))
intercept = ones(size(simulated_separableish_features)[1], 1)
simulated_separableish_features = hcat(intercept, simulated_separableish_features)
dim+=1
# Loss and gradient of loss
function loss(X, Y, w)
pred = Y - X * w
return pred' * pred
end
function grad(X, Y, w)
pred = X * w - Y
return vec(transpose(X) * pred)
end
f = x->loss(simulated_separableish_features, simulated_labels, x)
df = x->grad(simulated_separableish_features, simulated_labels, x)
# Simulate ODE flow
function ODE(du, u, p, t)
return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end
function simulate(p, s, N, period, step, xinit, vinit)
h = step
if s==1
alg = Euler()
elseif s==2
alg = SSPRK2()
elseif s==4
alg = ERKN4()
elseif s==8
alg = DPRK8()
else
return nothing
end
x = xinit
v = vinit
loop = Int(N/period)
losses = zeros(loop)
soln=nothing
for j in 1:loop
losses[j] = f(x)[1]
if j%1000 == 0
println("Loop="*string(j)*"\t Loss="*string(losses[j]))
end
t_start = 0.01 + (j-1)*h*period
t_end = 0.01 + j*h*period
tspan = (t_start, t_end)
prob = SecondOrderODEProblem(ODE, v, x, tspan, p)
soln = solve(prob, alg, dt=h, adaptive=false)
x = soln[12:22, period+1]
v = soln[1:11, period+1]
end
return losses, soln
end
N = 100000
truncate = 0
period = 10
pvalues = [2,3,4,5]
#nrange = period * range(1, Int(N/period)) - period + 1
svals = [4,4,4,4]
steps = [1e-7, 1e-7, 1e-7, 1e-7]
weights = reshape(rand(dim),(dim))
dweights = zeros(dim)
losslst = zeros(length(svals),Int(N/period))
for i in 1:length(svals)
s = svals[i]
p = pvalues[i]
println("p="*string(p))
p = [p]
losses, soln = simulate(p, s, N, period, steps[i], weights, dweights)
losslst[i,:] = losses
#plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log)
end
p=2
Loop=1000 Loss=1.1370352296014028e6
Loop=2000 Loss=558700.795600677
Loop=3000 Loss=210969.5668560976
Loop=4000 Loss=46713.953892804886
Loop=5000 Loss=869.1600901123418
Loop=6000 Loss=14136.22018135579
Loop=7000 Loss=43717.592335439964
Loop=8000 Loss=65406.289572243666
Loop=9000 Loss=70587.38391429758
Loop=10000 Loss=60986.03885638417
p=3
Loop=1000 Loss=1.529397566418039e6
Loop=2000 Loss=788531.2931497109
Loop=3000 Loss=231283.21188213775
Loop=4000 Loss=14019.069164400047
Loop=5000 Loss=18898.1528855852
Loop=6000 Loss=63363.246270909964
Loop=7000 Loss=57764.30638731056
Loop=8000 Loss=21374.125887804494
Loop=9000 Loss=980.3786965814796
Loop=10000 Loss=4972.182386471314
p=4
Loop=1000 Loss=1.9257352546918865e6
Loop=2000 Loss=1.9068147812419895e6
Loop=3000 Loss=1.879136151356047e6
Loop=4000 Loss=1.8431964827807134e6
Loop=5000 Loss=1.7987728502450513e6
Loop=6000 Loss=1.7454295328677325e6
Loop=7000 Loss=1.6827463787847036e6
Loop=8000 Loss=1.610446201562994e6
Loop=9000 Loss=1.528482124362498e6
Loop=10000 Loss=1.4371053342682128e6
p=5
Loop=1000 Loss=1.9332291461814602e6
Loop=2000 Loss=1.9329330522603847e6
Loop=3000 Loss=1.9324767525664547e6
Loop=4000 Loss=1.931840315207187e6
Loop=5000 Loss=1.9309874349720925e6
Loop=6000 Loss=1.9298718518708802e6
Loop=7000 Loss=1.9284392058824003e6
Loop=8000 Loss=1.9266274683366583e6
Loop=9000 Loss=1.9243669102391053e6
Loop=10000 Loss=1.9215799306837576e6
Also, it’s a bit odd that some of the s== choices are not real ODE solvers?