How can I speed up my Neural ODE?

Hi @ChrisRackauckas ,
I would like to ask you how can I speed up my code.
I tried to do a system identification about an HVAC system.
Here’s the code:

using CSV,DataFrames,DifferentialEquations,Plots,Flux,DiffEqFlux,JLD 
using Flux: throttle
using Statistics

Data = CSV.read("/home/gandolfo/scripts/Data/dataset.csv",DataFrame,type = Float64)
    
    #Data
    obs = 4 #4 misuration for hour
    time_start = 1
    outT = Data.Outdoor_Temp[time_start:end] # temp esterna
    outRH = Data.Outdoor_RH[time_start:end] # cercare
    
    wind_speed = Data.Wind_Speed[time_start:end]
    wind_direction = Data.Wind_Direction[time_start:end]
    
    diff_solar_rad = Data.Diff_Solar_Rad[time_start:end] # cercare (diffrazione?)
    direct_solar_rad = Data.Direct_Solar_Rad[time_start:end] # radiazione solare diretta
    
    Htg_Sp = Data.Htg_SP[time_start:end] # HVAC heating setpoint
    Clg_Sp = Data.Clg_SP[time_start:end] # HVAC cooling?
    
    inT_avg = Data.Indoor_Temp_avg[time_start:end] # temperatura media stanze 
    indoor_temp1 = Data.Indoor_Temp1[time_start:end] # stanza1 
    indoor_temp2 = Data.Indoor_Temp2[time_start:end] # stanza2 
    indoor_temp3 = Data.Indoor_Temp3[time_start:end] # stanza3 
    indoor_temp4 = Data.Indoor_Temp4[time_start:end] # stanza4 
    indoor_temp5 = Data.Indoor_Temp5[time_start:end] # stanza5 
    indoor_temp_setpoint = Data.Indoor_Temp_Setpoint[time_start:end] # ottimo da raggiungere

    Occupancy_flag = Data.Occupancy_Flag[time_start:end] # indice di occupazione del locale boolean
    Coil_power = Data.Coil_Power[time_start:end] # HVAC parameter 
    HVAC_power = Data.HVAC_Power[time_start:end] # HVAC parameter
    
    Sys_In_temp = Data.Sys_In_Temp[time_start:end] # temperatura in ingresso (sys -> azione di controllo)
    Sys_In_Mdot = Data.Sys_In_Mdot[time_start:end] # portata in ingresso (sys -> azione di controllo)
    Sys_Out_temp = Data.Sys_Out_Temp[time_start:end] # Temperatura in uscita (azione di controllo -> sys)
    Sys_Out_Mdot = Data.Sys_Out_Mdot[time_start:end] # portata in uscita (azione di controllo -> sys)
    
    OA_temp = Data.OA_Temp[time_start:end] # temperatura aria esterna edificio
    OA_Mdot = Data.OA_Mdot[time_start:end] # portata ""
    MA_temp = Data.MA_Temp[time_start:end] # mixed air temperature
    MA_Mdot = Data.MA_Mdot[time_start:end]; # mixed air flow

    len = length(Sys_In_temp)
    time_ = Float64[time_start,len]
    t_ = collect(time_[1]:time_[end]);

if length(outT) == length(Data.Outdoor_Temp)
    println("Data acquired")
else
    println("ERROR IN DATA ACQUISITION")
end

#Plottaggio temperatura media delle 5 stanze
fig1 = plot(t_,inT_avg,title = "Temperature ",xlabel = "Time",ylabel ="degC",label = "T_avarage")
plot!(t_,indoor_temp1, label = "Stanza1",xlabel = "Time",ylabel ="degC", size = (1200,800))
plot!(t_,indoor_temp2, label = " Stanza2 ",xlabel = "Time",ylabel ="degC")
plot!(t_,indoor_temp3, label = " Stanza3 ",xlabel = "Time",ylabel ="degC")
plot!(t_,indoor_temp4, label = " Stanza4 ",xlabel = "Time",ylabel ="degC")
plot!(t_,indoor_temp5, label = " Stanza5 ",xlabel = "Time",ylabel ="degC",color = "darkblue")

outT_obs = []
outRH_obs = []
wind_speed_obs = []
wind_direction_obs = []
diff_solar_rad_obs = []
direct_solar_rad_obs = []
inT_avg_obs = Float64[]
Occupancy_flag_obs = []
Sys_Out_temp_obs = []
MA_temp_obs = []
Occupancy_flag_obs = []

