Hi, I am still learning the ropes of Flux framework. I have experience of tensorflow and bit of pytorch, but I am interested in Flux, mostly because of the flexibility and uniformity (no more to_numpy and vice versa!).
I wrote a very simple example of a 2D function, but I am stuck querying one point at a time. as opposed to model.predict(x)
usability of TF.
Given below is the snippet I wrote, can anyone please provide comments on more efficient ways of organizing the code, or perhaps better way of writing the functions (specially loss function is bit of hack right now i feel). Thank you in advance
using Flux
function f(x,y)
z = x.^2 + y.^3
return z
end
x = collect(-2.:0.01:2.)
y = collect(-2.:0.01:2.)
z = f(x,y)
xy = hcat(x,y)
nn = Chain(Dense(2,10,tanh),
Dense(10,10,tanh),
Dense(10,1)
)
# Two loss functions, one for call back, as I could not get the
# same function general enough.
function loss_cb(x,y)
l = 0.0
for i = 1:size(x)[1]
l += (nn(x[i,:])[1] .- y[i]).^2
end
return l
end
function loss(x,y)
l = 0.0
for i = 1:length(x)
l += (nn(x[i])[1] - y[i]).^2
end
return l
end
cb() = @show(loss_cb(xy,z))
# arrange data in individual tuples, else dataloader gives error saying two
# subsequent queries return different length of array. 802 vs 401.
dataset = Vector(undef,401)
for i = 1:401
dataset[i] = xy[i,:]
end
in_data = Flux.Data.DataLoader((dataset,z),batchsize=401)
opt = ADAM(0.1)
Flux.@epochs 200 Flux.train!(loss,Flux.params(nn),in_data,opt,cb=cb)
# how to get predictions on all of the query points at once?
prediction = nn(dataset[1])