PPL connection to MLJ.jl

I’d like to get Soss.jl connected to MLJ.jl. I can imagine a few ways this might work, and thought I could get some input here on what could be most useful as first steps.

I’ll start with a short overview, since Soss works differently than most PPLs. All of this might actually run, but I’m not testing as I go, so for now just consider as pseudocode.

Here’s how one could write a simple linear model:

m = @model (αPrior, βPrior, σPrior, xPrior) begin
    α ~ αPrior
    β ~ βPrior
    σ ~ σPrior
    x ~ xPrior
    yhat = α .+ β .* x
    y ~ Normal(yhat,  σ)

In m, (αPrior, βPrior, σPrior, xPrior) are free variables, which for this purpose can be considered as hyperparameters.

One very different thing here is that no data are observed in the definition of a model. That’s separate, as part of inference. All the model knows how to do is reason about relationship between parameters and generate data.

Models are “function-like”, so m(αPrior, βPrior, σPrior, xPrior) gives us a thing I’m currently calling a BoundModel, but I’ll probably change that name at some point. It’s really more like a joint distribution. Inference methods in Soss take a joint distribution, values for a subset of the variables, and an sampling algorithm. There are different kinds of these, but for now I’ll focus on the ones that return a sample from the posterior.

A “sample” for me will be an iterator (it’s not this yet, but that’s where things are going). So for example something like

joint = m(αPrior, βPrior, σPrior, xPrior)
post = sample(joint, (x=x0,))

Soss makes it easy (or easier, anyway) to reason about models, so for example it should be easy to turn the above into something like

mPred = @model(α, β, σ, x) = begin
    yhat = α .+ β .* x
    return Normal(yhat,  σ)

From there it’s just a matter of piping the inference results into this prediction.

A lot of this could change. Maybe m should be a closure over hyperparameters with an inner model taking x input? Lots and lots of possibilities. And for a given model, I’d probably have a macro

@mlj_supervised m x y

to set up the predictive distribution and the type MLJ methods in the right way.

@oxinabox, you had mentioned a need for this sort of thing. What’s a simple use case that would be useful to you?


More generally, here are constrains on Soss models

  • arguments can be any Julia object: Float64, Distribution, Function, Soss/Turing/Gen Model, you name it
  • For now the semantics are very simple. Each line is just x = rhs or x ~ rhs
  • For assignments, rhs can be anything. For ~, the type of rhs should support whatever methods are needed by the inference, typically rand and logpdf
  • There’s an optional return value. If it’s not there, the return is the NamedTuple of all ~ values
  • Inference methods may change or remove the return value. For example, logpdf of a model is always a Real
  • Yes, rhs can be a Model! Check it out:
julia> m = @model begin
           a ~ @model begin
               x ~ Normal()
               y ~ Normal()
               return x/y

julia> rand(m)
(a = 0.03461737882946257,)

OTOH some of the semantics of this still need to be pinned down better.

Finally, it may help to think of a Model as a generalization of a Distribution. We will have ways to convert between the two.

A simple example that might be useful to me;
with the Iris dataset with fields: petal length, petal width, sepal length, sepal width and species.
lets say I would like to provide as input the last 4,
and get back a distribution estimate for the petal length.

And that I would be willing to provide what ever kinds of statements about naive guesses at the distributions you need. E.g. that petal length, petal width, sepal length, sepal width are multivariate normal when conditioned on species

Thanks! This is a fun example.

We could have the measurements as a single MvNormal, which would require passing observed data as including missing values. I have ideas for this, but we’re not there yet.

There’s also a much easier problem tat we just haven’t gotten to yet, which is implementing transforms for array-valued distributions. Both TransformVariables.jl and Bijectors.jl can do this, it’s just a matter of being able to infer the support at compile time and getting everything connected properly. I expect we’ll be there within a week or so. So you’d normally use an LKJ prior for covariance, but we need to be able to parameterize it with ℝⁿ.

Anyway, I think I wasn’t very clear in my question. Soss is pretty flexible, maybe think of it as

Soss : Stan :: Flux : Keras

Models in Soss are first-class and “function-like”, so there are lots of ways we could set this up. But for MLJ we need things pinned down, at least a little more.

I just added a new blog post about the interface. I think predictive will be useful for MLJ connectivity, just a matter of making it a little more concrete.

The MLJ interface looks like this:

MLJBase.fit(model::SomeSupervisedModel, verbosity::Integer, X, y) -> fitresult, cache, report
MLJBase.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat

