Zero gradient when using argmax

The following example shows a zero gradient (update! appears to do nothing), but if I change the argmax calls to sum calls, then the gradient is non-zero. Is argmax not supported, or is there some other thing wrong with how I’m doing this?

It outputs:

loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
Any[nothing, nothing, Base.RefValue{Any}((contents = nothing,)), Base.RefValue{Any}((contents = nothing,))]
function testgrad()
    width = 5
    len = 6
    batchLen = 7

    opt = Flux.Optimise.Adam()
    model = Flux.Chain(
        Flux.Dense(width * len => width * len),
        x -> reshape(x, (width, len, batchLen))
    ps = Flux.params(model)

    function loss(batch)
        x, y = batch
        pred = model(x)
        s = 0.0
        for b in 1:batchLen
            for i in 1:len
                x1 = argmax(pred[:,i,b])
                x2 = argmax(y[:,i,b])
                # NOTE: Replacing the above two lines with the below two lines results in non-zero gradient.
                # x1 = sum(pred[:,i,b]) / len
                # x2 = sum(y[:,i,b]) / len
                s += abs(x2 - x1)
        return s / batchLen

    x = rand(width, len, batchLen)
    y = rand(width, len, batchLen)
    xtest1 = rand(width, len, batchLen)
    ytest1 = rand(width, len, batchLen)

    println("loss: ", loss((xtest1, ytest1)))
    grad = Flux.gradient(() -> loss((x,y)), ps)
    Flux.Optimise.update!(opt, ps, grad)
    println("loss: ", loss((xtest1, ytest1)))
    grad = Flux.gradient(() -> loss((x,y)), ps)
    Flux.Optimise.update!(opt, ps, grad)
    println("loss: ", loss((xtest1, ytest1)))
    grad = Flux.gradient(() -> loss((x,y)), ps)
    Flux.Optimise.update!(opt, ps, grad)
    println("loss: ", loss((xtest1, ytest1)))

    return (grad, ps)
Version Info

julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab7 (2022-08-17 13:38 UTC)
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × AMD Ryzen 9 5900X 12-Core Processor
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
Threads: 1 on 24 virtual cores

pkg> status
Status C:\data\julia\Project.toml
[6e4b80f9] BenchmarkTools v1.3.1
[4f18b42c] BusinessDays v0.9.18
[052768ef] CUDA v3.12.0
[944b1d66] CodecZlib v0.7.0
[861a8166] Combinatorics v1.0.2
[a8cc5b0e] Crayons v4.1.1
[a10d1c49] DBInterface v2.5.0
[864edb3b] DataStructures v0.18.13
[31c24e10] Distributions v0.25.68
[631263cc] EnergyStatistics v0.2.0
[4e289a0a] EnumX v1.0.2
[7a1cc6ca] FFTW v1.5.0
[fb4d412d] FixedPointDecimals v0.4.0
[587475ba] Flux v0.13.5
[53d20848] FluxArchitectures v0.2.1
[59287772] Formatting v0.4.2
[f7f18e0c] GLFW v3.4.1
[e9467ef8] GLMakie v0.6.13
[cd3eb016] HTTP v1.3.3
[a98d9a8b] Interpolations v0.14.4
[d8418881] Intervals v1.8.0
[c8e1da08] IterTools v1.4.0
[0f8b85d8] JSON3 v1.9.5
[5ab0869b] KernelDensity v0.6.5
[194296ae] LibPQ v1.14.0
[e6f89c97] LoggingExtras v0.4.9
[2fda8390] LsqFit v0.12.1
[d9ec5142] NamedTupleTools v0.14.1
[6fe1bfb0] OffsetArrays v1.12.7
[9b87118b] PackageCompiler v2.0.9
[fa939f87] Pidfile v1.3.0
[08abe8d2] PrettyTables v1.3.1
[295af30f] Revise v3.4.0
[f2b01f46] Roots v2.0.2
[c35d69d1] SMTPClient v0.6.3
[276daf66] SpecialFunctions v2.1.7
[6ec83bb0] StructEquality v2.1.0
[bd369af6] Tables v1.7.0
[b189fb0b] ThreadPools v2.1.1
[9e3dc215] TimeSeries v0.23.0
[f269a46b] TimeZones v1.9.0
[3bb67fe8] TranscodingStreams v0.9.9
[9d95972d] TupleTools v1.3.0
[e88e6eb3] Zygote v0.6.45
[ade2ca70] Dates
[8bb1440f] DelimitedFiles
[de0858da] Printf
[9e88b42a] Serialization
[10745b16] Statistics
[8dfed614] Test

Side question: If there is a quick comprehensive way to troubleshoot why the gradient is zero (even when the loss function is returning non-zero and different values each time) when training with Flux, please tell me. Especially to avoid accidentally printing out millions of parameters to the console and having to restart Julia because that can take a while.

The derivative of argmax is zero almost everywhere (and undefined or a delta distribution at discontinuities).

This is why people use differentiable approximations like softmax for optimization (or tricks like epigraph to turn discrete minimax problems into differentiable NLPs).


That makes sense. I saw argmax recommended in a thread somewhere, so thought it might be ok, but I must have read it out of context. I was actually already using softmax, so I can figure out some differentiable solution with that.

Thanks for the answer! It seems like it’s past time that I should review Zygote’s docs.

This is not specific to Zygote, it’s about the definition of derivative — so I would review your calculus.

As mentionned by @stevengj, argmax is a piecewise constant function, therefore gradient optimization doesn’t work since you’ll always have zero value gradient.

If you absolutely need to use argmax in your pipeline, you can checkout InferOpt.jl, a package that provides tools to be able to differentiate through combinatorial algorithm layers (argmax can be seen as a very simple combinatorial algorithm).

1 Like

Sorry, my comment about reading the docs was without context. It was regarding my side question about development process and how to troubleshoot situations and not about the mathematical aspects. You’re right, it was out of place and confusing.