[ANN] ONNX export for Flux models

Hi,

After not being able to find a persistent enough serialization format I whipped together a small package which translates models into an ONNX graph:

https://github.com/DrChainsaw/ONNXmutable.jl

Simple example:

using ONNXmutable, Flux

l1 = Conv((3,3), 2=>3, relu)
l2 = Dense(3, 4, elu)

f = function(x,y)
    x = l1(x)
    # Home-brewed global average pool
    x = dropdims(mean(x, dims=(1,2)), dims=(1,2))
    x = l2(x)
    return x + y
end

# Input shapes can be combinations of numbers, strings, symbols or missing
x_shape = (:W, :H, 2, :Batch)
y_shape = (4, :Batch)
# Translate f to an ONNX graph and save it to a a file.
onnx("model.onnx", f, x_shape, y_shape)

It is basically just traversing the function and creating protos and types from ONNX.jl whenever it hits a “mapped” operation.

The list of supported operations is currently nothing to brag about, but I have put some effort in making it straighforward to add mappings between julia functions and ONNX operations without modifying any code in ONNXmutable (and I’m of course more than happy to receive pull requests with more OPs).

ONNXmutable can also de-serialize ONNX models into NaiveNASflux.jl models, something which could be quite handy in a transfer learning context when one also wants to make adjustments to the model (e.g change the output size or remove/add layers).

The reason for the name of the package and the perhaps unnecessary dependency to NaiveNASflux is simply that I initially had the ambition to only do serialization/deserialization for NaiveNASflux models as I thought writing a general model translator would be too hard. However, when the dust settled I realized that the method I used works also for generic functions too as long as they are not too picky with types high up in the call hierarchy.

If the serialization functionality is useful to other people than myself but the dependency to NaiveNASflux is not I would of course be happy to move it out of this package.

So far I have not tried to load a model which was serialized using ONNXmutable in some other (non-julian) library. Please file an issue if you try this and it doesn’t work.

13 Likes

Nice work. I see you are still actively developing your package.
On the other hand, development of the ONNX.jl package seems abandoned.
What’s your take on that?
There also an old thread on ONNX here from which it seems that there are no alternatives yet for using ONNX in Flux.

3 Likes

Hi and thanks for showing interest,

I also saw that ONNX.jl has seen very few updates but I don’t know why that is. I basically ended up using only the protos from ONNX so the only consequence are some annoying deprecation and overload warnings due to DataFlow (which is deprecated).

In other words, all mapping between deserialized protos and ops is basically reimplemented so if ONNX fails to deserialize there is a chance ONNXmutable will succeed. FWIW, all supported ops in ONNXmutable are tested against onnxruntime to verify that serialized models produce the same result as the original and deserialized model. In other words, that last statement in the op about the exported models not being tested against another framework is no longer true.

I have been thinking about just generating new proto-files using ProtoBuf, come up with a way to select spec and opset version and put some primitives from Base in a separate package, but I haven’t found the energy to do so (this issue makes it seem like a bit more work). I don’t really know much more about ONNX than I needed to know to put ONNXmutable together so its not like I have the complete design for it in my head either.

ONNXmutable should ideally also be split up into an exporter which does not depend on NaiveNASflux and I guess the exporter in turn could have a simpler dependency which just defines the primitives for flux. On top of this I guess the tracing mechanism should be outsourced to something like mjolnir so it can handle mapping control statements to the ONNX equivalents and work even if non-primitives have type bounds.

Sigh, so much coding to do and so little time and energy. I was kinda hoping that there would be some more interest and contributions from the community but I guess there are more important and fun problems to work on.

3 Likes

It may be harder than you think :slight_smile: I ended up generating most of the code and adding TypeProto_Tensor manually (see the result in onnx_pb.jl).

Somehow I missed your post in January and thought there’s no actually working ONNX implementation in Julia. Currently I’m working on my own implementation for Yota, so it may be advantageous to collaborate on this. However note that my stack is by no means based on Flux. I intended to use Yota.Tape to represent computational graph, if we decide to combine our work, the first question would be what representation is suitable for all desired use cases.

2 Likes

Yeah I saw that issue a few weeks ago and it did not do wonders for my motivation to kick off any work with this.

