I don’t speak English well, so sorry for the mistakes.
I have an exercise of differential programming.
I have an image taken from MNIST and a fully connected NN that classifies correctly that image.
Then I want to modify that image with differential programming until it is classified as 0.
The program works with ForwardDiff.gradient
using Flux
using BSON
using ForwardDiff
function load_model(path_to_model)
m = BSON.load(path_to_model)
m[:modello]
end
function load_data_mnist()
images = hcat([Float32.(i)[:] for i in Flux.Data.MNIST.images()]...) # 784*60_000
labels = Flux.onehotbatch(Flux.Data.MNIST.labels(), 0:9) .|> Float32 # 10 *60_000
images, labels
end
function accuracy(a, b)
prev(x) = argmax(x) - 1
L = size(a)[2]
right = 0
wrong = 0
for i in 1:L
@inbounds if prev(a[:, i]) == prev(b[:, i])
right += 1
else
wrong += 1
end
end
100right / (wrong + right)
end
model = load_model("model.bson")
images, labels = load_data_mnist() # 784*60_000, 10 *60_000, don't need train/dev/test
model_accuracy = accuracy(model(images), labels) # 96.7%
index_image = 10
z = images[:, index_image:index_image] # image that will be classified as zero (4->0)
# check if the model classifies it correctly
labels[:, index_image] |> println # 4
model(z) |> println # the model classifies the image as 4
zero_label = reshape([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], (10, 1))
opt = ADAM()
loss(x) = Flux.mse(model(x), zero_label)
loss(z) |> println
for i in 1:100
∂loss = ForwardDiff.gradient(loss, z) # this work
Flux.update!(opt, z, ∂loss)
end
model(z) |> println # the model classifies the modified image as 0
If I use Flux.gradient the program does not work and after the first call of Flux.gradient(loss)[1] the model is broken
for i in 1:100
∂loss = Flux.gradient(loss, z)[1] # this does not work and change the model
Flux.update!(opt, z, ∂loss)
end
# if i test model accuracy I get
accuracy(model(images), labels) |> println # 9.8%
loss(z) |> println # NaN32
model(images) |> println # a matrix full of NaN, the model is broken
Why Flux.gradient change the model?
I don’t expect it.