Do we have differentiable sort in Julia?

Python has support for fast differentiable sorting which is an implementation of Blondel et al. (2020). I don’t think we have this in Julia at the moment?

I have specific use case for ConformalPrediction.jl, but beyond that it seems there should be many applications in ML where this could be useful (just based on reading the abstract of the paper).

Does anyone know of an existing implementation? If not, where would be a good place to kindly ask for it?


Our package InferOpt.jl implements several SoTA methods for converting a combinatorial solver (anything that can be formulated as a MILP) into a differentiable layer. We have drawn a lot of inspiration from the works of Blondel and colleagues, so I invite you to check out the code, along with our paper Learning with Combinatorial Optimization Layers: a Probabilistic Approach.

The particular “fast differentiable sorting” algorithm you mentioned is not part of our package, because we offer to wrap arbitrary solvers instead of focusing on specific problems. But you can discuss it with my colleague @BatyLeo, since it is very close to our research interests. Perhaps there could even be a library of custom differentiable algorithms to complement the generic wrappers of InferOpt.jl?


Ohhh amazing! :clap:

From a brief look at it, this indeed looks like it should do the trick for me. Having wrappers would be amazing @BatyLeo, since it’s not immediately obvious to me from the docs how exactly I should build the layer that does “fast differentiable sorting”. But for now this gives me a place to start. Thank you very much!

Really cool! Sorry since this is a bit off-topic, but I was reading the paper and wondered if there is an advantage in e.g. the warcraft example, of training end-to-end as your package allows, as opposed to separately training the ML algorithm to output paths, and then running the CO algorithm as usual? I was hoping to see a comparison in the benchmarks but didn’t find one.

I guess I would find it if I read the linked papers but can someone who already knows briefly explain what it means for sorting to be differentiable? Naively one would expect derivatives to be mostly zero.

1 Like

If all your values in an array are distinct then there is a small neighbourhood around your array such that doing sorting is just an application of fixed permutation matrix. Thus sorting is locally just a linear map and its derivative is the same linear map. In forward mode, you sort your primary array and apply the same permutation to the tangent array. For arrays which have at least two identical elements, this permutation matrix is not uniquely determined, thus sorting is not differentiable there.


As @gdalle mentionned, our package Inferopt.jl provides generic wrappers to turn any combinatorial optimization algorithm into a differentiable function.
Depending on your needs, it could be enough to directly use these wrappers.
Feel free to ask more questions if you face any difficulties, we know that the documentation is very minimalistic and lacks details and examples at the moment :slightly_smiling_face: (also do not hesitate to give us feedback so we can improve it for future users !)

However, you won’t find the exact same “fast differentiable sorting” technique as in the paper you cited. By briefly looking at it, it seems quite close to our Regularized layers, with some additions specifically taylored for sorting and ranking functions. It does not seem that difficult to implement a specific wrapper, I’ll look more in details when I’m back from holidays.


Training the ML model separately would mean that you separately learn to predict edge/cell costs in a supervised way. This first means you assume to have access to the true edge costs in your dataset, which is not the case for our losses, except the SPO+ loss. In the latter case, it’s still better to learn in an integrated way, because a small loss in terms of predicted costs can lead to a large error in term of predicted path, and the other way around. For more details, you can checkout the “Smart Predict then Optimize” paper, which explains this quite well.


You’re right, combinatorial optimization algorithms like sort are by definition piecewise constant, and therefore have zero derivative almost everywhere.

However, in many applications we still want to be able to compute meaningful gradients through combinatorial algorithms, for example when their are combined with machine learning models. In this case, there are some ways to build a regularized version of the combinatorial algorithm such that it’s close to the original one, and is differentiable with useful gradients.

Edit: sort is piecewise linear, and ranking is piecewise constant


Many papers, including ours, study combinatorial solvers that can be formulated as (Mixed Integer) Linear Programs:

\theta \longmapsto \arg\max_x \theta^\top x \text{ such that } Ax \leq b

The mapping from \theta to a solution is piecewise constant, so the real problem is not the jump points (which we almost surely never hit) but rather the regions ot desperate flatness. Thus, “non-differentiable” is true in the litteral sense (no gradients at some points), but mostly in the practical sense (all gradients are useless anyway). What we really want here is an approximation of the function with more informative gradients.

1 Like

Which brings up an interesting question: I’m not sure whether sorting itself can be formulated as a linear program. Finding the sort permutation, sure (algorithms - Sorting as a linear program - Computer Science Stack Exchange), but the actual sorted values?

1 Like

As Léo said you won’t get the exact same layer from the paper. However this will give you similar behavior (differentiable generation of a sort permutation):

using InferOpt

ranking(θ::AbstractVector) =  invperm(sortperm(θ))
layer = PerturbedAdditive(ranking; ε=1.0, nb_samples=10)

The parameter ε controls the noise in the approximation, and nb_samples tells you how many times you evaluate the original function to get a gradient approximation.
The object layer is a callable whose derivatives are automatically computed using ChainRules.jl

1 Like

Fast differentiable sorting and ranking is now implemented in this new PR.

It implements soft_sort and soft_rank differentiable methods.

You can test it by installing the branch version as follows:

pkg> add InferOpt#soft-sort

Reproducing one of the figures of the paper:

using InferOpt, Plots, Zygote
plot(x -> sort([0.0, x, 1.0, 2.0]; rev=true)[2], label="Sort")
plot!(x -> soft_sort([0.0, x, 1.0, 2.0]; rev=true)[2], label="Soft sort")
plot!(y -> gradient(x -> soft_sort([0.0, x, 1.0, 2.0]; ε=1.0, rev=true)[2], y)[1], label="Soft sort derivative")



This is amazing :fire: thanks @BatyLeo

The pull request is now merged, and included in the new v0.6.0 release of InferOpt !