I was definitiely thinking to put generic stuff like protos and base OPs in a framework agnostic package (which is not prefixed by ONNX so it does not confuse people when tabbing in the repl), but I haven’t really given it any more thought than that except it should be extensible for both new and old ONNX-versions, perhaps with methods like this:

proto(::OnnxVersion{XX}, ::OpSetVersion{YY}, op::typeof(somebasefunction), args) = OnnxVersionXX.OpSetVersionYY.NodeProto(opname(op), somethingwith(args...))

or whatever makes sense to fulfill the above requirements.

Framework specific packages then just implement methods for their own primitives. Collaborating with two frameworks could be a great way to battle test the base structure (or perhaps a way to make sure nothing is ever completed, I dont know really).

Maybe tracing mechanism can/should also be made independent and should ideally be alot less code than all the primitives. I guess the tape approach (which I also use a poor mans version of) has the limitation that it can’t handle control statements. This should be no biggie for traditional deep learning, but it would be nice to be able to cater to at least some parts of the SciML stuff (although I suspect getting this to work in a satisfactory manner will be quite an uphill battle against the ONNX spec).

1 Like

Would it make sense to use something like Mjolnir here? XLA.jl more or less does this, and ONNX has (reducible) control flow ops IIRC.

1 Like

Yes, Mjolnir looks like it would be pretty useful for an ONNX exporter. I still have a feeling that making a diffeq solver work in ONNX will be about as pleasant as making it work in powerpoint animations (and maybe as efficient too)…

Do you have any thoughts regarding importing on ONNX models? I’m interested in importing even more than exporting because it will allow to easily use thousands of pretrained models in Julia. ONNX graph should be imported into some graph representation. And if you have such a representation, it might be easier to use it for exporting as well (although I’m not certain about it) and just let frameworks to convert their graphs into this representation.

1 Like

Doesn’t ONNX.jl already do this? The recent discussion in the other thread was that the package seems to have bitrotted and folks are no longer able to import models…

@dfdx I guess importing is the somewhat easier part as it translates from a smaller spec to a larger. I don’t have any more thoughts than the obvious that one should be able to pick their framework which in my mind means one package per framework for import as well.

One thing that made the import in ONNXmutable a bit more complicated is that I want to use it for long term model serialization and for this to work in a satisfactory manner for all use cases a serialized model should deserialize into the exact same thing. This requirement lead to some awkward code which looks ahead a few OPs to see if things shall be merged. The most prominent example of this is activation functions which ONNX puts as own nodes in the graph but which most ML frameworks allow for merging with the layer.

Wrt representation: ONNXmutable uses the dirt simple CompGraph from NaiveNASlib (which is framework agnostic). I initially thought that the graph representation would be helpful but with the tracing approach it was not necessary at all. The only code in ONNXmutables export part which uses the CompGraph does so to give the ONNX vertices the same name as the vertices in the CompGraph (in case they are named). Not sure if this conclusion carries over to a code transformation approach (e.g. mjolnir).

What I don’t like so much about the CompGraph is that it goes a bit against the (imo very attractive) Julia narrative of “everything is first class so you don’t need a dsl for machine learning, just write whatever code you want”.

@ToucheSir ONNX.jl does this (barring some bitrot), but not in a very extensible way. For instance, if you want to build upon it to import models into KNet or Lilith you also have to import Flux and you can’t really re-use much except for the protos.

Deserializing into a function is quite Julian ofc, but it is also quite hard to manipulate imported models which I think is not an uncommon thing to do e.g. in transfer learning.

I also think that the eval approach forces you to load the model at the top level (right?) and I find this quite limiting.

FWIW, I have now registered the package under the name ONNXNaiveNASflux so that breakage can be handled with compat. While the name is probably even more ugly than the previous name, it should make it clear that it is a bit of a niche package when/if more canoncial ONNX packages materializes.

I’m still trying to accumulate motivation/energy to extract Flux-primitives to a separate package so that they can be reused in the future. Any help with that is of course appreciated, even if it is just a signal that it would be useful to someone.

1 Like

I’m definitely interested and willing to contribute. Does it make sense to prepare a rough plan from which I or other contributors will be able to take pieces and implement one by one?

1 Like

It absolutely does! I will file an issue in the repo and ping you from it.

There is also this issue in ONNX.jl which I think is a separate, more holistic take.