Hi there,
I am still somewhat a rookie in Julia and Flux and I have problem understanding what is going on when I switch between crossentropy
and binarycrossentropy
loss functions.
I coded the following simple denoising autoencoder:
using Flux, Random
data = rand(2000,100)
data_corrupted = copy(data)
# Corrupt data
for sample_index in 1:size(data)[2]
# Create random indices
rng = MersenneTwister(1234)
indices = findall(bitrand(rng, 2000) .> 0)
# Change values at indices to 0
for i in 1:size(indices)[1]
data_corrupted[indices[i], sample_index] = 0
end
end
# Partition into batches of 10
data = [data[:, i:min(i+10-1,size(data, 2))] for i in 1:10:size(data, 2)]
data_corrupted = [data_corrupted[:, i:min(i+10-1,size(data_corrupted, 2))] for i in 1:10:size(data_corrupted, 2)]
# Define model
encoder = Dense(2000, 50, σ)
decoder = Dense(50, 2000, σ)
m = Chain(encoder, decoder)
# Defining the loss function
loss(x, y) = Flux.crossentropy(m(x), y)
# Defining the optimiser
opt = ADAM()
# Train
Flux.train!(loss, params(m), zip(data_corrupted, data), opt)
This runs fine.
But if I then change the loss function to:
loss(x, y) = Flux.binarycrossentropy(m(x), y)
I get the following error:
ERROR: LoadError: MethodError: no method matching eps(::Array{Float32,2})
Closest candidates are:
eps(!Matched::Dates.Time) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/Dates/src/types.jl:387
eps(!Matched::Dates.Date) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/Dates/src/types.jl:386
eps(!Matched::Dates.DateTime) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/Dates/src/types.jl:385
...
However if I change the loss to what has been suggested here:
loss(x, y) = Flux.binarycrossentropy(m(x)[1], y[1])
The model trains without any problem.
I have a hard time understanding why I need this indexing for the binarycrossentropy
, while I do not need it for the crossentropy
. I understand that the eps
function requires an array of dim 1, but I am confused as whether it will now calculate the loss only on the first batch instead of on all data.
Any insights are very welcome!
Many thanks,
Sander