Flux networks as arguments in functions


I recently opened an issue on Flux as there is something that needs clarification I think about using Flux with Julia’s functions.

To summarize my issue: I have a project where I have various equations embedded in functions. Here is an example notebook with a toy problem looking like my research problem.

I saw that most functions used with Flux code are functions that call global variables/Flux structures. However, to work well in my case (particularly when using different scripts or looping between different networks and hence avoiding ambiguity), I directly provide the Flux models as an argument to my functions. This is nice but a potential downside is that this forces me to also provide the network in the dataset list during training… See the notebook for a (quickly written and probably messy) example.

My question is related to the later point: Is this a legit way of doing things, or is there a better way to provide models to functions?

Also, I wonder if this may have any influence on training speed?

Thanks in advance for any insights!


Just to clarify, is your concern specifically about this line?

evalcb = () -> (push!(record_loss_train, loss(X1_training, X2_training, y_training,core).data))

which calls the functions in cell 11

I don’t understand this. Could you make a smaller example than your notebook?

My concern is that, as my functions take the network as an argument, I need to pass the network during training as a dataset when Flux.train is called:

Flux.train!(loss, params(core), [(X1_training, X2_training, y_training,core)], ADAM(0.001)

The [(X1_training, X2_training, y_training,core)] is called a dataset in Flux docs but there it actually also contains the core network as in my case it needs to be passed to the loss() and all other functions.

Is that a problem or should we not care? I mean, it works and training is going fine for now, so maybe it is not important, but I’m concern that this could be a potential problem later.

(@kristoffer.carlsson I think this should clarify things?)

You can just pass

(x1, x2, y) -> loss(x1, x2, y, core)

as the loss function to train!.

OK thanks. I guess it ends up doing the same thing, and this actually is not really a problem.

I could also write my own train() function that is a bit different and takes the network as an input. I will try and see if I see any difference.

I’ve just started doing this (writing my own training function). You have to dig into internals a bit, but the most challenging part is if you’re not familiar with AD. However, if you’re just starting out I’d just stick to the built in train! method while you get a better intuition.