How to add norm of gradient to a loss function?

I am trying to add a gradient norm penalty to a loss function similar to WGAN-GP ([1704.00028] Improved Training of Wasserstein GANs, Alg 1, Lines 7-9). However, I am running into problems with mutation. Is there another way to do this, or does something have to change in Zygote to accomodate this use-case?

The following code snippet,

using Flux, LinearAlgebra

m = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);
opt = Descent(0.01);

data, labels = rand(Float32, 100, 100), rand(0:1, 100);
labels = reshape(hcat(labels, 1 .- labels), (2,100))

loss(m, x, y) = sum(Flux.crossentropy(m(x), y));

function get_grad(m, data, labels, ps)
            gs = gradient(ps) do
              l = sum(LinearAlgebra.norm, get_grad_inner(m,data,labels,Flux.params(data)))
           end
         end

function get_grad_inner(m, data, labels, ps)
            gs = gradient(ps) do
              l = loss(m, data, labels)
           end
         end

g1 = get_grad(m, data, labels, Flux.params(m))

results in an error

ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float32}, _...)
Stacktrace:
...

Flux uses a similar approach to PyTorch. Namely, it moves the penalty calculation to the gradient update step, see Optimisers · Flux. This saves having to go through AD and is generally more efficient.

The thing is, I need the gradient penalty term to go through the AD. I am trying to do something similar to WGAN-GP ([1704.00028] Improved Training of Wasserstein GANs, Alg 1, Lines 7-9). The loss function involves a norm of the gradient with respect to data: L(w) = \|\nabla_x f_w(x)\|_2. So I would need to run this loss function through the AD again to get the gradient with respect to the parameters.

I’ve updated the title to make this more clear.

I think what I am trying to do used to be possible in Zygote, based on the thread here: Gradient of gradient - #8 by martenlienen

The suggestion:

using Flux
net = Dense(10, 1)
x = randn(10, 128)  # dims, batch

function pred(x, net)
    y, pullback = Zygote.pullback(net, x)
    grads = pullback(ones(size(y)))[1]
    return grads
end

gradient(() -> sum(pred(x, net)), params(net)

Now throws the same error:

ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)

Based on the comments in the linked thread, this used to work fine.

That example still works for me on the latest versions of Flux and Zygote. What may be different for your example (and should also work now, IIRC) is that params is called inside gradient. params does quite a bit of work and mutation under the hood to collect model parameters, so our recommendation is to allocate it once and pass it into the gradient callback.

My mistake, I had instantiated a slightly different network that included a softmax output. It turns out, the issue was related to softmax.

using Flux

net = Dense(10, 1)
net_chain = Chain(Dense(10, 1))
net_softmax = Chain(Dense(10, 1), softmax)

x = randn(10, 128)  # dims, batch

function pred(x, net)
    y, pullback = Zygote.pullback(net, x)
    grads = pullback(ones(size(y)))[1]
    return grads
end

gradient(() -> sum(pred(x, net)), Flux.params(net)) # Works
gradient(() -> sum(pred(x, net_chain)), Flux.params(net_chain)) # Works
gradient(() -> sum(pred(x, net_softmax)), Flux.params(net_softmax)) # Mutation error

Interestingly, removing softmax from my original code snippet doesn’t fix the problem entirely. It seems like there’s a separate issue with the cross entropy loss:

using Flux, LinearAlgebra

m = Chain(Dense(100, 50, relu), Dense(50, 2, relu));
m_softmax = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);

opt = Descent(0.01);

data, labels = rand(Float32, 100, 100), rand(0:1, 100);
labels = reshape(hcat(labels, 1 .- labels), (2,100))

loss(m, x, y) = sum(Flux.crossentropy(m(x), y));
loss_sum(m, x, y) = sum(m(x))

function get_grad(m, data, labels, ps, loss)
            gs = gradient(ps) do
              l = sum(LinearAlgebra.norm, get_grad_inner(m,data,labels,Flux.params(data),loss))
           end
         end

function get_grad_inner(m, data, labels, ps, loss)
            gs = gradient(ps) do
              l = loss(m, data, labels)
           end
         end

get_grad(m_softmax, data, labels, Flux.params(m_softmax), loss) # Mutation error
get_grad(m, data, labels, Flux.params(m), loss) # Mutation error
get_grad(m_softmax, data, labels, Flux.params(m_softmax), loss_sum) # Mutation error
get_grad(m, data, labels, Flux.params(m), loss_sum) # Works

I recall there being a discussion or issue about NNlib.softmax not being twice differentiable, but unfortunately not where it happened. If you wouldn’t mind posting the full stacktraces of both mutation error cases, we can take a deeper look at it and the crossentropyloss as well.

Thanks, I have a really hard time understanding these stacktraces, especially for the second one.

