I’m new to Julia and recently trying to implement the algorithms from Direct Runge-Kutta Discretization
Achieves Acceleration in which I need to solve an 2nd order ODE
where \nabla f(x(t)) is the gradient of the loss function. I tried to write code with DifferentialEquations.jl as below
using Random, Distributions, DifferentialEquations, Plots, LinearAlgebra # Data generation Random.seed!(12) num_observations = 3 clusters = 1 dim = 10 spread = 10 margin = rand(0:10,dim)*50 C = Diagonal(ones(dim)) x1, x2 = rand(MvNormal(rand(0:spread,dim), C), num_observations), rand(MvNormal(rand(0:spread,dim), C), num_observations) for i in 1:clusters if i == 1 x1 = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations)) x2 = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations)) else x1_temp = transpose(rand(MvNormal(rand(0:spread,dim) + margin, C), num_observations)) x1 = cat(x1, x1_temp) x2_temp = transpose(rand(MvNormal(rand(0:spread,dim), C), num_observations)) x2 = cat(x2_temp) end end simulated_separableish_features = vcat(x1,x2) simulated_labels = vcat(zeros(size(x1)), ones(size(x2))) simulated_labels = reshape(simulated_labels,(length(simulated_labels),1)) intercept = ones(size(simulated_separableish_features), 1) simulated_separableish_features = hcat(intercept, simulated_separableish_features) dim+=1 # Loss and gradient of loss function loss(X, Y, w) pred = Y - X * w return pred' * pred end function grad(X, Y, w) pred = X * w - Y return transpose(X) * pred end f = x->loss(simulated_separableish_features, simulated_labels, x) df = x->grad(simulated_separableish_features, simulated_labels, x) # Simulate ODE flow function ODE(du, u, p, t) return -(2*p+1) / t * du - p^2 * t^(p-2) * df(u) end function simulate(p, s, N, period, step, xinit, vinit) h = step if s==1 alg = Euler() elseif s==2 alg = SSPRK2() elseif s==4 alg = ERKN4() elseif s==8 alg = DPRK8() else return nothing end x = xinit v = vinit loop = Int(N/period) losses = zeros(loop) soln=nothing for j in 1:loop losses[j] = f(x) if j%1000 == 0 println("Loop="*string(j)*"\t Loss="*string(losses[j])) end t_start = 0.01 + (j-1)*h*period t_end = 0.01 + j*h*period tspan = (t_start, t_end) prob = SecondOrderODEProblem(ODE, v, x, tspan, p) soln = solve(prob, alg, dt=h, adaptive=false) x = soln[12:22, period+1] v = soln[1:11, period+1] end return losses, soln end N = 100000 truncate = 0 period = 10 pvalues = [2,3,4,5] #nrange = period * range(1, Int(N/period)) - period + 1 svals = [4,4,4,4] steps = [1e-7, 1e-7, 1e-7, 1e-7] weights = reshape(rand(dim),(dim,1)) dweights = zeros(dim,1) losslst = zeros(length(svals),Int(N/period)) for i in 1:length(svals) s = svals[i] p = pvalues[i] println("p="*string(p)) p = [p] losses, soln = simulate(p, s, N, period, steps[i], weights, dweights) losslst[i,:] = losses plot!(losses, label="p= "*string(pvalues[i]), xaxis=:log, yaxis=:log) end
But the weight (x in function simulation) is not updated correctly as expected and the loss function remains unchanged. I cannot figure out what’s wrong and looking forward to your advices.