How could I create a loss function with 3 inputs like loss(w,x,y)?
Could someone show me an example?
Have you run into an issue using a 3 input loss function? loss(w, x, y)
should work just fine. For example:
loss(m, x, y) = Flux.mse(m(x), y)
Since you mentioned Flux and the Flux doc for Flux.train!
says the below
If d is a tuple of arguments to loss call loss(d…), else call loss(d).
I imagine, you can have as many inputs as you want as long as you pass in the data = [(w, x, y), (w1, x1, y1) etc etc]
Full Doc as below
train!(loss, params, data, opt; cb)
For each datapoint d in data, compute the gradient of loss with respect to params through backpropagation and call
the optimizer opt.
**If d is a tuple of arguments to loss call loss(d...), else call loss(d).**
A callback is given with the keyword argument cb. For example, this will print "training" every 10 seconds (using
Flux.throttle):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call Flux.stop to interrupt the training loop.
Multiple optimisers and callbacks can be passed to opt and cb as arrays.