Hi all,
I’m new to Julia and want to do gradient-based optimization with a time-related model. The model is pretty simple and here is a link to the python version: tonic/snow17.py at master · UW-Hydro/tonic · GitHub
My question is about the optimization. Say I use the Adam optimizer and optimize some parameters, and this requires doing the iteration (epoch iteration). Meanwhile, the model itself is doing the iteration along the time dimension. If I use 150 epochs for optimization, and 365-day as the time steps, it will be 150*365. I have completed a PyTorch version but it is really slow even if I only do the forward pass without gradient-based optimization.
My question: is it possible to speed up the time step iteration with Julia? Some pseudo code:
for epoch in range(150): # epoch iteration for optimization.
# forward
T = x_d.shape[0] # time steps
# some state variables.
model_swe = torch.zeros((T), requires_grad=True).float().to(device)
outflow = torch.zeros((T), requires_grad=True).float().to(device)
ait = torch.zeros((1), requires_grad=True).float().to(device)
deficit = torch.zeros((1), requires_grad=True).float().to(device)
w_i = torch.zeros((1), requires_grad=True).float().to(device) # (batch, )
w_q = torch.zeros((1), requires_grad=True).float().to(device) # (batch, )
dt = 24 # daily time series.
prec = x_d[:, 0] # precipitation
tair = x_d[:, 1] # air temperature
ele = x_s[-1] # elevation from static features.
p_atm = 33.86 * (29.9 - (0.335 * ele / 100) +
(0.00022 * ((ele / 100) ** 2.4))).float() # (batch, )
params = model(x_d_norm[:,:])
for i in range(T):
# some parameters to be optimized.
fracsnow = params[i, 0]
plwhc = 0.04
# precipitation at this time step (mm)
precip = prec[i] # (batch,)
t_air_mean = tair[i] # (batch,)
fracrain = 1.0 - fracsnow
pn = precip * fracsnow * scf
w_i = w_i + pn
rain = fracrain * precip
# some equations to calculate snow water equivalent.
model_swe[:, i] = swe
outflow[:, i] = e
# finish time iteration.
loss = loss_fn(model_swe, real_swe)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Any help would be appreciated!
Thanks!