Full stack track for mutation seemingly to occur from softmax in chain:

ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#432#433"{Matrix{Float64}})(#unused#::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/lib/array.jl:74
  [3] (::Zygote.var"#2358#back#434"{Zygote.var"#432#433"{Matrix{Float64}}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] Pullback
    @ ~/.julia/packages/NNlib/CSWJa/src/softmax.jl:71 [inlined]
  [8] (::typeof(∂(#∇softmax!#50)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/NNlib/CSWJa/src/softmax.jl:70 [inlined]
 [10] (::typeof(∂(∇softmax!##kw)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/NNlib/CSWJa/src/softmax.jl:61 [inlined]
 [12] (::typeof(∂(#∇softmax#48)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/NNlib/CSWJa/src/softmax.jl:61 [inlined]
 [14] (::typeof(∂(∇softmax##kw)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/NNlib/CSWJa/src/softmax.jl:81 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/chainrules.jl:140 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
 [22] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
 [24] (::typeof(∂(λ)))(Δ::Tuple{Nothing, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:41 [inlined]
 [26] (::typeof(∂(λ)))(Δ::Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [27] Pullback
    @ ./REPL[94]:3 [inlined]
 [28] (::typeof(∂(pred)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [29] Pullback
    @ ./REPL[97]:1 [inlined]
 [30] (::typeof(∂(#41)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [31] (::Zygote.var"#90#91"{Params, typeof(∂(#41)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:348
 [32] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:76
 [33] top-level scope
    @ REPL[97]:1
 [34] top-level scope
    @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52

Full stack track for mutation seemingly from cross entropy loss function:

ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float32}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#432#433"{Matrix{Float32}})(#unused#::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/lib/array.jl:74
  [3] (::Zygote.var"#2358#back#434"{Zygote.var"#432#433"{Matrix{Float32}}})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] Pullback
    @ ~/.julia/packages/Zygote/l3aNG/src/lib/array.jl:279 [inlined]
  [8] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/losses/functions.jl:216 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/losses/functions.jl:215 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [15] Pullback
    @ ./REPL[75]:1 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [17] Pullback
    @ ./REPL[78]:3 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:348 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:76 [inlined]
 [22] (::typeof(∂(gradient)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [23] Pullback
    @ ./REPL[78]:2 [inlined]
 [24] (::typeof(∂(get_grad_inner)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [25] Pullback
    @ ./REPL[77]:3 [inlined]
 [26] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#90#91"{Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:348
 [28] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:76
 [29] get_grad(m::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, data::Matrix{Float32}, labels::Matrix{Int64}, ps::Params, loss::Function)
    @ Main ./REPL[77]:2
 [30] top-level scope
    @ REPL[88]:1
 [31] top-level scope
    @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52

If you want to calculate a second order gradient, you need to ensure that all pullbacks are differentiable, which is not the case of softmax. I have somewhere on my computer twice differentiable softmax for zygote, which i cen send you once i am at the computer.

I managed to get these working with loss_sum:

julia> function get_grad_inner(m, data, labels, loss)
           gradient(data -> loss(m, data, labels), data) |> only
       end

julia> function get_grad(m, data, labels, ps, loss)
           gradient(ps) do
               norm(get_grad_inner(m, data, labels, loss))
           end
       end

If you can acquire a twice differentiable log softmax, then loss(m, x, y) = Flux.crossentropy(m(x), y) with no final softmax activation should be the most numerically stable loss as well.

Here is the twice differentiable log of cross entropy, though it might be with an older version of ChainRulesCore. Bringing the up to date should be straightforward

@inline  Flux.Losses.logitcrossentropy(ŷ::Matrix, y::Matrix) = logce(ŷ, y)
@inline  Flux.Losses.logitcrossentropy(ŷ::Vector, y::Vector) = logce(ŷ, y)
@inline  Flux.Losses.logitcrossentropy(ŷ::Matrix, y::Vector) = logce(ŷ, y)

function logce(ŷ, y)
	softŷ = softmax(ŷ; dims=1)
	l = sum(y .* log.(max.(1f-6, softŷ)), dims = 1)
	n = length(l)
	-mean(l)
end

function ∇lsoftmax(Δ, xs; dims=1) 
	o = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
end

# function ∇softmax!(out::AbstractArray, Δ::AbstractArray, 
#                     x::AbstractArray, y::AbstractArray; dims = 1)
#     out .= Δ .* y
#     out .= out .- y .* sum(out; dims = dims)
# end


function ∇₂softmax(Δ₂, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
	Δ₂y = Δ₂ .* y
	sΔ₂y = sum(Δ₂y, dims = dims)
	(Δ₂ .* y .- sΔ₂y .* y), Zero(),  (Δ₂ .* Δ .- Δ₂ .* sum(Δ .* y, dims = dims) .- sΔ₂y .* Δ)
end


function ChainRulesCore.rrule(::typeof(NNlib.∇softmax),  Δ, x, softx; dims=1)
    y = ∇softmax(Δ, x, softx; dims=dims)
    function ∇softmax_pullback(Δ₂)
		NO_FIELDS, ∇₂softmax(Δ₂, Δ, x, softx; dims = dims)...
    end
    return y, ∇softmax_pullback
end

function ∇logce(Δ, logŷ, y::Matrix, n)
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n 
	∇y =  - Δ .* logsoftmax(logŷ; dims=1) ./ n
	(∇logŷ, ∇y)
end

function ∇logce(Δ, logŷ, y::Vector, n)
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n 
	∇y =  -mean(Δ .* logsoftmax(logŷ; dims=1), dims = 2)[:]
	(∇logŷ, ∇y)
end


function ChainRulesCore.rrule(::typeof(logce), logŷ, y)
	o = logce(logŷ, y)
	function g(Δ)
		(NO_FIELDS, ∇logce(Δ, logŷ, y, size(logŷ, 2))...)
	end
	o, g
end
2 Likes

Hello, I was also trying to implement WGAN-GP on Flux and got a similar problem with gradient penalty. The above approach(es) works for simple fully connected networks but will crash for convolutional networks, as the toy code below, which will crash even on CPU.
I was using Julia 1.6, Flux v 0.12.7.
Does anyone else get the same?

using Flux
using Statistics: mean

function pred(net, x)
    y, grad_func = Flux.pullback(net, x)
    grads = grad_func(ones(size(y)))[1]
    return mean(grads)
end

net = Chain(
        Conv((3, 3), 3 => 3; pad = 1),
        x->leakyrelu.(x, 0.2f0),
        x->reshape(x, 4 * 4 * 3, :),
        Dense(4 * 4 * 3, 1))    
x = rand(Float32, 4, 4, 3, 10) 

# This line is to check the validity of the network and the first gradient.
pred(net, x) 

ps = Flux.params(net)

y, b = Flux.pullback(ps) do 
    dx_grad = pred(net, x)
end # Crash
b(1f0)

ERROR: LoadError: TaskFailedException
Stacktrace:
  [1] wait
    @ ./task.jl:322 [inlined]
  [2] fetch(t::Task)
    @ Base ./task.jl:337
  [3] (::Zygote.var"#322#324"{Zygote.Context, Task})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/base.jl:75
  [4] (::Zygote.var"#2015#back#325"{Zygote.var"#322#324"{Zygote.Context, Task}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] macro expansion
    @ ./threadingconstructs.jl:173 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:256 [inlined]
  [7] macro expansion
    @ ./task.jl:387 [inlined]
  [8] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:252 [inlined]
  [9] (::typeof(∂(#∇conv_filter!#169)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:241 [inlined]
 [11] (::typeof(∂(∇conv_filter!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151 [inlined]
 [13] (::typeof(∂(#∇conv_filter!#126)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151 [inlined]
 [15] (::typeof(∂(∇conv_filter!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:118 [inlined]
 [17] (::typeof(∂(#∇conv_filter#93)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:115 [inlined]
 [19] (::typeof(∂(∇conv_filter)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [20] #203
    @ ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:203 [inlined]
 [21] #1734#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [22] Pullback
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:314 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/ChainRulesCore/7OROc/src/tangent_types/thunks.jl:194 [inlined]
 [25] Pullback
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/chainrules.jl:104 [inlined]
 [26] (::Zygote.var"#551#556")(::Tuple{Array{Float32, 4}, typeof(∂(wrap_chainrules_output))}, δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/array.jl:195
 [27] map
    @ ./tuple.jl:233 [inlined]
 [28] map (repeats 2 times)
    @ ./tuple.jl:236 [inlined]
 [29] #550
    @ ~/.julia/packages/Zygote/rv6db/src/lib/array.jl:195 [inlined]
 [30] #2577#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [31] Pullback
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/chainrules.jl:105 [inlined]
 [32] Pullback
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/chainrules.jl:179 [inlined]
 [33] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Array{Float32, 4}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/conv.jl:165 [inlined]
 [35] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Array{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [36] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
 [37] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Array{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [38] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
 [39] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Array{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [40] Pullback
    @ ~/.julia/packages/Zygote/rv6db/src/compiler/interface.jl:41 [inlined]
 [41] (::typeof(∂(λ)))(Δ::Tuple{Array{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [42] Pullback
    @ ~/test/test1/test2.jl:51 [inlined]
 [43] (::typeof(∂(pred)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [44] Pullback
    @ ~/test/test1/test2.jl:71 [inlined]
 [45] (::typeof(∂(#19)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
 [46] (::Zygote.var"#84#85"{Zygote.Params, typeof(∂(#19)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface.jl:343
 [47] top-level scope
    @ ~/test/test1/test2.jl:73

    nested task error: Can't differentiate gc_preserve_end expression
    Stacktrace:
     [1] error(s::String)
       @ Base ./error.jl:33
     [2] Pullback
       @ ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_im2col.jl:106 [inlined]
     [3] (::typeof(∂(#∇conv_filter_im2col!#390)))(Δ::Nothing)
       @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
     [4] Pullback
       @ ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_im2col.jl:75 [inlined]
     [5] (::typeof(∂(∇conv_filter_im2col!)))(Δ::Nothing)
       @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
     [6] Pullback
       @ ./threadingconstructs.jl:169 [inlined]
     [7] (::typeof(∂(λ)))(Δ::Nothing)
       @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
     [8] (::Zygote.var"#327#328"{Nothing, typeof(∂(λ))})()
       @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/base.jl:81

As always, please include a full stacktrace along with your MWE.

My bad. The original post is modified.
(And it is really nice to have such a quick reply.)

1 Like

Ok, this is tricky to accomplish with Zygote alone, but you could try mixing ADs like in Zygote push forward unable to differentiate through generic broadcast - #10 by RynoLaubscher. Unfortunately I don’t have any personal experience with that, so if you’re unable to get that working and nobody else replies, I’d recommend hitting up #autodiff on Slack.

1 Like

Still helpful, thanks!

Thanks for all the help, and I’m glad the thread has been helpful to others. I only recently needed a twice differentiable softmax. The only error I encountered was regarding NO_FIELDS which is now NoTangent() or ZeroTangent(). Changing only that lets the code go through for first and second derivatives. For some reason however, none of the second derivative functions are being hit. I do not have a good understanding of ChainRules to know if that’s correct behavior. Comparing to finite differencing, it seems to work fine.

using Flux
using StatsBase
import Flux.Zygote.ChainRulesCore
import Flux.Zygote.ChainRulesCore: NoTangent, ZeroTangent
import Flux.NNlib

function logce(ŷ, y)
	softŷ = softmax(ŷ; dims=1)
	l = sum(y .* log.(max.(1f-6, softŷ)), dims = 1)
	n = length(l)
	-mean(l)
end

function ∇lsoftmax(Δ, xs; dims=1)
	o = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
end

# function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
#                     x::AbstractArray, y::AbstractArray; dims = 1)
#     out .= Δ .* y
#     out .= out .- y .* sum(out; dims = dims)
# end


function ∇₂softmax(Δ₂, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
    println("second grad softmax")
	Δ₂y = Δ₂ .* y
	sΔ₂y = sum(Δ₂y, dims = dims)
	(Δ₂ .* y .- sΔ₂y .* y), Zero(),  (Δ₂ .* Δ .- Δ₂ .* sum(Δ .* y, dims = dims) .- sΔ₂y .* Δ)
end


function ChainRulesCore.rrule(::typeof(NNlib.∇softmax),  Δ, x, softx; dims=1)
    println("rrule softmax")
    y = ∇softmax(Δ, x, softx; dims=dims)
    function ∇softmax_pullback(Δ₂)
		ZeroTangent(), ∇₂softmax(Δ₂, Δ, x, softx; dims = dims)...
    end
    return y, ∇softmax_pullback
end

function ∇logce(Δ, logŷ, y::Matrix, n)
    println("grad logce matrix")
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
	∇y =  - Δ .* logsoftmax(logŷ; dims=1) ./ n
	(∇logŷ, ∇y)
end

function ∇logce(Δ, logŷ, y::Vector, n)
    println("grad logce vector")
	∇logŷ = -∇lsoftmax(Δ .* y, logŷ; dims=1) ./ n
	∇y =  -mean(Δ .* logsoftmax(logŷ; dims=1), dims = 2)[:]
	(∇logŷ, ∇y)
end


function ChainRulesCore.rrule(::typeof(logce), logŷ, y)
    println("rrule logce")
	o = logce(logŷ, y)
	function g(Δ)
		(ZeroTangent(), ∇logce(Δ, logŷ, y, size(logŷ, 2))...)
	end
	o, g
end

function first_order_grad(loss, pred, target)
    grads_inner = gradient(Flux.params(pred)) do
        loss(pred, target)
    end
    sum(grads_inner[pred].^2)
end

function second_order_grad(loss, pred, target)
    grads = gradient(Flux.params(pred)) do
        first_order_grad(loss, pred, target)
    end
    return sum(grads[pred].^2)
end

second_order_grad((x,y) -> sum((x .- y).^3), [2], [0]) #12
g1 = first_order_grad(logce, [0.5,0.5], [0.1, 0.9])
g2 = second_order_grad(logce, [0.5,0.5], [0.1, 0.9]) # Works, unlike Flux.logitcrossentropy