Sobolov training in Flux - first derivative of input wrt output in loss function

I have a dataset with input images x_train, an objective value corresponding to each image y_train, and the set of derivatives of y_train wrt x_train, dc_train. I would like to train a CNN using this data, where I find the first derivative of output wrt input during training and use this in the loss function along with the usual loss.

I did this in the following way using ForwardDiff

using Flux
using ForwardDiff

x_train = Float32.(rand(10,30,1,50))
norm_y_train = Float32.(rand(1,50))
dc_train = Float32.(rand(10,30,50))

mean_y_train = 300.0
sd_y_train = 1000.0

model = Chain(
    
    Conv((3, 3), 1=>6, pad=(1,1), relu),
    MaxPool((3,3)),
    Conv((3, 3), 6=>16, pad=(1,1), relu),
    MaxPool((3,3)),


    Flux.flatten,
    Dense(48 => 1)
)

loader = Flux.DataLoader((x_train, norm_y_train, dc_train))
opt = Flux.setup(Flux.RAdam(0.01), model)

struct Model_struct{MT,T}

    model::MT

    sd::T

    mean::T

end

function (ma::Model_struct)(x::Array)

    return (ma.model(x).*ma.sd .+ ma.mean)[]

end

function my_loss(model, x, y, dc)
    nn = Model_struct(model, sd_y_train, mean_y_train)
    dc_hat =  ForwardDiff.gradient(nn, x)
    y_hat = model(x)
    return Flux.mse(y_hat, y) + (0.5 * Flux.mse(dc_hat, dc))
end

epochs = []
train_loss = []

for epoch in 1:100
    for (x, y, dc) in loader
        val, grads = Flux.withgradient(model) do m
            my_loss(m, x, y, dc)
        end
        Flux.update!(opt, model, grads[1])
    end
    push!(train_loss, Flux.mse(model(x_train), norm_y_train))
    push!(epochs, epoch)
end

I have used random initiallizations for representation in the above example. Also y_train is normalized for training with mean mean_y_train and standard deviation sd_y_train.

This implementation gives me the below warning

ā”Œ Warning: ForwardDiff.gradient(f, x) within Zygote cannot track gradients with respect to f,
ā”‚ and f appears to be a closure, or a struct with fields (according to issingletontype(typeof(f))).
ā”‚ typeof(f) = Model_struct{Chain{Tuple{Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Float64}
ā”” @ Zygote C:\Users\User\.julia\packages\Zygote\SuKWp\src\lib\forward.jl:142

Iā€™m not able to figure out why this warning appears. Iā€™m also not sure if this is the best way to implement this training.

This was implemented based on Sobolov training method. Is there a better way to implement this in Flux?

1 Like