# PPL connection to MLJ.jl

I’d like to get Soss.jl connected to MLJ.jl. I can imagine a few ways this might work, and thought I could get some input here on what could be most useful as first steps.

I’ll start with a short overview, since Soss works differently than most PPLs. All of this might actually run, but I’m not testing as I go, so for now just consider as pseudocode.

Here’s how one could write a simple linear model:

``````m = @model (αPrior, βPrior, σPrior, xPrior) begin
α ~ αPrior
β ~ βPrior
σ ~ σPrior
x ~ xPrior
yhat = α .+ β .* x
y ~ Normal(yhat,  σ)
end
``````

In `m`, `(αPrior, βPrior, σPrior, xPrior)` are free variables, which for this purpose can be considered as hyperparameters.

One very different thing here is that no data are observed in the definition of a model. That’s separate, as part of inference. All the model knows how to do is reason about relationship between parameters and generate data.

Models are “function-like”, so `m(αPrior, βPrior, σPrior, xPrior)` gives us a thing I’m currently calling a `BoundModel`, but I’ll probably change that name at some point. It’s really more like a joint distribution. Inference methods in Soss take a joint distribution, values for a subset of the variables, and an sampling algorithm. There are different kinds of these, but for now I’ll focus on the ones that return a sample from the posterior.

A “sample” for me will be an iterator (it’s not this yet, but that’s where things are going). So for example something like

``````joint = m(αPrior, βPrior, σPrior, xPrior)
post = sample(joint, (x=x0,))
``````

Soss makes it easy (or easier, anyway) to reason about models, so for example it should be easy to turn the above into something like

``````mPred = @model(α, β, σ, x) = begin
yhat = α .+ β .* x
return Normal(yhat,  σ)
end
``````

From there it’s just a matter of piping the inference results into this prediction.

A lot of this could change. Maybe `m` should be a closure over hyperparameters with an inner model taking `x` input? Lots and lots of possibilities. And for a given model, I’d probably have a macro

``````@mlj_supervised m x y
``````

to set up the predictive distribution and the type MLJ methods in the right way.

@oxinabox, you had mentioned a need for this sort of thing. What’s a simple use case that would be useful to you?

4 Likes

More generally, here are constrains on `Soss` models

• arguments can be any Julia object: `Float64`, `Distribution`, `Function`, Soss/Turing/Gen `Model`, you name it
• For now the semantics are very simple. Each line is just `x = rhs` or `x ~ rhs`
• For assignments, `rhs` can be anything. For `~`, the type of `rhs` should support whatever methods are needed by the inference, typically `rand` and `logpdf`
• There’s an optional `return` value. If it’s not there, the return is the `NamedTuple` of all `~` values
• Inference methods may change or remove the return value. For example, `logpdf` of a model is always a `Real`
• Yes, `rhs` can be a `Model`! Check it out:
``````julia> m = @model begin
a ~ @model begin
x ~ Normal()
y ~ Normal()
return x/y
end
end;

julia> rand(m)
(a = 0.03461737882946257,)
``````

OTOH some of the semantics of this still need to be pinned down better.

Finally, it may help to think of a `Model` as a generalization of a `Distribution`. We will have ways to `convert` between the two.

A simple example that might be useful to me;
with the Iris dataset with fields: petal length, petal width, sepal length, sepal width and species.
lets say I would like to provide as input the last 4,
and get back a distribution estimate for the petal length.

And that I would be willing to provide what ever kinds of statements about naive guesses at the distributions you need. E.g. that petal length, petal width, sepal length, sepal width are multivariate normal when conditioned on species

Thanks! This is a fun example.

We could have the measurements as a single `MvNormal`, which would require passing observed data as including `missing` values. I have ideas for this, but we’re not there yet.

There’s also a much easier problem tat we just haven’t gotten to yet, which is implementing transforms for array-valued distributions. Both `TransformVariables.jl` and `Bijectors.jl` can do this, it’s just a matter of being able to infer the support at compile time and getting everything connected properly. I expect we’ll be there within a week or so. So you’d normally use an LKJ prior for covariance, but we need to be able to parameterize it with ℝⁿ.

Anyway, I think I wasn’t very clear in my question. Soss is pretty flexible, maybe think of it as

Soss : Stan :: Flux : Keras

Models in Soss are first-class and “function-like”, so there are lots of ways we could set this up. But for MLJ we need things pinned down, at least a little more.

I just added a new blog post about the interface. I think `predictive` will be useful for MLJ connectivity, just a matter of making it a little more concrete.

The MLJ interface looks like this:

``````MLJBase.fit(model::SomeSupervisedModel, verbosity::Integer, X, y) -> fitresult, cache, report
MLJBase.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat
``````

