I’m struggling with the following issue. I want to train on a somewhat nonstandard loss function involving the gradient of the target. If everything worked, I would. use the following code:
using Flux
using Random
using Printf
using LinearAlgebra
using Statistics
d = 2;
h = 1.0f0;
n_x = 10;
f_true(x) = sin(x[1]) * cos(x[2]); # arbitrary, just want a 2=>1 example function
# cook up data
Random.seed!(500);
x_train = [randn(Float32, d) for _ in 1:n_x];
y_train = Float32.(f_true.(x_train));
# convert to matrix
X_train = hcat(x_train...);
Y_train = hcat(y_train...);
W_train = [X_train[:, j] .- X_train for j in 1:n_x];
training_data = [(W_train, y_train)];
inner_layer = Dense(d => d, bias=false);
# what I would naively implement, to create a column vector of the gaussian of the column-wise norms
function problem_layer1(w)
return [exp(-norm(w[:, j])^2 / (2h)) for j in 1:n_x]
end
outer_layer = Dense(n_x => 1, bias=false);
f1 = Chain(inner_layer, problem_layer1, outer_layer);
# nonstandard loss
function grad_loss(f, w, y)
∇g(w) = Flux.gradient(ϕ_ -> Chain(f.layers[2:end])(ϕ_)[1][1], f.layers[1](w))[1]
return mean([norm(∇g(w_))^2 for w_ in w])
end
# now try training
opt_f_state1 = Flux.setup(Flux.Adam(0.01), f1)
Flux.train!(grad_loss, f1, training_data, opt_f_state1)
and this produces the error:
Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
Next, I tried something slightly different, which did work:
# this is not the same function, but it has a similar flavor
function problem_layer2(w)
return sum(abs2, w, dims=1)'
end
g2 = Chain(inner_layer, problem_layer2, outer_layer);
# try training
opt_g_state2 = Flux.setup(Flux.Adam(0.01), g2)
Flux.train!(grad_loss, g2, training_data, opt_g_state2)
and this executes without error (which is not to say it’s correct).
Next, I tried:
f2 = Chain(inner_layer, problem_layer2, norm2 -> exp.(-norm2 ./ (2h)), outer_layer);
opt_f_state2 = Flux.setup(Flux.Adam(0.01), f2)
Flux.train!(grad_loss, f2, training_data, opt_f_state2)
and I got
Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
I even tried the simpler problem
h2 = Chain(inner_layer, problem_layer2, norm2 -> norm2 ./ (-2h), outer_layer);
which just does a scaling, and this also gave me a foreigncall
error.
Thanks for any suggestions on how to sort this out.