Excessive allocations when using Zygote.gradient

I have as a MWE a lennard-Jones force calculation. I would like to define only energy functions, and then use automatic differentiation to calculate the derivative, and thus, force.

However, for the LJ calculation, the zygote.gradient force calculation leads to several order of magnitude worse performance than the “by hand” analytical version. @btime gives 123.161 ms (910316 allocations: 40.93 MiB) for the zygote version and compared to the by hand force calculation which is 67.900 μs (1 allocation: 2.50 KiB).

A second problem I have is that I cannot exit the force calculation via an if statement in the energy function if two particles are far apart. I am unsure how to make this work. Zygote does work with if statements.

using BenchmarkTools
using StaticArrays
using Zygote

"Vector between two coordinate values, accounting for mirror image seperation"
@inline function vector1D(c1::Float64, c2::Float64, box_size::Float64)
    if c1 < c2
        return (c2 - c1) < (c1 - c2 + box_size) ? (c2 - c1) : (c2 - c1 - box_size)
        return (c1 - c2) < (c2 - c1 + box_size) ? (c2 - c1) : (c2 - c1 + box_size)

""" Potential energy between two atoms"""
@inline function pair_energy(r1::SVector{3, Float64}, r2::SVector{3, Float64}, box_size::SVector{3,Float64})

    # apply mirror image separation
    dx = vector1D(r1[1], r2[1], box_size[1])
    dy = vector1D(r1[2], r2[2], box_size[2])
    dz = vector1D(r1[3], r2[3], box_size[3])

    rij_sq = dx * dx + dy * dy + dz * dz

    ##### Would like to implement the below commented out section
    #if  rij_sq > box_size[1] / 2
    #    return 0.0
    sr2 = 1 / rij_sq
    sr6 = sr2^3
    sr12 = sr6^2
    e = 4 * (sr12 - sr6)
    return e

funct(x, y, b) = gradient(x -> pair_energy(x, y, b), x)

function analytical_total_force(r::Vector{SVector{3, Float64}}, box_size::SVector{3,Float64})
    n = length(r)
    forces = [SVector{3}(0.0, 0.0, 0.0) for i=1:n ]

    for i = 1:(n-1)
        for j = (i+1):n
            dE_dr = funct(r[i], r[j], box_size)[1]  # 
            forces[i] = forces[i] - dE_dr 
            forces[j] = forces[j] + dE_dr 
    return forces

natoms = 100
box_size = SVector{3}(6.0, 6.0, 6.0)
 r = [SVector{3}(rand(), rand(), rand()) .* box_size[1] for i = 1:natoms]
@btime test_analytical_force = analytical_total_force($r, $box_size)

I don’t have any advice about the performance part of the question, but I think the reason the conditional gives trouble is that AD systems have two ideas that look like 0. The first is “the derivative of the output with respect to that parameter at this point is zero”, Zygote uses 0 for this. The second idea is “the value at this point is independent of that parameter”, Zygote use nothing for this.

The conditional is exactly this second condition, so you have two possible choices. You can introduce a dependency to get back to the first case (return 0.0*rij_sq), or you can handle a gradient value of nothing when summing the forces. Surprisingly, the second option is a little faster on my machine, you should probably test with your actual code to see how it performs for you.

1 Like

Thanks, yes that does make it work. The time gets worse with the if statement, and is now 200.074 ms (1273661 allocations: 52.80 MiB). I don’t know if this is typical for AD, but it makes me want to just revert to coding derivatives by hand :frowning:

Given the input dimension in the function you’re taking a gradient of is really small (3), this is probably better suited for ForwardDiff. Try with

funct(x, y, b) = ForwardDiff.gradient(x -> pair_energy(x, y, b), x)

You’ll also need to remove the ::Float64 annotations on vector1D and change the pair_energy ones to just SVector{3} (or remove them completely). With that I get:

julia> @btime test_analytical_force = analytical_total_force($r, $box_size)
  924.744 μs (1 allocation: 2.50 KiB)

Hi, thanks. With your changes I get 110.199 μs (1 allocation: 2.50 KiB) for ForwardDiff and 64.700 μs (1 allocation: 2.50 KiB) for the analytical “by-hand” version. That made a really big difference. I will have to profile it to see if factor of two can be improved.

1 Like