We’re moving toward an iterator interface, but the more I think about it I think it will be easier with a fixed sample form the posterior, at least to start.

So…

• I guess we’ll need to wrap things in a mutable struct
• `fitresult` is “the learned parameters”, so that’s the posterior sample of the parameters
• Not sure about `cache`
• `report` can have sampler diagnostics

A lot of supervised models look something like this:

``````julia> m = @model X begin
β ~ foo()
y ~ bar(X,β)
end;
``````

I think we could have something like

``````mutable struct MLJmodel <: MLJ.Probabilistic
sampler #function that can do MCMC
m :: Soss.Model
X :: Symbol
y :: Symbol
end
``````

Then `fit` would be

``````function fit(mljModel, verbosity, X, y)
# This line is wrong, we'll need to grab the right NamedTuple keys
(samples, report) = mljModel.sampler(mljModel.m(X=X), (y=y,))
cache = ...
(samples, cache, report)
end
``````

For prediction, I think we’ll need another wrapper:

``````struct MLJpredictor <: Distributions.Sampleable
pred :: Soss.Model
Xrow :: Matrix
βs   :: Vector{NamedTuple}
end

function Base.rand(p::MLJpredictor)
β = rand(p.βs)

# This line is wrong, we'll need to grab the right NamedTuple keys
rand(p.pred(X=p.Xrow, β=β))
end
``````

Then prediction would be

``````function predict(mljModel,βs,Xnew)
pred = predictive(mljModel.m, mljModel.X)
[MLJpredictor(pred, Xrow, βs) for Xrow in eachrow(Xnew)]
end
``````

I might be missing some fine points, but I think something along these lines can work. Any thoughts?

Thanks for that.

Regarding the mysterious “cache”.

This is needed if you are going to overload the `update` method, which you typically do for iterative models, and in other situations where you do not want certain hyperparameter changes to trigger retraining from scratch, when refitting an MLJ “machine”. The `cache` is how the first `fit` call passes any required information not in `fitresult` to subsequent `update` calls. So, if you want to implement a Soss model with MCMC sampling (which I don’t suggest as a first model, see below) then eventually you will want to have an `update` method and will want to worry about what to pass to `update` in the `cache` output of `fit`. For now, just set it to `nothing`.

If you are interested in seeing how update this works, see WIP MLJFlux.jl where something very similar is happening, the `optimiser` parameter playing the role of your `sampler` above. Or look at MLJ/src/ensembles.jl.

BTW, by having the `update` method, you will one day be able to externalise the control of the MCMC iteration to MLJ using a common “iterative methods control” model wrapper not yet written.

However, I suggest we first implement something very simple to start with, which predicts an ordinary Distribution, rather than a weaker Sampler object. Perhaps a suitable Bayesian linear regressor?

Minor points:

• In MLJ it is conventional (at least at present) for `X` coming into `fit` to be allowed to be any Tables.jl compatible table, rather than a matrix. You can convert to matrix with `MLJ.matrix` in your `fit` and `predict`.

• The `Integer` type annotation on `verbosity` in `fit` cannot be dropped for dispatch reasons

Thanks @ablaom!

In most cases we’ll have a Distribution, but it won’t be one of the canned ones. We could write ridge regression, but there’s nothing special about it. Our focus is on non-conjugate models, so we don’t have anything in place to exploit conjugacy.

I agree starting with something very simple is the best way to go. I’ll try to get going with a simple model when I get a chance and let you know how it goes

Hi @ablaom @tlienart, I got a very simple example of this working.

Let’s start with a very simple linear model:

``````julia> using Soss

