Whys is Enzyme code running slow

Title: Performance Issues with Enzyme Autodiff in Julia Code


I am new to Enzyme. When I run this sample code, the autodiff function is taking a lot of time. Where am I going wrong?

using Enzyme

function Convec_diffuse_derivative(phi_state, U_state, V_state, dx, dy, nu)
    # Find the 2D convection diffusion derivative using first order upwind
    m, n = size(phi_state);
    der_state = zeros(m, n);

    for i = 2:(m - 1), j = 2:(n - 1)
        # Let us define some quantities of interest
        U = U_state[i, j];
        V = V_state[i, j];
        dx_f = dx[i];
        dx_b = dx[i - 1];
        dy_f = dy[j];
        dy_b = dy[j - 1];
        phi_c = phi_state[i, j];
        phi_e = phi_state[i + 1, j];
        phi_w = phi_state[i - 1, j];
        phi_n = phi_state[i, j + 1];
        phi_s = phi_state[i, j - 1];

        # Now writing the advection terms
        con_x = 0.0;
        if U > 0.0
            con_x = U * (phi_c - phi_w) / dx_b;
        elseif U < 0.0
            con_x = U * (phi_e - phi_c) / dx_f;
        end

        con_y = 0.0;
        if V > 0.0
            con_y = V * (phi_c - phi_s) / dy_b;
        elseif v < 0.0
            con_y = V * (phi_n - phi_c) / dy_f;
        end
        # Now writing the diffusive terms
        dphidx_e = (phi_e - phi_c) / dx_f;
        dphidx_w = (phi_c - phi_w) / dx_b;
        D_xx = (dphidx_e - dphidx_w) / (0.5 * (dx_f + dx_b));
        # Now using both the terms
        dphidy_n = (phi_n - phi_c) / dy_f;
        dphidy_s = (phi_c - phi_s) / dy_b;
        D_yy = (dphidy_n - dphidy_s) / (0.5 * (dy_f + dy_b));

        der_state[i, j] = nu * (D_xx + D_yy) - con_x - con_y;

    end
    return der_state;
end

function Euler_integrator(phi, timestep, U_vel, V_vel, dx, dy, nu)
    d_phi = timestep * Convec_diffuse_derivative(phi, U_vel, V_vel, dx, dy, nu);
    return phi + d_phi;
end

function RK4_integrator(phi, timestep, U_vel, V_vel, dx, dy, nu)
    k1 = Convec_diffuse_derivative(phi, U_vel, V_vel, dx, dy, nu);
    k2 = Convec_diffuse_derivative(phi + 0.5 * timestep * k1, U_vel, V_vel, dx, dy, nu);
    k3 = Convec_diffuse_derivative(phi + 0.5 * timestep * k2, U_vel, V_vel, dx, dy, nu);
    k4 = Convec_diffuse_derivative(phi + timestep * k3, U_vel, V_vel, dx, dy, nu);
    K = (timestep / 6.0) * (k1 + k4 + 2 * (k2 + k3))
    return phi + K
end

# Now let us test the solver
# Let us write the meshgrid_function

function mesh_grid_ij(x, y)
    m = length(x);
    n = length(y);
    xx = zeros(Float64, m, n);
    yy = zeros(Float64, m, n);
    # Now looping for populating the values
    for i = 1:m
        yy[i, :] .= y;
    end

    for j = 1:n
        xx[:, j] .= x;
    end
    return xx, yy;
end

function init_state(xx, yy, phi)
    m, n = size(phi)
    phi_init = zeros(m, n)
    # We want a part of the phi to be non zeros
    case1 = xx .< 0.75;
    case2 = xx .> 0.25;
    case3 = yy .< 0.85;
    case4 = yy .> 0.25;
    cases = case1 .& case2 .& case3 .& case4;
    phi_init[cases] .= 2.0;

    return copy(phi_init);
end

