I’m trying to do some ML experiments using Lux and Reactant. Without posting all of the code, I was wondering if I’m using Reactant properly, because I’m getting compile times of around 3 minutes, and gigabytes of memory usage for a single function, so that compiling another one can sometimes crash my WSL.
function apply_no_prealloc(model, u_batch, parameters, state)
(;rhel_layer, dropout, output_layer) = model
layer_parameters = parameters.rhel_layer
layer_state = state.rhel_layer
layer_state = layer_no_prealloc(rhel_layer, u_batch, layer_parameters, layer_state)
y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
y = cat(u_batch, y; dims=1) # Concatenate input and hidden along feature dim (N_in + N, T, B)
if output_mode(model) isa Classification
# Classification mode: mean over time, then linear layer
y = mean(y; dims=2) # (N_in + N, B)
y = dropdims(y; dims=2) # (N_in + N, B) -> (N_in + N, B)
y = output_layer(y, parameters.output_layer, state.output_layer) # (output_dim, B)
else
# CODE
end
updated_state = (; state..., dropout = dropout_state)
return y, updated_state
end
function layer_unallocated(layer, input_sequence, params, state)
# input is in (N_in, t, B)
N_in, DT, B = size(input_sequence)
(;W_in) = params
# Apply the input layer to all timesteps and batches
W_dot_us = reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B) # (N, DT*B), flattened batches and time for GEMM
return apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
end
function apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
_, Tlen, _ = size(W_dot_us)
for t in 1:Tlen
@inline leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
end
return state
end
function leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
(;ϕ, π) = state
(;W, b, α, β, ϵ) = params
dt = layer.dt
half_dt = dt / 2
W_in_u = @view W_dot_us[:, t, :]
ϕ_t = @view ϕ[:, t, :]
kinetic_grad = grad_state_kinetic_hamiltonian(π)
ϕhalf = ϕ_t .+ half_dt .* kinetic_grad
grad = grad_state_potential_hamiltonian_no_prealloc(ϕhalf, W_in_u, W, b, α, β)
π .-= dt .* grad
kinetic_grad = grad_state_kinetic_hamiltonian(π)
ϕ_t .= ϕhalf .+ half_dt .* kinetic_grad
π .+= ϵ
return state
end
function grad_state_potential_hamiltonian_no_prealloc(ϕ::AbstractArray{T}, W_in_u, W, b, alpha, beta) where T
# Gradient of norm term: ∂/∂ϕ [1/2 * α * ϕ²] = α * ϕ
tβϕ = @. tanh(beta * ϕ)
state_grad = @. alpha * ϕ
sech2β = @. T(1) - tβϕ^2 # sech²(βϕ)
W_tanh_ϕ = W*tβϕ
WT_tanh_ϕ = W'*tβϕ
state_grad += @. (T(1) / beta) * T(0.5) * sech2β * (W_tanh_ϕ + WT_tanh_ϕ)
# Gradient of bias term: ∂/∂ϕ [b^T tanh(ϕ)] = b * sech²(ϕ)
state_grad += @. b * (T(1) - tanh(ϕ)^2)
# Gradient of input term: ∂/∂ϕ [tanh(ϕ)^T W_in_u] = W_in_u * sech²(ϕ)
state_grad += @. W_in_u * (T(1) - tanh(ϕ)^2)
return state_grad
end
I’m trying to do
f = @compile apply_no_prealloc(model, u, state, params)
for some model I have. We had something similar before in jax which seems quite snappy to compile.