Dear Community
I’m new to Julia and recently trying to implement the algorithms from Direct Runge-Kutta Discretization
Achieves Acceleration in which I need to solve an 2nd order ODE
where \nabla f(x(t)) is the gradient of the loss function. I tried to write code with DifferentialEquations.jl as below
using Random, Distributions, DifferentialEquations, Plots, LinearAlgebra
# Data generation
Random.seed!(12)
num_observations = 3
clusters = 1
dim = 10
spread = 10
margin = rand(0:10,dim)*50
C = Diagonal(ones(dim))
x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations)
for i in 1:clusters
if i == 1
x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
else
x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations))
x1 = cat(x1, x1_temp)
x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations))
x2 = cat(x2_temp)
end
end
simulated_separableish_features = vcat(x1,x2)
simulated_labels = vcat(zeros(size(x1)[1]), ones(size(x2)[1]))
simulated_labels = reshape(simulated_labels,(length(simulated_labels),1))
intercept = ones(size(simulated_separableish_features)[1], 1)
simulated_separableish_features = hcat(intercept, simulated_separableish_features)
dim+=1
# Loss and gradient of loss
function loss(X, Y, w)
pred = Y - X * w
return pred' * pred
end
function grad(X, Y, w)
pred = X * w - Y
return transpose(X) * pred
end
f = x->loss(simulated_separableish_features, simulated_labels, x)
df = x->grad(simulated_separableish_features, simulated_labels, x)
# Simulate ODE flow
function ODE(du, u, p, t)
return -(2*p[1]+1) / t * du - p[1]^2 * t^(p[1]-2) * df(u)
end
function simulate(p, s, N, period, step, xinit, vinit)
h = step
if s==1
alg = Euler()
elseif s==2
alg = SSPRK2()
elseif s==4
alg = ERKN4()
elseif s==8
alg = DPRK8()
else
return nothing
end
x = xinit
v = vinit
loop = Int(N/period)
losses = zeros(loop)
soln=nothing
for j in 1:loop
losses[j] = f(x)[1]
if j%1000 == 0
println("Loop="*string(j)*"\t Loss="*string(losses[j]))
end
t_start = 0.01 + (j-1)*h*period
t_end = 0.01 + j*h*period
tspan = (t_start, t_end)
prob = SecondOrderODEProblem(ODE, v, x, tspan, p)
soln = solve(prob, alg, dt=h, adaptive=false)
x = soln[12:22, period+1]
v = soln[1:11, period+1]
end
return losses, soln
end
N = 100000
truncate = 0
period = 10
pvalues = [2,3,4,5]
#nrange = period * range(1, Int(N/period)) - period + 1
svals = [4,4,4,4]
steps = [1e-7, 1e-7, 1e-7, 1e-7]
weights = reshape(rand(dim),(dim,1))
dweights = zeros(dim,1)
losslst = zeros(length(svals),Int(N/period))
for i in 1:length(svals)
s = svals[i]
p = pvalues[i]
println("p="*string(p))
p = [p]
losses, soln = simulate(p, s, N, period, steps[i], weights, dweights)
losslst[i,:] = losses
plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log)
end
But the weight (x in function simulation) is not updated correctly as expected and the loss function remains unchanged. I cannot figure out what’s wrong and looking forward to your advices.