x = [i for i = 0:0.01:1]
y = [i for i = 0:0.01:2]
dx = x[2:end] - x[1:end - 1]
dy = y[2:end] - y[1:end - 1]
xx, yy = mesh_grid_ij(x, y);
phi = zeros(Float64, length(x), length(y));
dphi = zeros(Float64, length(x), length(y));
curr_phi = init_state(xx, yy, phi);
nu = 0.0;
n_iter = 1000;
timestep = 1e-4;

function simulation(omega, xx, yy, curr_phi, ind_x, ind_y, nu, n_iter, timestep)
    uu = (-1.0 * omega) * yy
    vv = (1.0 * omega) * xx
    phi = zeros(Float64, length(x), length(y));

    # Let us define the simulation parameters
    for i = 1:n_iter
        curr_phi = Euler_integrator(curr_phi, timestep, uu, vv, dx, dy, nu)
    end
    return curr_phi[ind_x, ind_y]
end

val = simulation(3.0, xx, yy, curr_phi, 40, 40, nu, n_iter, timestep)
# First finding the gradient using autodiff
result = first(autodiff(Reverse, simulation, Active, Active(3.0), Const(xx), Const(yy), Duplicated(curr_phi, dphi), Const(40), Const(40), Const(nu), Const(n_iter), Const(timestep)))

# Now finding out the gradient manually
result = (simulation(3.0 + 0.0001, xx, yy, curr_phi, 40, 40, nu, n_iter, timestep) - simulation(3.0, xx, yy, curr_phi, 40, 40, nu, n_iter, timestep)) / 0.0001

The first time you run any Julia function, you pay the price of just-in-time compilation. This is especially true for Enzyme which does lots of compilation work to differentiate through your code. Is it much faster on the second run?

It seems to get stuck. I also reduced the number of iterations to 2 and reduced the size of the array if that does any good but the problem still persists. Does the code implementation seem to have any bugs?

It seems at least part of it is due to your code itself being slightly inefficient.
You can read the performance tips to learn how to speed it up.

In particular, type instability is really bad for Julia in general and Enzyme in particular.
There were two major inference issues in your code:

  • the variables x, y, dx and dy are global variables, you need to give them as arguments to simulation
  • the function Convec_diffuse_derivative seems to contain an undefined variable little v in this block:

Fixing both of these things got me a 4x speedup on the autodiff, from 5s to around 1.5s. You can probably get much better if you improve your function efficiency, for instance by reducing allocations.

My corrected code

using Enzyme
using JET

function Convec_diffuse_derivative(phi_state, U_state, V_state, dx, dy, nu)
    # Find the 2D convection diffusion derivative using first order upwind
    m, n = size(phi_state)
    der_state = zeros(m, n)

    for i = 2:(m-1), j = 2:(n-1)
        # Let us define some quantities of interest
        U = U_state[i, j]
        V = V_state[i, j]
        dx_f = dx[i]
        dx_b = dx[i-1]
        dy_f = dy[j]
        dy_b = dy[j-1]
        phi_c = phi_state[i, j]
        phi_e = phi_state[i+1, j]
        phi_w = phi_state[i-1, j]
        phi_n = phi_state[i, j+1]
        phi_s = phi_state[i, j-1]

        # Now writing the advection terms
        con_x = 0.0
        if U > 0.0
            con_x = U * (phi_c - phi_w) / dx_b
        elseif U < 0.0
            con_x = U * (phi_e - phi_c) / dx_f
        end

        con_y = 0.0
        if V > 0.0
            con_y = V * (phi_c - phi_s) / dy_b
        elseif V < 0.0
            con_y = V * (phi_n - phi_c) / dy_f
        end
        # Now writing the diffusive terms
        dphidx_e = (phi_e - phi_c) / dx_f
        dphidx_w = (phi_c - phi_w) / dx_b
        D_xx = (dphidx_e - dphidx_w) / (0.5 * (dx_f + dx_b))
        # Now using both the terms
        dphidy_n = (phi_n - phi_c) / dy_f
        dphidy_s = (phi_c - phi_s) / dy_b
        D_yy = (dphidy_n - dphidy_s) / (0.5 * (dy_f + dy_b))

        der_state[i, j] = nu * (D_xx + D_yy) - con_x - con_y

    end
    return der_state
