I learned of the undocumented gradcheck
function in a recent question, but it was not clear how to apply it to a model (rather than an array). It requires the use of Flux.destructure., as mentioned in the recent pull request 1168.
Here is a self contained example for other beginners. It is based on these links,
https://github.com/FluxML/Flux.jl/issues/1168
https://github.com/FluxML/Zygote.jl/pull/464
https://github.com/FluxML/Zygote.jl/blob/master/test/gradcheck.jl
using Flux
using Random
#----------------------------------------------------------------
struct Mymodel
W
b
end
function Mymodel(in::Integer, out::Integer)
return Mymodel(randn(out, in), randn(out))
end
Flux.@functor Mymodel
(m::Mymodel)(x) = m.W * x .+ m.b
#----------------------------------------------------------------
using FiniteDifferences
# from https://github.com/FluxML/Zygote.jl/pull/464
function neogradient(f, xs::AbstractArray...)
grad(central_fdm(5,1), f, xs...) # FiniteDifferences
end
# from https://github.com/FluxML/Zygote.jl/blob/master/test/gradcheck.jl
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
end
return grads
end
# adapted from https://github.com/FluxML/Flux.jl/issues/1168
function loss_restructure(θvec, re, loss, x, y)
loss(x, y, re(θvec))
end
function gradcheckmodel(model,loss,lossargs...)
θvec, re = Flux.destructure(model)
g1 = Flux.gradient(θvec) do θvec
loss_restructure(θvec, re, loss, lossargs...)
end
g2 = neogradient(θvec) do θvec # or ngradient(...)
loss_restructure(θvec, re, loss, lossargs...)
end
@show typeof(g1) g1
@show typeof(g2) g2
@show maximum(abs, g1[1] .- g2[1])
end
#----------------------------------------------------------------
function loss(x,y, model=model)
ŷ = model(x)
#@show y ŷ
return Flux.mse(y,ŷ)
end
# example inputs
X = [1.,2.]
Y = [3.]
model = Mymodel(2,1)
# test evaluage model and loss
model(X)
loss(X,Y)
# check against numerical gradient
gradcheckmodel(model,loss,X,Y)