How to add norm of gradient to a loss function?

Thanks for all the help, and I’m glad the thread has been helpful to others. I only recently needed a twice differentiable softmax. The only error I encountered was regarding NO_FIELDS which is now NoTangent() or ZeroTangent(). Changing only that lets the code go through for first and second derivatives. For some reason however, none of the second derivative functions are being hit. I do not have a good understanding of ChainRules to know if that’s correct behavior. Comparing to finite differencing, it seems to work fine.

using Flux
using StatsBase
import Flux.Zygote.ChainRulesCore
import Flux.Zygote.ChainRulesCore: NoTangent, ZeroTangent
import Flux.NNlib

function logce(ŷ, y)
	softŷ = softmax(ŷ; dims=1)
	l = sum(y .* log.(max.(1f-6, softŷ)), dims = 1)
	n = length(l)
	-mean(l)
end

function ∇lsoftmax(Δ, xs; dims=1)
	o = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
end

# function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
#                     x::AbstractArray, y::AbstractArray; dims = 1)
#     out .= Δ .* y
#     out .= out .- y .* sum(out; dims = dims)
# end


function ∇₂softmax(Δ₂, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
    println("second grad softmax")
	Δ₂y = Δ₂ .* y
	sΔ₂y = sum(Δ₂y, dims = dims)
	(Δ₂ .* y .- sΔ₂y .* y), Zero(),  (Δ₂ .* Δ .- Δ₂ .* sum(Δ .* y, dims = dims) .- sΔ₂y .* Δ)
end


function ChainRulesCore.rrule(::typeof(NNlib.∇softmax),  Δ, x, softx; dims=1)
    println("rrule softmax")
    y = ∇softmax(Δ, x, softx; dims=dims)
    function ∇softmax_pullback(Δ₂)
		ZeroTangent(), ∇₂softmax(Δ₂, Δ, x, softx; dims = dims)...
    end
    return y, ∇softmax_pullback
end

function ∇logce(Δ, logŷ, y::Matrix, n)
    println("grad logce matrix")
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
	∇y =  - Δ .* logsoftmax(logŷ; dims=1) ./ n
	(∇logŷ, ∇y)
end

function ∇logce(Δ, logŷ, y::Vector, n)
    println("grad logce vector")
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
	∇y =  -mean(Δ .* logsoftmax(logŷ; dims=1), dims = 2)[:]
	(∇logŷ, ∇y)
end


function ChainRulesCore.rrule(::typeof(logce), logŷ, y)
    println("rrule logce")
	o = logce(logŷ, y)
	function g(Δ)
		(ZeroTangent(), ∇logce(Δ, logŷ, y, size(logŷ, 2))...)
	end
	o, g
end

function first_order_grad(loss, pred, target)
    grads_inner = gradient(Flux.params(pred)) do
        loss(pred, target)
    end
    sum(grads_inner[pred].^2)
end

function second_order_grad(loss, pred, target)
    grads = gradient(Flux.params(pred)) do
        first_order_grad(loss, pred, target)
    end
    return sum(grads[pred].^2)
end

second_order_grad((x,y) -> sum((x .- y).^3), [2], [0]) #12
g1 = first_order_grad(logce, [0.5,0.5], [0.1, 0.9])
g2 = second_order_grad(logce, [0.5,0.5], [0.1, 0.9]) # Works, unlike Flux.logitcrossentropy