for i in 1:obs:len-obs
    push!(outT_obs,mean(outT[i:i+obs-1]))    
    push!(outRH_obs,mean(outRH[i:i+obs-1])) #
    push!(wind_speed_obs,mean(wind_speed[i:i+obs-1])) #
    push!(wind_direction_obs,mean(wind_direction[i:i+obs-1])) #
    push!(diff_solar_rad_obs,mean(diff_solar_rad[i:i+obs-1])) #
    push!(direct_solar_rad_obs,mean(direct_solar_rad[i:i+obs-1])) #
    push!(inT_avg_obs,mean(inT_avg[i:i+obs-1]))
    push!(Sys_Out_temp_obs,mean(Sys_Out_temp[i:i+obs-1]))
    push!(MA_temp_obs,mean(MA_temp[i:i+obs-1])) 
    push!(Occupancy_flag_obs,Occupancy_flag[i])
end

if length(outT_obs) == round(len/4)
    hour = round(obs/4)
    println("Avarage misuration around $(hour) hour")
else
    println("ERROR IN AVARAGE")
end

#Normalizzo i valori

inT_avg_mean = (inT_avg_obs.-minimum(inT_avg_obs))/(maximum(inT_avg_obs)-minimum(inT_avg_obs))
Sys_Out_temp_mean = (Sys_Out_temp_obs.-minimum(Sys_Out_temp_obs))/(maximum(Sys_Out_temp_obs)-minimum(Sys_Out_temp_obs))
MA_temp_mean = (MA_temp_obs.-minimum(MA_temp_obs))/(maximum(MA_temp_obs)-minimum(MA_temp_obs))
outT_mean = (outT_obs.-minimum(outT_obs))/(maximum(outT_obs)-minimum(outT_obs))
outRH_mean = (outRH_obs.-minimum(outRH_obs))/(maximum(outRH_obs)-minimum(outRH_obs))
wind_speed_mean = (wind_speed_obs.-minimum(wind_speed_obs))/(maximum(wind_speed_obs)-minimum(wind_speed_obs))
wind_direction_mean = (wind_direction_obs.-minimum(wind_direction_obs))/(maximum(wind_direction_obs)-minimum(wind_direction_obs))
diff_solar_rad_mean = (diff_solar_rad_obs.-minimum(diff_solar_rad_obs))/(maximum(diff_solar_rad_obs)-minimum(diff_solar_rad_obs))
direct_solar_rad_mean = (direct_solar_rad_obs.-minimum(direct_solar_rad_obs))/(maximum(direct_solar_rad_obs)-minimum(direct_solar_rad_obs))
Occupancy_flag_mean = Occupancy_flag_obs
input_action_mean = Sys_Out_temp_mean .- MA_temp_mean;
n_obs = size(inT_avg_mean,1)
plot(inT_avg_mean,title = "Temperature ",xlabel = "Time",ylabel ="degC",label = "T_avarage",size = (1200,800))

for i in length(inT_avg_mean)
    if inT_avg_mean[i] < 1.0 && inT_avg_mean[i] > 0.0
        println("Correct normalization")
    else
        println("ERROR IN NORMALIZATION")
    end
end

#Training data
day = 24
Training_data = day*40

outT_train = outT_mean[1:Training_data]
outRH_train = outRH_mean[1:Training_data]
wind_speed_train = wind_speed_mean[1:Training_data]
wind_direction_train = wind_direction_mean[1:Training_data]
diff_solar_rad_train = diff_solar_rad_mean[1:Training_data]
direct_solar_rad_train = direct_solar_rad_mean[1:Training_data]
inT_avg_train = inT_avg_mean[1:Training_data]
Occupancy_flag_train = Occupancy_flag_mean[1:Training_data]
Sys_Out_temp_train = Sys_Out_temp_mean[1:Training_data]
MA_temp_train = MA_temp_mean[1:Training_data]
input_action_train = Sys_Out_temp_train .- MA_temp_train;

