I am trying to write loss functions for use in a simple Neural Network training.
As I understand from the documentation loss functions should have the signature loss(ŷ, y). So to my mind something like
loss(ŷ, y) = mean(abs.(ŷ .- y))
ought to work but it does not for me (no error is thrown but the fit never improves). The only one I can get to work is to have the loss function refer to global variables (rather than those passed to it in the train! function). This cannot be how it is meant to work so I was wondering how I can modify my example to get something running properly.
My MWE is:
using Statistics
using Flux
# Making dummy data
obs = 1000
x = rand(Float64, 5, obs)
y = mean(x, dims=1) + sum(x, dims=1)
y[findall(x[4,:] .< 0.3)] .= 17 # Making it slightly harder.
# Making model
m = Chain(
Dense(5, 5, σ),
Dense(5, 1))
dataset = zip(x,y)
opt = Descent()
# Attempt 1: Fit does not improve
mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y)) # Copypasted from here https://github.com/FluxML/Flux.jl/blob/0fa97759367227ced0bde28f39ba5d2abc08e8c7/src/losses/functions.jl#L1-L7
Flux.train!(mae, params(m), dataset, opt)
Flux.mae(m(x),y)
Flux.train!(mae, params(m), dataset, opt)
Flux.mae(m(x),y)
# Attempt 2: Fit does not improve
loss2 = Flux.mae
Flux.train!(loss2, params(m), dataset, opt)
Flux.mae(m(x),y)
Flux.train!(loss2, params(m), dataset, opt)
Flux.mae(m(x),y)
# Attempt 3: This throws an error.
loss3(A, B) = Flux.mae(m(A),B)
Flux.train!(loss3, params(m), dataset, opt)
Flux.mae(m(x),y)
Flux.train!(loss3, params(m), dataset, opt)
Flux.mae(m(x),y)
# Attempt 4: This works but it is terrible (the loss4 function uses global variables rather than anything passed in)
loss4(A, B) = Flux.mae(m(x),y)
Flux.train!(loss4, params(m), dataset, opt)
Flux.mae(m(x),y)
Flux.train!(loss4, params(m), dataset, opt)
Flux.mae(m(x),y)
I can see that the Flux.train! is inputting Float64s (ie ŷ and y). So Attempt 1 should have worked.
function loss5(x,y)
println("x is a ", typeof(x), " and y is a ", typeof(y))
error("Error To stop training with a pointless loss function")
end
Flux.train!(loss5, params(m), dataset, opt)