Description
numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.
There is also a proposal to provide a higher level API, namely (arg)topk in numpy:
This PR relies on numpy.argpartition
internally but it can probably later be optimized to avoid allocating a result array of the size of the input array when k
is small.
Here is a quick review of some available implementations in related libraries:
- torch.topk (no such thing as
torch.argpartition
)- returns a tuple of values and indices
- jax.lax.top_k
- returns a tuple of values and indices
- apparently it is quite slow on GPU
- dask.array.topk
- returns only the values, I did not find a way to get the indices :(
- cupy.argpartition but internally computes a full
cupy.argsort
which makes it very inefficient for large arrays and smallk
: O(nlog(n)) instead of O(n).
Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in std:partial_sort
or std::nth_element
).