We’re moving toward an iterator interface, but the more I think about it I think it will be easier with a fixed sample form the posterior, at least to start.


  • I guess we’ll need to wrap things in a mutable struct
  • fitresult is “the learned parameters”, so that’s the posterior sample of the parameters
  • Not sure about cache
  • report can have sampler diagnostics

A lot of supervised models look something like this:

julia> m = @model X begin
           β ~ foo()
           y ~ bar(X,β)

I think we could have something like

mutable struct MLJmodel <: MLJ.Probabilistic
    sampler #function that can do MCMC
    m :: Soss.Model
    X :: Symbol
    y :: Symbol

Then fit would be

function fit(mljModel, verbosity, X, y)
    # This line is wrong, we'll need to grab the right NamedTuple keys
    (samples, report) = mljModel.sampler(mljModel.m(X=X), (y=y,))
    cache = ...
    (samples, cache, report)

For prediction, I think we’ll need another wrapper:

struct MLJpredictor <: Distributions.Sampleable
    pred :: Soss.Model
    Xrow :: Matrix
    βs   :: Vector{NamedTuple}

function Base.rand(p::MLJpredictor)
    β = rand(p.βs)

    # This line is wrong, we'll need to grab the right NamedTuple keys
    rand(p.pred(X=p.Xrow, β=β))

Then prediction would be

function predict(mljModel,βs,Xnew)
    pred = predictive(mljModel.m, mljModel.X)
    [MLJpredictor(pred, Xrow, βs) for Xrow in eachrow(Xnew)]

I might be missing some fine points, but I think something along these lines can work. Any thoughts?

Thanks for that.

Regarding the mysterious “cache”.

This is needed if you are going to overload the update method, which you typically do for iterative models, and in other situations where you do not want certain hyperparameter changes to trigger retraining from scratch, when refitting an MLJ “machine”. The cache is how the first fit call passes any required information not in fitresult to subsequent update calls. So, if you want to implement a Soss model with MCMC sampling (which I don’t suggest as a first model, see below) then eventually you will want to have an update method and will want to worry about what to pass to update in the cache output of fit. For now, just set it to nothing.

If you are interested in seeing how update this works, see WIP MLJFlux.jl where something very similar is happening, the optimiser parameter playing the role of your sampler above. Or look at MLJ/src/ensembles.jl.

BTW, by having the update method, you will one day be able to externalise the control of the MCMC iteration to MLJ using a common “iterative methods control” model wrapper not yet written.

However, I suggest we first implement something very simple to start with, which predicts an ordinary Distribution, rather than a weaker Sampler object. Perhaps a suitable Bayesian linear regressor?

Minor points:

  • In MLJ it is conventional (at least at present) for X coming into fit to be allowed to be any Tables.jl compatible table, rather than a matrix. You can convert to matrix with MLJ.matrix in your fit and predict.

  • The Integer type annotation on verbosity in fit cannot be dropped for dispatch reasons

Thanks @ablaom!

In most cases we’ll have a Distribution, but it won’t be one of the canned ones. We could write ridge regression, but there’s nothing special about it. Our focus is on non-conjugate models, so we don’t have anything in place to exploit conjugacy.

I agree starting with something very simple is the best way to go. I’ll try to get going with a simple model when I get a chance and let you know how it goes

Hi @ablaom @tlienart, I got a very simple example of this working.

Let’s start with a very simple linear model:

julia> using Soss

