Sobolov training in Flux - first derivative of input wrt output in loss function

I have a dataset with input images x_train, an objective value corresponding to each image y_train, and the set of derivatives of y_train wrt x_train, dc_train. I would like to train a CNN using this data, where I find the first derivative of output wrt input during training and use this in the loss function along with the usual loss.

I did this in the following way using ForwardDiff

using Flux
using ForwardDiff

x_train = Float32.(rand(10,30,1,50))
norm_y_train = Float32.(rand(1,50))
dc_train = Float32.(rand(10,30,50))

mean_y_train = 300.0
sd_y_train = 1000.0

model = Chain(
    
    Conv((3, 3), 1=>6, pad=(1,1), relu),
    MaxPool((3,3)),
    Conv((3, 3), 6=>16, pad=(1,1), relu),
    MaxPool((3,3)),


    Flux.flatten,
    Dense(48 => 1)
)

loader = Flux.DataLoader((x_train, norm_y_train, dc_train))
opt = Flux.setup(Flux.RAdam(0.01), model)

struct Model_struct{MT,T}

    model::MT

    sd::T

    mean::T

end

function (ma::Model_struct)(x::Array)

    return (ma.model(x).*ma.sd .+ ma.mean)[]

end

function my_loss(model, x, y, dc)
    nn = Model_struct(model, sd_y_train, mean_y_train)
    dc_hat =  ForwardDiff.gradient(nn, x)
    y_hat = model(x)
    return Flux.mse(y_hat, y) + (0.5 * Flux.mse(dc_hat, dc))
end

epochs = []
train_loss = []

for epoch in 1:100
    for (x, y, dc) in loader
        val, grads = Flux.withgradient(model) do m
            my_loss(m, x, y, dc)
        end
        Flux.update!(opt, model, grads[1])
    end
    push!(train_loss, Flux.mse(model(x_train), norm_y_train))
    push!(epochs, epoch)
end

I have used random initiallizations for representation in the above example. Also y_train is normalized for training with mean mean_y_train and standard deviation sd_y_train.

This implementation gives me the below warning

┌ Warning: ForwardDiff.gradient(f, x) within Zygote cannot track gradients with respect to f,
│ and f appears to be a closure, or a struct with fields (according to issingletontype(typeof(f))).
│ typeof(f) = Model_struct{Chain{Tuple{Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Float64}
└ @ Zygote C:\Users\User\.julia\packages\Zygote\SuKWp\src\lib\forward.jl:142

I’m not able to figure out why this warning appears. I’m also not sure if this is the best way to implement this training.

This was implemented based on Sobolov training method. Is there a better way to implement this in Flux?