#Validation data
outT_val = outT_mean[Training_data+1:end]
outRH_val = outRH_mean[Training_data+1:end]
wind_speed_val = wind_speed_mean[Training_data+1:end]
wind_direction_val = wind_direction_mean[Training_data+1:end]
diff_solar_rad_val = diff_solar_rad_mean[Training_data+1:end]
direct_solar_rad_val = direct_solar_rad_mean[Training_data+1:end]
inT_avg_val = inT_avg_mean[Training_data+1:end]
Occupancy_flag_val = Occupancy_flag_mean[Training_data+1:end]
Sys_Out_temp_val = Sys_Out_temp_mean[Training_data+1:end]
MA_temp_val = MA_temp_mean[Training_data+1:end]
input_action_val = Sys_Out_temp_val .- MA_temp_val;

#Funzioni di input per NN

inT_avg_(t) = getindex(inT_avg_train,Int(round(t)))
Sys_Out_temp_(t) = getindex(Sys_Out_temp_train,Int(round(t)))
MA_temp_(t) = getindex(MA_temp_train,Int(round(t)))
outT_(t) = getindex(outT_train,Int(round(t)))
outRH_(t) = getindex(outRH_train,Int(round(t)))
wind_speed_(t) = getindex(wind_speed_train,Int(round(t)))
wind_direction_(t) = getindex(wind_direction_train,Int(round(t)))
diff_solar_rad_(t) = getindex(diff_solar_rad_train,Int(round(t)))
direct_solar_rad_(t) = getindex(direct_solar_rad_train,Int(round(t)))
Occupancy_flag_(t) = getindex(Occupancy_flag_train,Int(round(t)))
input_action(t) = getindex(input_action_train,Int(round(t)))
states(t) = [outT_(t),outRH_(t),wind_speed_(t),wind_direction_(t),diff_solar_rad_(t),Occupancy_flag_(t),input_action(t)]

dim = size(states(1),1)

VinT_avg_val(t) = getindex(inT_avg_val,Int(round(t)))
VSys_Out_temp_val(t) = getindex(Sys_Out_temp_val,Int(round(t)))
VMA_temp_val(t) = getindex(MA_temp_val,Int(round(t)))
VoutT_val(t) = getindex(outT_val,Int(round(t)))
VoutRH_val(t) = getindex(outRH_val,Int(round(t)))
Vwind_speed_val(t) = getindex(wind_speed_val,Int(round(t)))
Vwind_direction_val(t) = getindex(wind_direction_val,Int(round(t)))
Vdiff_solar_rad_val(t) = getindex(diff_solar_rad_val,Int(round(t)))
Vdirect_solar_rad_val(t) = getindex(direct_solar_rad_val,Int(round(t)))
VOccupancy_flag_val(t) = getindex(Occupancy_flag_val,Int(round(t)))
Vinput_action_val(t) = getindex(input_action_val,Int(round(t)))

states_val(t) = [VoutT_val(t),VoutRH_val(t),Vwind_speed_val(t),Vwind_direction_val(t),Vdiff_solar_rad_val(t),VOccupancy_flag_val(t),Vinput_action_val(t)]

println("Dataset divided")

n = 30
ann = FastChain(FastDense(dim+1,n,tanh),FastDense(n,15,swish),FastDense(15,1)) 
nn_p = initial_params(ann)

println("$(nn_p)")

function T_fun(u,p,t)
    ann(vcat(u[1],states(t)),p) 
end

T0 = [inT_avg_mean[1]]
t_span = (t_[1],t_[Training_data])
nn = ODEProblem(T_fun,T0,t_span,saveat = 1,nothing)

function pred(p)
    _prob = remake(nn,p = p)
    prediction = solve(_prob, Tsit5(), saveat = t_[t_.<= time_fit[end]], abstol = 1e-8, reltol = 1e-6) #[t .<= 20.0]
end

function loss_(p)
    y_tilde = Array(pred(p)) 
    y = inT_avg_mean[1:size(y_tilde,2)]' 
    return loss = sum(abs2,y .- y_tilde),y_tilde
end

n_epochs = 500
learning_rate = 0.01 
data = Iterators.repeated((),n_epochs)
opt = ADAM(learning_rate)
LOSS = []
time_fit = []

cb_ADAM = function(p,loss,y_tilde)
    push!(LOSS,loss)
    toll_break = 1.0
    println("LOSS: ", LOSS[end])
    
    if LOSS[end] <= toll_break
        println("Condition satisfied: Loss < $(toll_break)")
        println("CHANGE OPTIMIZER")
        return true
    end # stop optimization, take another sub_set
   
    # Iteration counter
    iteration = round(time_fit[end]/sub_time)
    println("Training: ADAM λ = $(learning_rate)")
    println("sub_time: $(iteration)/$(n_sub_set)")
    return false