end

function Euler_integrator(phi, timestep, U_vel, V_vel, dx, dy, nu)
    d_phi = timestep * Convec_diffuse_derivative(phi, U_vel, V_vel, dx, dy, nu)
    return phi + d_phi
end

function RK4_integrator(phi, timestep, U_vel, V_vel, dx, dy, nu)
    k1 = Convec_diffuse_derivative(phi, U_vel, V_vel, dx, dy, nu)
    k2 = Convec_diffuse_derivative(phi + 0.5 * timestep * k1, U_vel, V_vel, dx, dy, nu)
    k3 = Convec_diffuse_derivative(phi + 0.5 * timestep * k2, U_vel, V_vel, dx, dy, nu)
    k4 = Convec_diffuse_derivative(phi + timestep * k3, U_vel, V_vel, dx, dy, nu)
    K = (timestep / 6.0) * (k1 + k4 + 2 * (k2 + k3))
    return phi + K
end

# Now let us test the solver
# Let us write the meshgrid_function

function mesh_grid_ij(x, y)
    m = length(x)
    n = length(y)
    xx = zeros(Float64, m, n)
    yy = zeros(Float64, m, n)
    # Now looping for populating the values
    for i = 1:m
        yy[i, :] .= y
    end

    for j = 1:n
        xx[:, j] .= x
    end
    return xx, yy
end

function init_state(xx, yy, phi)
    m, n = size(phi)
    phi_init = zeros(m, n)
    # We want a part of the phi to be non zeros
    case1 = xx .< 0.75
    case2 = xx .> 0.25
    case3 = yy .< 0.85
    case4 = yy .> 0.25
    cases = case1 .& case2 .& case3 .& case4
    phi_init[cases] .= 2.0

    return copy(phi_init)
end

x = [i for i = 0:0.01:1]
y = [i for i = 0:0.01:2]
dx = x[2:end] - x[1:end-1]
dy = y[2:end] - y[1:end-1]
xx, yy = mesh_grid_ij(x, y);
phi = zeros(Float64, length(x), length(y));
dphi = zeros(Float64, length(x), length(y));
curr_phi = init_state(xx, yy, phi);
nu = 0.0;
n_iter = 1000;
timestep = 1e-4;

function simulation(omega, x, y, dx, dy, xx, yy, curr_phi, ind_x, ind_y, nu, n_iter, timestep)
    uu = (-1.0 * omega) * yy
    vv = (1.0 * omega) * xx
    phi = zeros(Float64, length(x), length(y))

    # Let us define the simulation parameters
    for i = 1:n_iter
        curr_phi = Euler_integrator(curr_phi, timestep, uu, vv, dx, dy, nu)
    end
    return curr_phi[ind_x, ind_y]
end

val = simulation(3.0, x, y, dx, dy, xx, yy, curr_phi, 40, 40, nu, n_iter, timestep)

JET.@test_opt simulation(3.0, x, y, dx, dy, xx, yy, curr_phi, 40, 40, nu, n_iter, timestep)  # this means your function is type-stable

@time autodiff(
    Reverse,
    simulation,
    Active,
    Active(3.0),
    Duplicated(x, make_zero(x)),
    Duplicated(y, make_zero(y)),
    Duplicated(dx, make_zero(dx)),
    Duplicated(dy, make_zero(dy)),
    Const(xx),
    Const(yy),
    Duplicated(curr_phi, dphi),
    Const(40),
    Const(40),
    Const(nu),
    Const(n_iter),
    Const(timestep)
)
2 Likes

Thank you for the help. The little ā€œvā€ was the major error.

I think if you restructure your code to reuse this matrix instead of re-allocating it every time, you can get absolutely massive speedups

1 Like