I have as a MWE a lennard-Jones force calculation. I would like to define only energy functions, and then use automatic differentiation to calculate the derivative, and thus, force.
However, for the LJ calculation, the zygote.gradient force calculation leads to several order of magnitude worse performance than the “by hand” analytical version. @btime gives 123.161 ms (910316 allocations: 40.93 MiB) for the zygote version and compared to the by hand force calculation which is 67.900 μs (1 allocation: 2.50 KiB).
A second problem I have is that I cannot exit the force calculation via an if statement in the energy function if two particles are far apart. I am unsure how to make this work. Zygote does work with if statements.
using BenchmarkTools
using StaticArrays
using Zygote
"Vector between two coordinate values, accounting for mirror image seperation"
@inline function vector1D(c1::Float64, c2::Float64, box_size::Float64)
if c1 < c2
return (c2 - c1) < (c1 - c2 + box_size) ? (c2 - c1) : (c2 - c1 - box_size)
return (c1 - c2) < (c2 - c1 + box_size) ? (c2 - c1) : (c2 - c1 + box_size)
""" Potential energy between two atoms"""
@inline function pair_energy(r1::SVector{3, Float64}, r2::SVector{3, Float64}, box_size::SVector{3,Float64})
# apply mirror image separation
dx = vector1D(r1[1], r2[1], box_size[1])
dy = vector1D(r1[2], r2[2], box_size[2])
dz = vector1D(r1[3], r2[3], box_size[3])
rij_sq = dx * dx + dy * dy + dz * dz
##### Would like to implement the below commented out section
#if rij_sq > box_size[1] / 2
# return 0.0
sr2 = 1 / rij_sq
sr6 = sr2^3
sr12 = sr6^2
e = 4 * (sr12 - sr6)
return e
funct(x, y, b) = gradient(x -> pair_energy(x, y, b), x)
function analytical_total_force(r::Vector{SVector{3, Float64}}, box_size::SVector{3,Float64})
n = length(r)
forces = [SVector{3}(0.0, 0.0, 0.0) for i=1:n ]
for i = 1:(n-1)
for j = (i+1):n
dE_dr = funct(r[i], r[j], box_size)[1] #
forces[i] = forces[i] - dE_dr
forces[j] = forces[j] + dE_dr
return forces
natoms = 100
box_size = SVector{3}(6.0, 6.0, 6.0)
r = [SVector{3}(rand(), rand(), rand()) .* box_size[1] for i = 1:natoms]
@btime test_analytical_force = analytical_total_force($r, $box_size)