Thanks for all the help, and I’m glad the thread has been helpful to others. I only recently needed a twice differentiable softmax. The only error I encountered was regarding NO_FIELDS which is now NoTangent() or ZeroTangent(). Changing only that lets the code go through for first and second derivatives. For some reason however, none of the second derivative functions are being hit. I do not have a good understanding of ChainRules to know if that’s correct behavior. Comparing to finite differencing, it seems to work fine.
using Flux
using StatsBase
import Flux.Zygote.ChainRulesCore
import Flux.Zygote.ChainRulesCore: NoTangent, ZeroTangent
import Flux.NNlib
function logce(ŷ, y)
softŷ = softmax(ŷ; dims=1)
l = sum(y .* log.(max.(1f-6, softŷ)), dims = 1)
n = length(l)
-mean(l)
end
function ∇lsoftmax(Δ, xs; dims=1)
o = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
end
# function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
# x::AbstractArray, y::AbstractArray; dims = 1)
# out .= Δ .* y
# out .= out .- y .* sum(out; dims = dims)
# end
function ∇₂softmax(Δ₂, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
println("second grad softmax")
Δ₂y = Δ₂ .* y
sΔ₂y = sum(Δ₂y, dims = dims)
(Δ₂ .* y .- sΔ₂y .* y), Zero(), (Δ₂ .* Δ .- Δ₂ .* sum(Δ .* y, dims = dims) .- sΔ₂y .* Δ)
end
function ChainRulesCore.rrule(::typeof(NNlib.∇softmax), Δ, x, softx; dims=1)
println("rrule softmax")
y = ∇softmax(Δ, x, softx; dims=dims)
function ∇softmax_pullback(Δ₂)
ZeroTangent(), ∇₂softmax(Δ₂, Δ, x, softx; dims = dims)...
end
return y, ∇softmax_pullback
end
function ∇logce(Δ, logŷ, y::Matrix, n)
println("grad logce matrix")
∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
∇y = - Δ .* logsoftmax(logŷ; dims=1) ./ n
(∇logŷ, ∇y)
end
function ∇logce(Δ, logŷ, y::Vector, n)
println("grad logce vector")
∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
∇y = -mean(Δ .* logsoftmax(logŷ; dims=1), dims = 2)[:]
(∇logŷ, ∇y)
end
function ChainRulesCore.rrule(::typeof(logce), logŷ, y)
println("rrule logce")
o = logce(logŷ, y)
function g(Δ)
(ZeroTangent(), ∇logce(Δ, logŷ, y, size(logŷ, 2))...)
end
o, g
end
function first_order_grad(loss, pred, target)
grads_inner = gradient(Flux.params(pred)) do
loss(pred, target)
end
sum(grads_inner[pred].^2)
end
function second_order_grad(loss, pred, target)
grads = gradient(Flux.params(pred)) do
first_order_grad(loss, pred, target)
end
return sum(grads[pred].^2)
end
second_order_grad((x,y) -> sum((x .- y).^3), [2], [0]) #12
g1 = first_order_grad(logce, [0.5,0.5], [0.1, 0.9])
g2 = second_order_grad(logce, [0.5,0.5], [0.1, 0.9]) # Works, unlike Flux.logitcrossentropy