I have a dataset with input images x_train
, an objective value corresponding to each image y_train
, and the set of derivatives of y_train
wrt x_train
, dc_train
. I would like to train a CNN using this data, where I find the first derivative of output wrt input during training and use this in the loss function along with the usual loss.
I did this in the following way using ForwardDiff
using Flux
using ForwardDiff
x_train = Float32.(rand(10,30,1,50))
norm_y_train = Float32.(rand(1,50))
dc_train = Float32.(rand(10,30,50))
mean_y_train = 300.0
sd_y_train = 1000.0
model = Chain(
Conv((3, 3), 1=>6, pad=(1,1), relu),
MaxPool((3,3)),
Conv((3, 3), 6=>16, pad=(1,1), relu),
MaxPool((3,3)),
Flux.flatten,
Dense(48 => 1)
)
loader = Flux.DataLoader((x_train, norm_y_train, dc_train))
opt = Flux.setup(Flux.RAdam(0.01), model)
struct Model_struct{MT,T}
model::MT
sd::T
mean::T
end
function (ma::Model_struct)(x::Array)
return (ma.model(x).*ma.sd .+ ma.mean)[]
end
function my_loss(model, x, y, dc)
nn = Model_struct(model, sd_y_train, mean_y_train)
dc_hat = ForwardDiff.gradient(nn, x)
y_hat = model(x)
return Flux.mse(y_hat, y) + (0.5 * Flux.mse(dc_hat, dc))
end
epochs = []
train_loss = []
for epoch in 1:100
for (x, y, dc) in loader
val, grads = Flux.withgradient(model) do m
my_loss(m, x, y, dc)
end
Flux.update!(opt, model, grads[1])
end
push!(train_loss, Flux.mse(model(x_train), norm_y_train))
push!(epochs, epoch)
end
I have used random initiallizations for representation in the above example. Also y_train
is normalized for training with mean mean_y_train
and standard deviation sd_y_train
.
This implementation gives me the below warning
ā Warning:
ForwardDiff.gradient(f, x)
within Zygote cannot track gradients with respect tof
,
ā andf
appears to be a closure, or a struct with fields (according toissingletontype(typeof(f))
).
ā typeof(f) = Model_struct{Chain{Tuple{Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 2, typeof(relu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Float64}
ā @ Zygote C:\Users\User\.julia\packages\Zygote\SuKWp\src\lib\forward.jl:142
Iām not able to figure out why this warning appears. Iām also not sure if this is the best way to implement this training.
This was implemented based on Sobolov training method. Is there a better way to implement this in Flux?