Skip to content

RFC: add topk and / or argpartition #629

Open
@ogrisel

Description

@ogrisel

numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.

There is also a proposal to provide a higher level API, namely (arg)topk in numpy:

This PR relies on numpy.argpartition internally but it can probably later be optimized to avoid allocating a result array of the size of the input array when k is small.

Here is a quick review of some available implementations in related libraries:

  • torch.topk (no such thing as torch.argpartition)
    • returns a tuple of values and indices
  • jax.lax.top_k
    • returns a tuple of values and indices
    • apparently it is quite slow on GPU
  • dask.array.topk
    • returns only the values, I did not find a way to get the indices :(
  • cupy.argpartition but internally computes a full cupy.argsort which makes it very inefficient for large arrays and small k: O(nlog(n)) instead of O(n).

Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in std:partial_sort or std::nth_element).

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions