Thanks gdalle, I finally wrapped up an example that is working. What I did:
julia> ff(x) = sum((x .- 1).^2)
ff (generic function with 1 method)
julia> using Zygote
julia> function hessz(f, x)
gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
gg = Zygote.gradient(gs, x)[1]
return gg
end
hessz (generic function with 1 method)
julia> hessz(ff, zeros(3))
3-element Vector{Float64}:
-4.0
-4.0
-4.0
In a more complex example, I realized that the problem is using Flux.logitcrossentropy
, something like,
using Flux
using MLDatasets
using Flux: logitcrossentropy, normalise, onecold, onehotbatch
using Statistics: mean
using Zygote
using Parameters: @with_kw
@with_kw mutable struct Args
lr::Float64 = 0.5
repeat::Int = 110
end
function get_processed_data(args)
labels = MLDatasets.Iris.labels()
features = MLDatasets.Iris.features()
# Subract mean, divide by std dev for normed mean of 0 and std dev of 1.
normed_features = normalise(features, dims=2)
klasses = sort(unique(labels))
onehot_labels = onehotbatch(labels, klasses)
# Split into training and test sets, 2/3 for training, 1/3 for test.
train_indices = [1:3:150; 2:3:150]
X_train = normed_features[:, train_indices]
y_train = onehot_labels[:, train_indices]
X_test = normed_features[:, 3:3:150]
y_test = onehot_labels[:, 3:3:150]
#repeat the data `args.repeat` times
train_data_iter = Iterators.repeated((X_train, y_train), args.repeat)
train_data = (X_train, y_train)
test_data = (X_test, y_test)
return train_data, train_data_iter, test_data
end
# Initialize hyperparameter arguments
args = Args(; lr=0.1)
#Loading processed data
train_data, train_data_iter, test_data = get_processed_data(args)
x_train, yc_train = train_data
x_test, yc_test = test_data
function logit_model(wbv, x)
wb = reshape(wbv, 3, :)
return wb[:, 1:end-1] * x .+ wb[:, end]
end
loss_train(wb) = Flux.logitcrossentropy(logit_model(wb, x_train), yc_train)
w0 = ones(15)
Then proceed,
julia> function hessz(f, x)
gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
gg = Zygote.gradient(gs, x)[1]
return gg
end
hessz (generic function with 1 method)
julia> hessz(loss_train, w0)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[10] getindex
@ ./tuple.jl:29 [inlined]
[11] map
@ ./tuple.jl:222 [inlined]
[12] unthunk_tangent
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:36 [inlined]
[13] #1630#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41 [inlined]
[16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76 [inlined]
[18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[19] Pullback
@ ./REPL[1]:2 [inlined]
[20] (::typeof(∂(gs)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[21] (::Zygote.var"#56#57"{typeof(∂(gs))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[22] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
[23] hessz(f::Function, x::Vector{Float64})
@ Main ./REPL[1]:3
[24] top-level scope
@ REPL[2]:1