julia> m = @model X begin
β ~ Normal() |> iid(size(X,2))
y ~ For(eachrow(X)) do x
Normal(x' * β, 1)
end
end;
``````

Soss can easily build the predictive model:

``````julia> pred = predictive(m,:β)
@model (X, β) begin
y ~ For(eachrow(X)) do x
Normal(x' * β, 1)
end
end
``````

So the `fit` method can look something like this:

``````function fit(m::Model, vebosity::Integer, X, y)
fitresult = dynamicHMC(m(X=X), (y=y,))
cache = nothing
report = nothing
(fitresult, cache, report)
end
``````

Note that this will change; the `report` will contain sampler diagnostics

MLJ wants something `Sampleable` for each row, so that can be one of these:

``````struct MLJpredictor <: Distributions.Sampleable{Univariate, Continuous}
pred :: Soss.Model
Xrow :: Matrix
βs   :: Vector{NamedTuple}
end
``````

We can arrive at one of these using

``````function predict(m, fitresult, Xnew)
pred = predictive(m, setdiff(variables(m),[:X,:y])...)
map(eachrow(Xnew)) do x
X = reshape(x, 1, :)
MLJpredictor(pred, X, fitresult)
end
end
``````

and sample from it with

``````function Base.rand(p::MLJpredictor)
args = merge(rand(p.βs), (X=p.Xrow,))
rand(p.pred(args)).y
end
``````

Here’s an example:

``````julia> m = @model X begin
β ~ Normal() |> iid(size(X,2))
y ~ For(eachrow(X)) do x
Normal(x' * β, 1)
end
end;

julia> truth = rand(m(X=randn(6,3)));

julia> truth.X
6×3 Array{Float64,2}:
-1.73363    2.22775    -0.219192
-0.316971   1.04228    -0.832411
-1.10279    0.0216187   0.600457
1.16898    0.516836    0.745137
-0.485346  -0.638888   -0.00444749
-0.348279   0.260488    0.00282337

julia> truth.β
3-element Array{Float64,1}:
-0.5956909231810055
0.35237481802388154
-0.5657109779896079

julia> truth.y
6-element Array{Float64,1}:
3.3822351232108376
0.5386418959858568
-1.2993070301632343
-1.5603629159553476
0.7523110806435932
0.2615344878495407

julia> fitresult = fit(m, 0, truth.X, truth.y);

julia> p = predict(m,fitresult, rand(4,3));

julia> map(rand, p)
4-element Array{Array{Float64,1},1}:
[-1.208077428909729]
[0.8595594796246107]
[-2.2840206878815446]
[0.7762001376025607]
``````
2 Likes

That looks great to me, thanks for putting this together!

One thing is that here you suggest using `dynamicHMC` and I could see this being a fallback choice in general but in this case it makes more sense to return a Normal. Does SOSS provide a way to tell easily if the posterior can be represented simply by a Distribution object?

There may only be relatively few of those examples (conjugate and normal stuff essentially) so those could take special constructors in MLJ I guess. What do you think?

In any case your example could become something like

``````@mlj_model mutable struct BayesianLinearRegression{P<:Distributions.Distribution} <: MLJBase.Probabilistic
prior::P = Normal()
fit_intercept::Bool = true
end

function MLJBase.fit(m::BayesianLinearRegression, verbose::Int, X, y)
# the model like you defined it in Soss
end

function MLJBase.predict(m::BayesianLinearRegression, fitresult, X)
# here condition to see if the prior is a Normal in which case return a Normal
# if the prior is not Normal, here the body of your `fit` function
# in first case return a `Normal`, second case return samples or sampler
end
``````

PS: I guess we could argue over whether the call to `dynamicHMC` should happen in the `fit` or `predict`, IMO it should happen in predict as per code above but no strong opinion thus far.

PPS: there probably should be two distributions in the object, one for the data, one for the prior

1 Like

Nothing yet. We had a connection to `SymPy.jl`, and I think there’s opportunity to have symbolic analysis again in the future. But currently the focus is on making everything easily composable for non-conjugate models.

Maybe. If it’s very specialized, there’s a point where Soss isn’t such a benefit, and these could be hard-coded. In that case, we’d just want (I think) to be sure there’s a clean transition, and/or provide Soss workalikes so people can easily extend.

Side note: `DynamicHMC` was my starting point, but I’ll have `AdvancedHMC`, Gibbs, etc (all the usual stuff). We probably need a parameter somewhere to specify a sampling algorithm.

Otherwise, I don’t see the benefit of moving inference to predict. The approach here is that `fit` determines the parameter samples, and `predict` uses those samples together with a new input. This parallels the approach taken by MLE-based algorithms, where “samples” is replaced with “estimate”.

Another wrinkle: We need `fit` to take as input some more information, like

• Which algorithm to use
• Options that can vary by algorithm:
• How many samples
• How many chains
• Sometimes things like proposal distribution, etc

So we probably need a way for a user to specify an inference algorithm, and then (inference in hand) a way to combine that with a `Soss.Model`, finally passing the combined result to `fit`.

Having `Inference` as a separate entity would allow users to easily just always use their preferred methods across a range of models.

Yes, I hadn’t taken into account your `predictive(m,:b)`.

Ok, I’ve started a branch `soss` on MLJModels, will try to incorporate what you’ve suggested into the standard-ish way we currently interface with external packages and will ping here once that’s gone somewhere exciting stuff, thanks!

2 Likes

That sounds great, thank you!

As you’re building this out, let me know if you hit places where we need to adapt in order to have consistency between this and unsupervised models, which I guess we’ll need to hit next, or other aspects of MLJ we’ll need to connect with 2 Likes

@tlienart One suggestion. Instead of doing this in MLJModels, could we put MLJ model implementation code in Soss.jl? (MLJBase could be an optional dependency using Requires, so the only impact on Soss.jl would be some extra tests.)

Works for me. What can I do to help make this easier?