# Sobolov training in Flux - first derivative of input wrt output in loss function

I have a dataset with input images `x_train`, an objective value corresponding to each image `y_train`, and the set of derivatives of `y_train` wrt `x_train`, `dc_train`. I would like to train a CNN using this data, where I find the first derivative of output wrt input during training and use this in the loss function along with the usual loss.

I did this in the following way using `ForwardDiff`

``````using Flux
using ForwardDiff

x_train = Float32.(rand(10,30,1,50))
norm_y_train = Float32.(rand(1,50))
dc_train = Float32.(rand(10,30,50))

mean_y_train = 300.0
sd_y_train = 1000.0

model = Chain(

MaxPool((3,3)),
MaxPool((3,3)),

Flux.flatten,
Dense(48 => 1)
)

struct Model_struct{MT,T}

model::MT

sd::T

mean::T

end

function (ma::Model_struct)(x::Array)

return (ma.model(x).*ma.sd .+ ma.mean)[]

end

function my_loss(model, x, y, dc)
nn = Model_struct(model, sd_y_train, mean_y_train)
y_hat = model(x)
return Flux.mse(y_hat, y) + (0.5 * Flux.mse(dc_hat, dc))
end

epochs = []
train_loss = []

for epoch in 1:100
for (x, y, dc) in loader
my_loss(m, x, y, dc)
end
end
push!(train_loss, Flux.mse(model(x_train), norm_y_train))
push!(epochs, epoch)
end
``````

I have used random initiallizations for representation in the above example. Also `y_train` is normalized for training with mean `mean_y_train` and standard deviation `sd_y_train`.

This implementation gives me the below warning

┌ Warning: `ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = Model_struct{Chain{Tuple{Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Float64}
└ @ Zygote C:\Users\User\.julia\packages\Zygote\SuKWp\src\lib\forward.jl:142

I’m not able to figure out why this warning appears. I’m also not sure if this is the best way to implement this training.

This was implemented based on Sobolov training method. Is there a better way to implement this in Flux?