end

cb_LBFGS = function(p,loss,y_tilde)
    push!(LOSS,loss)
    toll_break = 0.05
    println("LOSS: ", LOSS[end])
    
    if LOSS[end] <= toll_break
        return true
    end # stop optimization, take another sub_set
   
    # Iteration counter
    iteration = round(time_fit[end]/sub_time)
    println("Training: LBFGS")
    println("sub_time: $(iteration)/$(n_sub_set)")
    return false
end

Integration_time = t_[end] - t_[1]
n_sub_set = 20
sub_time = Integration_time/n_sub_set
@time begin
    
    for t_fit in sub_time:sub_time:t_[end]
        push!(time_fit,t_fit)
        
        global nn_p = DiffEqFlux.sciml_train(loss_,nn_p,opt,cb = throttle(cb_ADAM,50),maxiters = n_epochs)
        global nn_p = DiffEqFlux.sciml_train(loss_,nn_p.minimizer,opt,cb = throttle(cb_LBFGS,50),maxiters = n_epochs)
                
    end
end
println("TRAINING COMPLETE")

#Plot training completo
y_tilde = Array(pred(par))
y = inT_avg_mean[1:size(y_tilde,2)]

fig1 = plot(t_[1:size(y_tilde,2)],y,linewidth = 3, color = "black",label = " T_avarage ",title = "500 epochs lr = 0.01",size = (2000,1200))
Plots.plot!(y_tilde',label = " T_avarage_nn ",color = "red")

#Plot validation
t_end = Float64(length(inT_avg_val))
t_val = (1.0,t_end)
tval_ = collect(t_val[1]:t_val[2])

function T_fun_val(u,p,t)
    ann(vcat(u[1],states_val(t)),p) 
end

nn_val = ODEProblem(T_fun_val,[inT_avg_val[1]],t_val, par)

y_val = Array(solve(nn_val, Tsit5(), saveat = tval_, abstol = 1e-8, reltol = 1e-6)) #[t .<= 20.0]
y_true = inT_avg_val

fig2 = plot(y_true,linewidth = 3, color = "black",label = " T_avarage ")
Plots.plot!(y_val',label = " T_avarage_nn ",color = "red")

Loss_validation = sum(abs2,y_true .- y_val)

#Save result
Parameters = par
LOSS_test = LOSS
Loss_validation
        #/home/gandolfo/scripts/Result/Neural_v1.1/Data.jld
save("/home/gandolfo/scripts/Result/Neural_v1.2/Data.jld","Parameters",Parameters,"LOSS_test",LOSS_test,"Loss_validation",Loss_validation)
#Save figure
fig_loss = plot(LOSS)
                #/home/gandolfo/scripts/Result/Neural_v1.1/Training.png
savefig(fig1,"/home/gandolfo/scripts/Result/Neural_v1.2/Training.png")
                #/home/gandolfo/Result/Neural_v1.1/Plot/Validation.png
savefig(fig2,"/home/gandolfo/scripts/Result/Neural_v1.2/Validation.png")
                #/home/gandolfo/Result/Neural_v1.1/Plot/Loss.png
savefig(fig_loss,"/home/gandolfo/scripts/Result/Neural_v1.2/Loss.png")

The aim is to train the Neural ODE using ADAM optimizer and when a loss condition is satisfied it changes the optimizer in LBFGS.
When the second optimizer gets back a good value of Loss then it breaks the optimization.
I’ve read in https://diffeqflux.sciml.ai/stable/sciml_train/ something that sounds like what I want to do but I don’t understand what deterministic loss function means.

In conclusion:

  1. Is there a way to speed up my code in order to avoid the for loop in which i change the optimizer?
  2. Is there a way to upgrade this script to something leaner?

Thank you in advance.

Have you read through the Performance Tips?

I think there are many points that would apply

3 Likes

What do you mean by that? That could mean plenty of things.

Did you check whether EnzymeVJP is compatible? Reduce allocations? Fix the types to use less globals? Etc.

I tried to create a chain of optimizers but they don’t work very well.
Any advices to improve the convergence of Neural ODE?

Did you go through the Julia performance tips and the adjoint choices like mentioned above? Those still do not look addressed in the code you show.