julia> m = @model X begin
           β ~ Normal() |> iid(size(X,2))
           y ~ For(eachrow(X)) do x
               Normal(x' * β, 1)

Soss can easily build the predictive model:

julia> pred = predictive(m,:β)
@model (X, β) begin
        y ~ For(eachrow(X)) do x
                Normal(x' * β, 1)

So the fit method can look something like this:

function fit(m::Model, vebosity::Integer, X, y)
    fitresult = dynamicHMC(m(X=X), (y=y,))
    cache = nothing
    report = nothing
    (fitresult, cache, report)

Note that this will change; the report will contain sampler diagnostics

MLJ wants something Sampleable for each row, so that can be one of these:

struct MLJpredictor <: Distributions.Sampleable{Univariate, Continuous}
    pred :: Soss.Model
    Xrow :: Matrix
    βs   :: Vector{NamedTuple}

We can arrive at one of these using

function predict(m, fitresult, Xnew)
    pred = predictive(m, setdiff(variables(m),[:X,:y])...)
    map(eachrow(Xnew)) do x
        X = reshape(x, 1, :)
        MLJpredictor(pred, X, fitresult)

and sample from it with

function Base.rand(p::MLJpredictor)
    args = merge(rand(p.βs), (X=p.Xrow,))

Here’s an example:

julia> m = @model X begin
           β ~ Normal() |> iid(size(X,2))
           y ~ For(eachrow(X)) do x
               Normal(x' * β, 1)

julia> truth = rand(m(X=randn(6,3)));

julia> truth.X
6×3 Array{Float64,2}:
 -1.73363    2.22775    -0.219192  
 -0.316971   1.04228    -0.832411  
 -1.10279    0.0216187   0.600457  
  1.16898    0.516836    0.745137  
 -0.485346  -0.638888   -0.00444749
 -0.348279   0.260488    0.00282337

julia> truth.β
3-element Array{Float64,1}:

julia> truth.y
6-element Array{Float64,1}:

julia> fitresult = fit(m, 0, truth.X, truth.y)[1];

julia> p = predict(m,fitresult, rand(4,3));

julia> map(rand, p)
4-element Array{Array{Float64,1},1}:

That looks great to me, thanks for putting this together!

One thing is that here you suggest using dynamicHMC and I could see this being a fallback choice in general but in this case it makes more sense to return a Normal. Does SOSS provide a way to tell easily if the posterior can be represented simply by a Distribution object?

There may only be relatively few of those examples (conjugate and normal stuff essentially) so those could take special constructors in MLJ I guess. What do you think?

In any case your example could become something like

@mlj_model mutable struct BayesianLinearRegression{P<:Distributions.Distribution} <: MLJBase.Probabilistic
  prior::P = Normal()
  fit_intercept::Bool = true

function MLJBase.fit(m::BayesianLinearRegression, verbose::Int, X, y)
  # the model like you defined it in Soss

function MLJBase.predict(m::BayesianLinearRegression, fitresult, X)
  # here condition to see if the prior is a Normal in which case return a Normal
  # if the prior is not Normal, here the body of your `fit` function
  # in first case return a `Normal`, second case return samples or sampler 

PS: I guess we could argue over whether the call to dynamicHMC should happen in the fit or predict, IMO it should happen in predict as per code above but no strong opinion thus far.

PPS: there probably should be two distributions in the object, one for the data, one for the prior

1 Like

Nothing yet. We had a connection to SymPy.jl, and I think there’s opportunity to have symbolic analysis again in the future. But currently the focus is on making everything easily composable for non-conjugate models.

Maybe. If it’s very specialized, there’s a point where Soss isn’t such a benefit, and these could be hard-coded. In that case, we’d just want (I think) to be sure there’s a clean transition, and/or provide Soss workalikes so people can easily extend.

Side note: DynamicHMC was my starting point, but I’ll have AdvancedHMC, Gibbs, etc (all the usual stuff). We probably need a parameter somewhere to specify a sampling algorithm.

Otherwise, I don’t see the benefit of moving inference to predict. The approach here is that fit determines the parameter samples, and predict uses those samples together with a new input. This parallels the approach taken by MLE-based algorithms, where “samples” is replaced with “estimate”.

Another wrinkle: We need fit to take as input some more information, like

  • Which algorithm to use
  • Options that can vary by algorithm:
    • How many samples
    • How many chains
    • Sometimes things like proposal distribution, etc

So we probably need a way for a user to specify an inference algorithm, and then (inference in hand) a way to combine that with a Soss.Model, finally passing the combined result to fit.

Having Inference as a separate entity would allow users to easily just always use their preferred methods across a range of models.

Yes, I hadn’t taken into account your predictive(m,:b).

Ok, I’ve started a branch soss on MLJModels, will try to incorporate what you’ve suggested into the standard-ish way we currently interface with external packages and will ping here once that’s gone somewhere :slight_smile:

exciting stuff, thanks!


That sounds great, thank you!

As you’re building this out, let me know if you hit places where we need to adapt in order to have consistency between this and unsupervised models, which I guess we’ll need to hit next, or other aspects of MLJ we’ll need to connect with :slight_smile:


@tlienart One suggestion. Instead of doing this in MLJModels, could we put MLJ model implementation code in Soss.jl? (MLJBase could be an optional dependency using Requires, so the only impact on Soss.jl would be some extra tests.)

Works for me. What can I do to help make this easier?