Data-formatting in and out of Flux ML model

Hi, I am still learning the ropes of Flux framework. I have experience of tensorflow and bit of pytorch, but I am interested in Flux, mostly because of the flexibility and uniformity (no more to_numpy and vice versa!).

I wrote a very simple example of a 2D function, but I am stuck querying one point at a time. as opposed to model.predict(x) usability of TF.

Given below is the snippet I wrote, can anyone please provide comments on more efficient ways of organizing the code, or perhaps better way of writing the functions (specially loss function is bit of hack right now i feel). Thank you in advance

using Flux

function f(x,y)
	z = x.^2 + y.^3
	return z
end

x = collect(-2.:0.01:2.)
y = collect(-2.:0.01:2.)
z = f(x,y)
xy = hcat(x,y)

nn = Chain(Dense(2,10,tanh),
			Dense(10,10,tanh),
			Dense(10,1)
			)
# Two loss functions, one for call back, as I could not get the
# same function general enough. 

function loss_cb(x,y)
	l = 0.0
	for i = 1:size(x)[1]
		l += (nn(x[i,:])[1] .- y[i]).^2
	end
	return l
end
	
function loss(x,y)
	l = 0.0
	for i = 1:length(x)
		l += (nn(x[i])[1] - y[i]).^2
	end
	return l
end

cb() = @show(loss_cb(xy,z))

# arrange data in individual tuples, else dataloader gives error saying two
# subsequent queries return different length of array. 802 vs 401.  
dataset = Vector(undef,401)
for i = 1:401
	dataset[i] = xy[i,:]
end

in_data = Flux.Data.DataLoader((dataset,z),batchsize=401)
opt = ADAM(0.1)

Flux.@epochs 200 Flux.train!(loss,Flux.params(nn),in_data,opt,cb=cb)

# how to get predictions on all of the query points at once?
prediction = nn(dataset[1])

Welcome! I think you’ll find that the APIs in Flux are a lot closer to what you’re familar with in Tensorflow (or more specifically, Keras). I would suggest working through the tutorials on Flux – Tutorials and the intro docs at Overview · Flux to get a better idea of how the library works. I know for a fact those will be able to answer all of the comments you left in the code snippet above :slight_smile:

1 Like