Fast Differentiable Sorting and Ranking

(arxiv.org)

263 points | by etaioinshrdlu 1523 days ago

13 comments

  • kxyvr 1523 days ago
    Something that I didn't know, and may help others understand this paper better, is that there's a way to define the sorting a vector through the creative use of the Birkhoff–von Neumann theorem:

    https://en.wikipedia.org/wiki/Doubly_stochastic_matrix

    which is better explained here:

    https://cs.stackexchange.com/questions/4805/sorting-as-a-lin...

    where the sorting operation is defined as a linear program. Evidently, this has been known for at least half a century. That said, if a solution to a linear program can be found in a way that's differentiable, this means that the operation of sorting can be found to be differentiable as well. This appears to be trick in the paper and they appear to have a relatively fast way to compute this solution as well, which I think is interesting.

    • eru 1523 days ago
      We did sorting-via-linear-programming as a simple exercise when I was studying maths in the mid 2000s. It was certainly not seen as a research grade problem.

      In general if memory serves right, linear programming can solve every problem that's in P with some suitable linear time preprocessing. See https://en.wikipedia.org/wiki/P-complete for some background.

      Inner point methods to solve linear programs should then give you the bridge to continuous domains.

    • Mmrnmhrm 1523 days ago
      While the roots have been known for a long time, my impression is that the key paper that started this line of thought was Marco Cuturi's NIPS 2013 paper "Sinkhorn Distances", which is, IMHO, a very nice read.
      • kxyvr 1523 days ago
        Certainly I may be missing something, but it seems like the advance in this series of papers is that they figured out a way to calculate a differentiable solution to the sorting problem quickly, whereas it was already known that the a differentiable solution already existed, no?
    • corebit 1514 days ago
      Ugh I tried to read the stackoverflow answer it immediately devolved into gobbledygook. Incredibly frustrating.
  • GistNoesis 1523 days ago
    >"While our soft operators are several hours faster than OT, they are slower than All-pairs, despite its O(n^2)complexity. This is due the fact that, with n= 100, All-pairs is very efficient on GPUs, while our PAV implementation runs on CPU"

    Paper is interesting but not yet sure of practical uses.

    The trick I use in practice when I need a differentiable sort, is usually a pre-sort step which involves thresholding (i.e. selecting with sparsity only values greater than a certain score (usually either a constant, or a fraction of the best score, or the Kth score ) ). Then pay the quadratic price with n=10 or 20.

    I don't see when the relevance of rank between garbage results would really matter. When n get bigger and you don't want to ignore bad results, usually quantile approximations suffice.

    In the applications they cite :

    The smart use (cross-validation) of threshold by the Huber loss in section 6.4 works better 2 out of 3 times in their own graphs).

    The other use cases when order matters (for example like in section 6.3 is where the rankings are given as input). If n is low you pay the quadratic cost, if n is high you usually need to process samples a subset at a time for memory reasons and use some comparison losses (triplet loss...). So this is relevant only in the sweet spot in between if you need exact calculations.

  • etaioinshrdlu 1523 days ago
    I find this paper super cool, and highly unintuitive that an operation as discrete as sorting can be done entirely with smooth functions, and efficiently to boot.

    However I must admit that I do not fully grasp the implications of this paper. Why do we really need differentiable sorting for deep learning in the first place? What new possibilities open up as a result? My best uneducated guess is that the gradients produced by differentiable sorting are more informative than regular piecewise sorting, and this allows the gradient descent to progress faster, therefore training faster. (Think about how you can know an entire analytic function can be completely known from any small neighborhood. Are these sorting functions analytic too?) My intuition tells me that the derivatives produced with this technique allow the optimizer to see true gradients across classes.

    Are higher order derivatives also meaningful here?

    • huac 1523 days ago
      I think that this means that certain rank-based metrics, e.g. NDCG or AUC, which are very commonly used in ranking or classification, can be represented in a differentiable form. That, practically, means that neural networks can easily optimize directly for those results, instead of (most commonly) minimizing log-likelihood and monitoring for results on NDCG (as in the original formulation of RankNet). My (limited) understanding of LambdaRank is that their model empirically minimizes NDCG, but does not have a strong theoretical backing for why it should work.

      The experiments section is pretty clear in what potential applications can be, e.g. "optimizing directly for top-k classification loss" or "label ranking via soft Spearman’s rank correlation coefficient". In Google's case, there are pretty clear applications towards web search (top-k results, classification = do you click or not), and things like entity labeling (e.g. what label should we assign to a news story).

      • pheug 1523 days ago
        > My (limited) understanding of LambdaRank is that their model empirically minimizes NDCG, but does not have a strong theoretical backing for why it should work.

        That's akin to saying that minimizing cross entropy empirically maximizes accuracy but there's no strong theoretical backing for that either

        LambdaRank is one way of getting a smooth differentiable approximation to NDCG by slapping a sigmoid somewhere. The paper we're discussing now offers another way. Hard to say which way would turn out to be empirically better on problems of practical significance without actually experimenting.

    • bo1024 1523 days ago
      You're right, the paper's messaging is very confusing.

      What they want to do (eventually) is propose loss functions for machine learning. NOT algorithms for sorting per se.

      The consider ML models of the form f: x --> r that take features x to some permutation or list of ranks. For example in their CIFAR experiments, I think (correct me if I'm wrong) that x is an image and there are n possible labels, e.g. "dog, cat, giraffe, ...", and given x, f(x) should rank the labels from most likely to least likely.

      Now, how do we train such a model f? We use empirical risk minimization over some labeled dataset, e.g. a collection of pairs (x,y) where x is the image and y is a label.

      So we train our neural net to become the f that minimizes average Loss(f(x), y) over pairs x,y.

      But what Loss function do we use? That's what this paper is ultimately about. And they claim theirs is efficient to compute and produces good hypotheses f.

    • mpoteat 1523 days ago
      I view the usefulness as expanding the class of programmatic functions we can differentiate. One usefulness of differentiating programmatic functions is the ability to perform gradient descent optimization on that function... which has applications in operations research for example.
    • billconan 1523 days ago
      A few months ago, I played with a Julia differentiable programming framework, and I thought what if I make a differentiable virtual machine and use unsorted and sorted numbers as training data. will it learn a sorting algorithm, something similar to deep Turing machine. My conclusion was I can't ....
      • mpoteat 1523 days ago
        The million dollar question is if it's possible to construct a theory of computation where the input language itself is automatically differentiable, and where the execution semantics are also so. Perhaps a continuous spatial automata that has been proven to be Turing complete.
      • etaioinshrdlu 1523 days ago
        I see no reason why you couldn't make a differentiable virtual machine! Getting it to learn may be quite tricky tho.
        • stev0lution 1522 days ago
          "We extend the capabilities of neural networks by coupling them to external memory resources, which they can interact with by attentional processes. The combined system is analogous to a Turing Machine or Von Neumann architecture but is differentiable end-to-end, allowing it to be efficiently trained with gradient descent. Preliminary results demonstrate that Neural Turing Machines can infer simple algorithms such as copying, sorting, and associative recall from input and output examples." https://arxiv.org/pdf/1410.5401.pdf
      • sjg007 1523 days ago
        Seems like it should work. A fully connected network would encode all permutations. Interesting.
        • sdenton4 1523 days ago
          Yeah, the trick is choosing the right one...
    • currymj 1523 days ago
      they add a regularization term to smooth things out and make the derivative exist. in a sense it’s similar to how argmax is a very discrete operation, but softmax is a good approximation that is differentiable.
  • quotemstr 1523 days ago
    Huh. I've only read the paper superficially, but it definitely looks cool. I wouldn't have thought to implement sorting by geometric projection onto an unfathomably huge polygon, then optimizing the projection by transforming it into isotonic optimization ([1], apparently?). I'm not sure my geometry-fu is strong enough to properly understand the details of the approach.

    I do have one question though: what is the resulting algorithm actually "doing" when analyzed as a conventional sorting algorithm and not a geometric operation?

    [1] https://en.wikipedia.org/wiki/Isotonic_regression

  • formalsystem 1523 days ago
    Skimming papers like this make me think maybe we spend too much time computing stuff in discrete spaces vs continuous ones.

    This textbook covers CS theory using real numbers instead of integers.

    https://www.amazon.com/Complexity-Real-Computation-Lenore-Bl...

    • nestorD 1523 days ago
      With the strong caveat that floating point numbers are technicaly discrete.

      Anecdotal consequence: 16bits floating point numbers have so much non-linearity in their round-off error that you can use them to build neural network with no activation functions (which are traditionally needed to introduces non-linearity).

      • GregarianChild 1523 days ago
        This is very interesting. I have been wondering about this for years. I have asked neural network people whether it is possible to build a neural network from the non-linearity induced by rounding. I have never received a convincing answer. Would you be able to point me towards a write-up of this fact? And why is this not used in practice?

        I imagine that how well this works also strongly depends on the kinds of rounding. I imagine that stochastic rounding, or the rounding used in Google's bfloat16 are different in this regard in comparison with standard IEEE floating point rounding.

        • unlikelymordant 1523 days ago
        • nestorD 1523 days ago
          Someone else gave a link to an openai blog post, I personally first heard about it in a post by facebook (on their experiments with small precision).

          I believe it is not used because it requires 16bits precision which, nowadays, you only get on GPU. People usually train on GPU but then evaluate on CPU (in production) where the discontinuity would be much smaller (as you would use 32 bits precision).

          Furthermore I don't know if, in practice, that type of discontinuity trains as well as a classical activation function (the gradient propagation might be hindered by the limited precision).

        • formalsystem 1523 days ago
          See https://en.wikipedia.org/wiki/Analog_computer

          AFAIK analog computers are still the standard in radars for example and it sound like neural networks would benefit from similar hardware.

    • gumby 1523 days ago
      > Skimming papers like this make me think maybe we spend too much time computing stuff in discrete spaces vs continuous ones.

      I would go one step further and argue that we shouldn't teach kids discrete math first, but rather continuous math instead.

      Sure, you have discrete digits and toys, but Piaget (and his student Papert) observe that kids begin pouring water between different containers in the bath before they can do integer counting and from that develop understanding that objects of different shape can have the same volume and concepts of partial filling, ratios etc.

      The human scale world is continuous more than it is discrete.

      • threatofrain 1523 days ago
        Current K-12 curriculum is, from a perspective, all about preparing kids for 3 years of calculus. From this pedagogical perspective, discrete narratives are there to be a stepping stone into continuous narratives. Some people like Gilbert Strang believe that there's way too much emphasis on calculus and not enough on algebra.
      • eru 1523 days ago
        Humans have some intuitions for both discrete and continuous domains.

        (Euclidean) geometry is an interesting case. It has discrete arrangements that you can vary continuously.

        Of course, there's also areas of math without anything resembling numbers or the discrete vs continuous distinction in it.

    • versteegen 1523 days ago
      I would like to suggest "Concrete Mathematics: A Foundation for Computer Science", by Ronald Graham, Donald Knuth, and Oren Patashnik, 1994. "A blend of CONtinuous and disCRETE mathematics."

      "A textbook that is widely used in computer-science departments as a substantive but light-hearted treatment of the analysis of algorithms" --Wikipedia

  • cs702 1523 days ago
    Based on an initial read, this looks like a significant breakthrough to me.

    I wonder if the techniques developed by the authors could make it feasible to take other piecewise linear/constant algorithms (which until now have been considered "non-differentiable" for practical purposes) and turn them into differentiable algorithms.

    Think beyond sorting and ranking.

    • currymj 1523 days ago
      it’s not completely unprecedented because there were other ways of getting equivalent results before (the Sinkhorn based optimal transport approach cited, for one), which have been used for all kinds of interesting tasks. the contribution is that it does so more efficiently.
      • cs702 1523 days ago
        Agree. That's what I mean when I wrote "for practical purposes" above... although in hindsight I could have articulated it better. Thanks!
    • RrLbDpMo 1521 days ago
      There's this work from the same team (posted the same day), more general problems with a different method

      https://arxiv.org/abs/2002.08676

  • jeremysalwen 1523 days ago
    I have always thought of the LambdaRank objective (https://www.microsoft.com/en-us/research/publication/from-ra...) as mapping the scores to a probability distribution over rankings, or as they call it "projections onto the permutahedron".
  • goldenkey 1523 days ago
    This is immensely useful for my vying attempt to produce a compact lookup-tableless perfect hash generator for big datasets using ML.

    Naively one might think, why not just do a standard loss - a point to point metric like mean squared error. But this is deeply flawed. Because it requires assigning each sample to a specific natural number, effectively reducing the solution space by an order of magnitude, n!. In practice, believe me, I have tried - the network never converges because the mapping is entirely arbitrary and has nothing to do with the samples.

    To remedy this, we need an innovation in loss functions / mathematics. The loss function for the output of the net needs to be a set function [1]. This set function should measure the distance between the Set of outputs of the net and the Set {1,2,...,n-1,n}. This is different from KL and all the other standard loss metrics, because we do not care about the point to point mappings, and have no ability to histogram or compute the probability distribution since those are non differentiable operations.

    On sorting: tensorflow has a differentiable sort but it is a hack and simply propagates the loss backwards to the position the original data was in before it ended up in its sorted position. This loss of dist(sort(Y_pred),[1,n]) provides better results but still fails for large data sets due to the fakeness of the sort derivative.

    I have a hunch that there is a mathematical way to uniquely measure some arithmetic quality to optimize for, that is maximum when the output set is the discrete uniform distribution.

    The mean, standard deviation and other statistical measures are terrible identifiers and actually, via statistical theory, we would need n moments for a dataset of size n, to uniquely identify the distribution..so scratch those off the list [2].

    So there are two ways to achieve this milestone in ML:

    1) a truly differentiable distance metric between two sets d(S,T)

    2) a differentiable measure of ideal dispersion / density that forces the output set S to converge to the discrete uniform distribution (this is more problem specific to perfect hashes.)

    Perhaps this sort is the key to doing #1 generically so we can have a new type of NN based on the output Set instead of the specific points. It is late here but I am excited to hear feedback.

    [1] https://en.wikipedia.org/wiki/Set_function

    [2] https://en.wikipedia.org/wiki/Hausdorff_moment_problem

    • bionhoward 1523 days ago
      is the loss you're talking about like fused gromov-wasserstein distance?

      we hit permutation invariance issues like what you're talking about in some atomistic simulations because the atoms need to be permutable if you want to use the same model for chemistry as protein folding/docking, and the FGW algo from e.g. https://arxiv.org/pdf/1811.02834.pdf https://tvayer.github.io/materials/Titouan_Marseille_2019.pd...

      relaxes the invariance issue by adding a feature distance to the euclidean distance.

      higher order distance matrices are a neat trick, but blow up VRAM past 10-50k atoms, but if you did it in mixed precision with newer gpus it could scale damn far. problem is, the distance between distance matrices assumes the target and source items are matched, so you get into iterative closest point alignment, and pretty soon you're just reinventing RMSD

      it would be cool for molec stuffs to have fast permutation-invariant set based loss functions using transport theory, but this might be better handled with a model-free approach (just let the AI figure out the loss function itself)

  • salty_biscuits 1523 days ago
    Reminds me of this older result by Brockett about using dynamical systems to do these types of problems that I always really found interesting

    https://ieeexplore.ieee.org/document/194420

  • rsp1984 1523 days ago
    I find this super interesting but unfortunately don't have enough background in Deep Learning to recognise what this would be used for. To train a model that knows how to sort stuff (probably not)? Would someone have mercy and ELI5?
    • nestorD 1523 days ago
      Deep learning models require that all of their component are differentiable in order to fit their parameters.

      This means that most building blocks for neural networks are basic linear algebra and not much else (I am simplifying, nowadays we have access to a surprisingly large array of operations).

      This paper gives you two new building block, a sorting function and a ranking function. The ranking function might have direct applications for recommender systems.

  • brianpgordon 1523 days ago
    I feel like this is over my head, but if the problem is that sorting a vector produces non-differentiable kinks in the output then why not just run a simple polynomial regression over it and differentiate that?
  • breatheoften 1523 days ago
    Just skimmed the abstract — on a practical level — can this be used to better train a global ranking function given subsets of example ranked data?
  • lala26in 1523 days ago
    Since when Facebook posts are becoming citations LOL!