AD Troubles in Flux and Unusual Loss

I’m struggling with the following issue. I want to train on a somewhat nonstandard loss function involving the gradient of the target. If everything worked, I would. use the following code:

using Flux
using Random
using Printf
using LinearAlgebra
using Statistics

d = 2;
h = 1.0f0;
n_x = 10;

f_true(x) = sin(x[1]) * cos(x[2]); # arbitrary, just want a 2=>1 example function

# cook up data
Random.seed!(500);

x_train = [randn(Float32, d) for _ in 1:n_x];
y_train = Float32.(f_true.(x_train));

# convert to matrix
X_train = hcat(x_train...);
Y_train = hcat(y_train...);

W_train = [X_train[:, j] .- X_train for j in 1:n_x];
training_data = [(W_train, y_train)];

inner_layer = Dense(d => d, bias=false);
# what I would naively implement, to create a column vector of the gaussian of the column-wise norms
function problem_layer1(w)
    return [exp(-norm(w[:, j])^2 / (2h)) for j in 1:n_x]
end
outer_layer = Dense(n_x => 1, bias=false);
f1 = Chain(inner_layer, problem_layer1, outer_layer);

# nonstandard loss
function grad_loss(f, w, y)
    ∇g(w) = Flux.gradient(ϕ_ -> Chain(f.layers[2:end])(ϕ_)[1][1], f.layers[1](w))[1]
    return mean([norm(∇g(w_))^2 for w_ in w])
end
# now try training
opt_f_state1 = Flux.setup(Flux.Adam(0.01), f1)
Flux.train!(grad_loss, f1, training_data, opt_f_state1)

and this produces the error:

Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).

Next, I tried something slightly different, which did work:

# this is not the same function, but it has a similar flavor
function problem_layer2(w)
    return sum(abs2, w, dims=1)'
end
g2 = Chain(inner_layer, problem_layer2, outer_layer);
# try training
opt_g_state2 = Flux.setup(Flux.Adam(0.01), g2)
Flux.train!(grad_loss, g2, training_data, opt_g_state2)

and this executes without error (which is not to say it’s correct).

Next, I tried:

f2 = Chain(inner_layer, problem_layer2, norm2 -> exp.(-norm2 ./ (2h)), outer_layer);

opt_f_state2 = Flux.setup(Flux.Adam(0.01), f2)
Flux.train!(grad_loss, f2, training_data, opt_f_state2)

and I got

Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).

I even tried the simpler problem

h2 = Chain(inner_layer, problem_layer2, norm2 -> norm2 ./ (-2h), outer_layer);

which just does a scaling, and this also gave me a foreigncall error.

Thanks for any suggestions on how to sort this out.

Flux (Zygote) doesn’t deal with nested gradients nicely.

You need to either use Enzyme for the autodiff or if you really want to use Zygote, then call it via Lux (Nested Automatic Differentiation | Lux.jl Docs)