I am writing to write a code for interpolating a function using a neural network in Flux.jl, but the loss function does not decrease at all. Am I doing something wrong here? I am sorry if there is a trivial problem, as I am new to using Flux.jl. The same code works with slight modifications for a 1d function.
using Flux: train!
using Flux
using Statistics
x_train = rand(Float64, (2, 100))
y_train = x_train[1,:] + x_train[2,:]
predict = f64(Chain(Dense(2,64,leakyrelu),
Dense(64,64,leakyrelu),
Dense(64,1,leakyrelu)))
loss(model, x, y) = mean(abs2.(model(x) .- y))
data = [(x_train, y_train)]
opt = Flux.setup(Adam(0.01), predict)
for epoch in 1:5000
train!(loss, predict, data, opt)
println(loss(predict,x_train,y_train))
end
y_predict=predict(x_train)
for i in 1:100
println(x_train[1,i]," ",x_train[2,i]," ",y_predict[i]," ",y_train[i])
end
Replace
y_train = x_train[1,:] + x_train[2,:]
by
y_train = reshape(x_train[1,:] + x_train[2,:], 1, :)
to make the shape of y_train
match that of x_train
. The last dimension is the batch dimension.
Also, use mean(abs2, model(x).-y)
to avoid an unnecessary intermediate allocation.
In your original code, predict(x_train).-y_train
produces
julia> predict(x_train).-y_train
100×100 Matrix{Float64}:
-0.000221335 0.165909 0.562235 1.10108 0.422628 0.764869 … -0.0057506 0.900018 0.491736 0.215371 0.467521
-0.16728 -0.00114932 0.395176 0.934021 0.25557 0.597811 -0.172809 0.732959 0.324677 0.0483125 0.300463
-0.561125 -0.394994 0.00133184 0.540176 -0.138275 0.203966 -0.566654 0.339115 -0.0691673 -0.345532 -0.093382
-1.10098 -0.934844 -0.538519 0.00032562 -0.678125 -0.335884 -1.1065 -0.200736 -0.609018 -0.885383 -0.633232
[...]
due to the broadcast .
Clearly not what you intended!
Thanks a lot! That made the code work. I was working under the wrong assumption that I would get an error if there were a dimensions mismatch.
You do if you use Flux.Losses.mse
in your loss function. That said, I agree that it would be prudent for a high-level function such as train!
to give a warning if the input shapes are likely not as intended.
https://fluxml.ai/Flux.jl/stable/models/losses/
1 Like