-
Notifications
You must be signed in to change notification settings - Fork 54
feat: add API specification for returning the k largest elements
#722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a2e33f9
30900eb
e5d3189
76873d8
07e62e9
efb985d
96461fc
c72d334
54f1396
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,4 +25,5 @@ Objects in API | |
| count_nonzero | ||
| nonzero | ||
| searchsorted | ||
| top_k | ||
| where | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,15 @@ | ||
| __all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"] | ||
| __all__ = [ | ||
| "argmax", | ||
| "argmin", | ||
| "count_nonzero", | ||
| "nonzero", | ||
| "searchsorted", | ||
| "top_k", | ||
| "where", | ||
| ] | ||
|
|
||
|
|
||
| from ._types import Optional, Tuple, Literal, Union, array | ||
| from ._types import Optional, Literal, Tuple, Union, array | ||
|
|
||
|
|
||
| def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: | ||
|
|
@@ -177,6 +185,50 @@ def searchsorted( | |
| """ | ||
|
|
||
|
|
||
| def top_k( | ||
| x: array, | ||
| k: int, | ||
| /, | ||
| *, | ||
| axis: Optional[int] = None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd second Olivier's comment (#722 (comment)):
As a user, I would definitely prefer to flatten myself, as opposed to getting a 1D array as an output for an nD input. The "axis=None means ravel" default IMO makes sense for reductions which return a scalar: "give me the sum of all elements of this array which happens to be nD". Returning a ravelled array for an nD input is not intuitive, unexpected, and I don't think it has much precedent even in NumPy? If anything, the default for And in the Array API spec,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the default should be -1, not None. This is what all array libraries do for their implementation, and A code search shows that it's also somewhat regularly used in practice for |
||
| mode: Literal["largest", "smallest"] = "largest", | ||
|
rgommers marked this conversation as resolved.
|
||
| ) -> Tuple[array, array]: | ||
| """ | ||
| Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x: array | ||
| input array. Should have a real-valued data type. | ||
| k: int | ||
| number of elements to find. Must be a positive integer value. | ||
| axis: Optional[int] | ||
| axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. | ||
| mode: Literal['largest', 'smallest'] | ||
| search mode. Must be one of the following modes: | ||
|
|
||
| - ``'largest'``: return the ``k`` largest elements. | ||
| - ``'smallest'``: return the ``k`` smallest elements. | ||
|
|
||
| Default: ``'largest'``. | ||
|
|
||
| Returns | ||
| ------- | ||
| out: Tuple[array, array] | ||
| a namedtuple ``(values, indices)`` whose | ||
|
|
||
| - first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. | ||
| - second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. | ||
|
|
||
| Notes | ||
| ----- | ||
|
|
||
| - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. | ||
| - The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. | ||
| - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). | ||
| """ | ||
|
|
||
|
|
||
| def where( | ||
| condition: array, | ||
| x1: Union[array, int, float, complex, bool], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.