Zygote returning wrong gradient on GPU

Hi! As an exercise, I want to optimize the potential energy of a system of springs, which are coupled if they are closer than some threshold. On cpu, the gradient is correct but on gpu it is nothing. Am I doing something wrong or is this a bug?

using StaticArrays
using CUDA
using Zygote
using LinearAlgebra

pot(x, x0) = (x-x0)^2

# array based version without loops and mutation
function E(positions, r0, cutoff)
    dist_matrix = positions .- transpose(transpose.(positions))  # hack to "un-transpose" the position vectors
    distances = norm.(dist_matrix)
    valid = distances .< cutoff
    thresholded = distances .* valid
    mask = thresholded .> 0
    pots = pot.(thresholded, r0) .* mask
    E_ = sum(pots) / 2
    return E_

pos = [@SVector[0,0],@SVector[1,0],@SVector[0,1]]
gradient((x)->E(x, 1.1, Inf), pos)  # SVector...
gradient((x)->E(x, 1.1, Inf), cu(pos))  # (nothing,)

[052768ef] CUDA v3.9.1
[587475ba] Flux v0.13.0
[e9467ef8] GLMakie v0.6.0
[ee78f7c6] Makie v0.17.0
[3bd65402] Optimisers v0.2.3
[90137ffa] StaticArrays v1.4.4
[e88e6eb3] Zygote v0.6.39

julia v1.7.2
1 Like

Nothing stands out, though the broadcasted transpose and norm are unusual. You can try bisecting the function by inserting @showgrad around particular expressions until a nothing turns up where it shouldn’t. I suspect it may have something to do with Zygote’s broadcasting heuristics.

I think this is a variant of the issue with complex broadcasting. Zygote’s broadcasting for CuArrays uses dual numbers, hence only real numbers, but doesn’t give an error in other cases.

Yes I think so. In particular this line removes the gradient for arbitrary types, including static arrays. Could this line be changed to emit a warning or error? That might be too heavy a solution, as ignoring some types for gradients is often okay.

I wrote some code extending the forward diff broadcast path to static arrays for Molly.jl: Molly.jl/zygote.jl at v0.10.1 · JuliaMolSim/Molly.jl · GitHub. It is over-complicated and difficult to read, but you could write similar cases for your code to get this working.

1 Like

With the help of @showgrad I boiled it down to:

pos = cu([@SVector[0,0],@SVector[1,0],@SVector[0,1]])
gradient(pos) do positions
    dist_matrix = positions .- transpose(transpose.(positions))
    sum(norm, dist_matrix)
end  # nothing

Which indeed resembles Complex broadcasting AD gives `nothing` when using CUDA · Issue #1215 · FluxML/Zygote.jl · GitHub.

It was actually intended as a demo for a friend, so for now I won’t and can’t invest more time. Anyway, thanks for the help!

1 Like