What is a model in Julia Flux 0.13 and higher?

Cross post from stackoverflow (flux.jl - What is a model in Julia Flux 0.13 and higher? - Stack Overflow)

I want to use Julia Flux for machine learning with custom models (not neural networks, so I won’t be using/combining models provided by Flux). I want to do it with Flux because I want access to various advanced gradient decent algorithms.

To do so I intend to use the training API (Training API · Flux)

The problem is that the documentation is not very detailed. For example there are functions like

Flux.setup(rule, model)

and

Flux.train!(loss, model, data, opt_state)

however nowhere in the API there is a description of what the model is and what form should it take.

As a test problem, consider matrix factorization. That is, given matrix A

using Random

dim = 2

A = rand(dim, dim)

find such x and y that

x * y' ≈ A

I would guess that the model should be defined as model(x, y) = x * y', but then Flux.setup(AdaGrad(), model) produces a warning

Warning: setup found no trainable parameters in this model

But there is an example right in the first box on the page you link?

julia> model = Dense(2=>1, leakyrelu; init=ones);

julia> opt_state = Flux.setup(Momentum(0.1), model)

julia> x1, y1 = [0.2, -0.3], [0.4];

julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
         sum(abs.(m(x) .- y)) * 100
       end

If you want to create custom models that don’t use the built-in layers there is some information about that in the documentation as well.
https://fluxml.ai/Flux.jl/stable/models/advanced/#Custom-Model-Example

1 Like