I don’t speak English well, so sorry for the mistakes.
I have an exercise of differential programming.
I have an image taken from MNIST and a fully connected NN that classifies correctly that image.
Then I want to modify that image with differential programming until it is classified as 0.
The program works with ForwardDiff.gradient
using Flux using BSON using ForwardDiff function load_model(path_to_model) m = BSON.load(path_to_model) m[:modello] end function load_data_mnist() images = hcat([Float32.(i)[:] for i in Flux.Data.MNIST.images()]...) # 784*60_000 labels = Flux.onehotbatch(Flux.Data.MNIST.labels(), 0:9) .|> Float32 # 10 *60_000 images, labels end function accuracy(a, b) prev(x) = argmax(x) - 1 L = size(a) right = 0 wrong = 0 for i in 1:L @inbounds if prev(a[:, i]) == prev(b[:, i]) right += 1 else wrong += 1 end end 100right / (wrong + right) end model = load_model("model.bson") images, labels = load_data_mnist() # 784*60_000, 10 *60_000, don't need train/dev/test model_accuracy = accuracy(model(images), labels) # 96.7% index_image = 10 z = images[:, index_image:index_image] # image that will be classified as zero (4->0) # check if the model classifies it correctly labels[:, index_image] |> println # 4 model(z) |> println # the model classifies the image as 4 zero_label = reshape([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], (10, 1)) opt = ADAM() loss(x) = Flux.mse(model(x), zero_label) loss(z) |> println for i in 1:100 ∂loss = ForwardDiff.gradient(loss, z) # this work Flux.update!(opt, z, ∂loss) end model(z) |> println # the model classifies the modified image as 0
If I use Flux.gradient the program does not work and after the first call of Flux.gradient(loss) the model is broken
for i in 1:100 ∂loss = Flux.gradient(loss, z) # this does not work and change the model Flux.update!(opt, z, ∂loss) end # if i test model accuracy I get accuracy(model(images), labels) |> println # 9.8% loss(z) |> println # NaN32 model(images) |> println # a matrix full of NaN, the model is broken
Why Flux.gradient change the model?
I don’t expect it.