The following example shows a zero gradient (update! appears to do nothing), but if I change the argmax calls to sum calls, then the gradient is non-zero. Is argmax not supported, or is there some other thing wrong with how I’m doing this?

It outputs:

```
loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
Any[nothing, nothing, Base.RefValue{Any}((contents = nothing,)), Base.RefValue{Any}((contents = nothing,))]
```

```
function testgrad()
width = 5
len = 6
batchLen = 7
opt = Flux.Optimise.Adam()
model = Flux.Chain(
Flux.flatten,
Flux.Dense(width * len => width * len),
x -> reshape(x, (width, len, batchLen))
)
ps = Flux.params(model)
function loss(batch)
x, y = batch
pred = model(x)
s = 0.0
for b in 1:batchLen
for i in 1:len
x1 = argmax(pred[:,i,b])
x2 = argmax(y[:,i,b])
# NOTE: Replacing the above two lines with the below two lines results in non-zero gradient.
# x1 = sum(pred[:,i,b]) / len
# x2 = sum(y[:,i,b]) / len
s += abs(x2 - x1)
end
end
return s / batchLen
end
x = rand(width, len, batchLen)
y = rand(width, len, batchLen)
xtest1 = rand(width, len, batchLen)
ytest1 = rand(width, len, batchLen)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
return (grad, ps)
end
```

## Version Info

julia> versioninfo()

Julia Version 1.8.0

Commit 5544a0fab7 (2022-08-17 13:38 UTC)

Platform Info:

OS: Windows (x86_64-w64-mingw32)

CPU: 24 × AMD Ryzen 9 5900X 12-Core Processor

WORD_SIZE: 64

LIBM: libopenlibm

LLVM: libLLVM-13.0.1 (ORCJIT, znver3)

Threads: 1 on 24 virtual cores

pkg> status

Status `C:\data\julia\Project.toml`

Side question: If there is a quick comprehensive way to troubleshoot why the gradient is zero (even when the loss function is returning non-zero and different values each time) when training with Flux, please tell me. Especially to avoid accidentally printing out millions of parameters to the console and having to restart Julia because that can take a while.