The following example shows a zero gradient (update! appears to do nothing), but if I change the argmax calls to sum calls, then the gradient is non-zero. Is argmax not supported, or is there some other thing wrong with how I’m doing this?
It outputs:
loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
loss: 11.142857142857142
Any[nothing, nothing, Base.RefValue{Any}((contents = nothing,)), Base.RefValue{Any}((contents = nothing,))]
function testgrad()
width = 5
len = 6
batchLen = 7
opt = Flux.Optimise.Adam()
model = Flux.Chain(
Flux.Dense(width * len => width * len),
x -> reshape(x, (width, len, batchLen))
ps = Flux.params(model)
function loss(batch)
x, y = batch
pred = model(x)
s = 0.0
for b in 1:batchLen
for i in 1:len
x1 = argmax(pred[:,i,b])
x2 = argmax(y[:,i,b])
# NOTE: Replacing the above two lines with the below two lines results in non-zero gradient.
# x1 = sum(pred[:,i,b]) / len
# x2 = sum(y[:,i,b]) / len
s += abs(x2 - x1)
return s / batchLen
x = rand(width, len, batchLen)
y = rand(width, len, batchLen)
xtest1 = rand(width, len, batchLen)
ytest1 = rand(width, len, batchLen)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
grad = Flux.gradient(() -> loss((x,y)), ps)
Flux.Optimise.update!(opt, ps, grad)
println("loss: ", loss((xtest1, ytest1)))
return (grad, ps)
Version Info
julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab7 (2022-08-17 13:38 UTC)
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × AMD Ryzen 9 5900X 12-Core Processor
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
Threads: 1 on 24 virtual cores
pkg> status
Status C:\data\julia\Project.toml
[6e4b80f9] BenchmarkTools v1.3.1
[4f18b42c] BusinessDays v0.9.18
[052768ef] CUDA v3.12.0
[944b1d66] CodecZlib v0.7.0
[861a8166] Combinatorics v1.0.2
[a8cc5b0e] Crayons v4.1.1
[a10d1c49] DBInterface v2.5.0
[864edb3b] DataStructures v0.18.13
[31c24e10] Distributions v0.25.68
[631263cc] EnergyStatistics v0.2.0
[4e289a0a] EnumX v1.0.2
[7a1cc6ca] FFTW v1.5.0
[fb4d412d] FixedPointDecimals v0.4.0
[587475ba] Flux v0.13.5
[53d20848] FluxArchitectures v0.2.1
[59287772] Formatting v0.4.2
[f7f18e0c] GLFW v3.4.1
[e9467ef8] GLMakie v0.6.13
[cd3eb016] HTTP v1.3.3
[a98d9a8b] Interpolations v0.14.4
[d8418881] Intervals v1.8.0
[c8e1da08] IterTools v1.4.0
[0f8b85d8] JSON3 v1.9.5
[5ab0869b] KernelDensity v0.6.5
[194296ae] LibPQ v1.14.0
[e6f89c97] LoggingExtras v0.4.9
[2fda8390] LsqFit v0.12.1
[d9ec5142] NamedTupleTools v0.14.1
[6fe1bfb0] OffsetArrays v1.12.7
[9b87118b] PackageCompiler v2.0.9
[fa939f87] Pidfile v1.3.0
[08abe8d2] PrettyTables v1.3.1
[295af30f] Revise v3.4.0
[f2b01f46] Roots v2.0.2
[c35d69d1] SMTPClient v0.6.3
[276daf66] SpecialFunctions v2.1.7
[6ec83bb0] StructEquality v2.1.0
[bd369af6] Tables v1.7.0
[b189fb0b] ThreadPools v2.1.1
[9e3dc215] TimeSeries v0.23.0
[f269a46b] TimeZones v1.9.0
[3bb67fe8] TranscodingStreams v0.9.9
[9d95972d] TupleTools v1.3.0
[e88e6eb3] Zygote v0.6.45
[ade2ca70] Dates
[8bb1440f] DelimitedFiles
[de0858da] Printf
[9e88b42a] Serialization
[10745b16] Statistics
[8dfed614] Test
Side question: If there is a quick comprehensive way to troubleshoot why the gradient is zero (even when the loss function is returning non-zero and different values each time) when training with Flux, please tell me. Especially to avoid accidentally printing out millions of parameters to the console and having to restart Julia because that can take a while.