Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,12 @@ def searchsorted(a, v, side="left", sorter=None):
"""

usm_a = dpnp.get_usm_ndarray(a)
if dpnp.isscalar(v):
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
else:
usm_v = dpnp.get_usm_ndarray(v)
if not dpnp.isscalar(v):
v = dpnp.get_usm_ndarray(v)

usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
return dpnp_array._create_from_usm_ndarray(
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
dpt.searchsorted(usm_a, v, side=side, sorter=usm_sorter)
)


Expand Down
97 changes: 58 additions & 39 deletions dpnp/tensor/_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,42 @@
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************


from typing import Literal, Union
from typing import Literal

import dpctl
import dpctl.utils as du

import dpnp.tensor as dpt

from ._compute_follows_data import (
ExecutionPlacementError,
get_coerced_usm_type,
get_execution_queue,
)
from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._ctors import empty_like
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import isdtype, result_type
from ._type_utils import (
_resolve_weak_types_all_py_ints,
_to_device_supported_dtype,
isdtype,
)
from ._usmarray import usm_ndarray


def searchsorted(
x1: usm_ndarray,
x2: usm_ndarray,
x2: usm_ndarray | int | float | complex | bool,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Union[usm_ndarray, None] = None,
sorter: usm_ndarray | None = None,
) -> usm_ndarray:
"""searchsorted(x1, x2, side='left', sorter=None)

Expand All @@ -68,8 +74,8 @@ def searchsorted(
input array. Must be a one-dimensional array. If `sorter` is
`None`, must be sorted in ascending order; otherwise, `sorter` must
be an array of indices that sort `x1` in ascending order.
x2 (usm_ndarray):
array containing search values.
x2 (usm_ndarray | int | float | complex | bool):
search value or values.
side (Literal["left", "right]):
argument controlling which index is returned if a value lands
exactly on an edge. If `x2` is an array of rank `N` where
Expand All @@ -85,13 +91,11 @@ def searchsorted(
array of indices that sort `x1` in ascending order. The array must
have the same shape as `x1` and have an integral data type.
Out of bound index values of `sorter` array are treated using
`"wrap"` mode documented in :py:func:`dpctl.tensor.take`.
`"wrap"` mode documented in :py:func:`dpnp.tensor.take`.
Default: `None`.
"""
if not isinstance(x1, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x1)}")
if not isinstance(x2, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x2)}")
if sorter is not None and not isinstance(sorter, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(sorter)}")

Expand All @@ -101,23 +105,39 @@ def searchsorted(
"Expected either 'left' or 'right'"
)

if sorter is None:
q = get_execution_queue([x1.sycl_queue, x2.sycl_queue])
else:
q = get_execution_queue(
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
)
q1, x1_usm_type = x1.sycl_queue, x1.usm_type
q2, x2_usm_type = _get_queue_usm_type(x2)
q3 = sorter.sycl_queue if sorter is not None else None
q = get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None))
if q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously "
"inferred from input arguments."
)

res_usm_type = get_coerced_usm_type(
tuple(
ut
for ut in (
x1_usm_type,
x2_usm_type,
)
if ut is not None
)
)
dpt.validate_usm_type(res_usm_type, allow_none=False)
sycl_dev = q.sycl_device

if x1.ndim != 1:
raise ValueError("First argument array must be one-dimensional")

x1_dt = x1.dtype
x2_dt = x2.dtype
x2_dt = _get_dtype(x2, sycl_dev)
if not _validate_dtype(x2_dt):
raise ValueError(
"dpt.searchsorted search value argument has "
f"unsupported data type {x2_dt}"
)

_manager = du.SequentialOrderManager[q]
dep_evs = _manager.submitted_events
Expand All @@ -132,7 +152,7 @@ def searchsorted(
"Sorter array must be one-dimension with the same "
"shape as the first argument array"
)
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
res = empty_like(x1)
ind = (sorter,)
axis = 0
wrap_out_of_bound_indices_mode = 0
Expand All @@ -148,29 +168,28 @@ def searchsorted(
x1 = res
_manager.add_event_pair(ht_ev, ev)

if x1_dt != x2_dt:
dt = result_type(x1, x2)
if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf
if x2_dt != dt:
x2_buf = _empty_like_orderK(x2, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf
dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev)
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)

# get submitted events again in case some were added by sorter handling
dep_evs = _manager.submitted_events
if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf

if not isinstance(x2, usm_ndarray):
x2 = dpt.asarray(x2, dtype=dt2, usm_type=res_usm_type, sycl_queue=q)
elif x2_dt != dt:
x2_buf = _empty_like_orderK(x2, dt)
ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs)
_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf

dst_usm_type = get_coerced_usm_type([x1.usm_type, x2.usm_type])
index_dt = ti_default_device_index_type(q)

dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type)

dep_evs = _manager.submitted_events
if side == "left":
Expand Down
108 changes: 63 additions & 45 deletions dpnp/tests/tensor/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

import ctypes

import dpctl
import numpy as np
import pytest
Expand All @@ -37,6 +39,30 @@
skip_if_dtype_not_supported,
)

_integer_dtypes = [
"i1",
"u1",
"i2",
"u2",
"i4",
"u4",
"i8",
"u8",
]

_floating_dtypes = [
"f2",
"f4",
"f8",
]

_complex_dtypes = [
"c8",
"c16",
]

_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes


def _check(hay_stack, needles, needles_np):
assert hay_stack.dtype == needles.dtype
Expand Down Expand Up @@ -103,19 +129,7 @@ def test_searchsorted_strided_bool():
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_contig_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -135,19 +149,7 @@ def test_searchsorted_contig_int(idt):
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_strided_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -174,12 +176,12 @@ def _add_extended_fp(array):
array[-1] = dpt.nan


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_contig_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_contig_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
_add_extended_fp(hay_stack)
Expand All @@ -195,12 +197,12 @@ def test_searchsorted_contig_fp(idt):
)


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_strided_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_strided_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -243,12 +245,12 @@ def _add_extended_cfp(array):
return dpt.sort(dpt.concat((ev, array)))


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_contig_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_contig_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
hay_stack = _add_extended_cfp(hay_stack)
Expand All @@ -263,12 +265,12 @@ def test_searchsorted_contig_cfp(idt):
)


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_strided_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_strided_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -315,7 +317,7 @@ def test_searchsorted_validation():
x1 = dpt.arange(10, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("Default device could not be created")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.searchsorted(x1, None)
with pytest.raises(TypeError):
dpt.searchsorted(x1, x1, sorter=dict())
Expand All @@ -333,10 +335,10 @@ def test_searchsorted_validation2():
q2 = dpctl.SyclQueue(d, property="in_order")
x2 = dpt.ones(5, dtype=x1.dtype, sycl_queue=q2)

with pytest.raises(dpt.ExecutionPlacementError):
with pytest.raises(dpu.ExecutionPlacementError):
dpt.searchsorted(x1, x2)

with pytest.raises(dpt.ExecutionPlacementError):
with pytest.raises(dpu.ExecutionPlacementError):
dpt.searchsorted(x1, x2, sorter=sorter)

sorter = dpt.ones(x1.shape, dtype=dpt.bool)
Expand Down Expand Up @@ -405,3 +407,19 @@ def test_searchsorted_strided_scalar_needle():
needles = dpt.asarray(needles_np)

_check(hay_stack, needles, needles_np)


@pytest.mark.parametrize(
"py_zero",
[bool(0), int(0), float(0), complex(0), np.float32(0), ctypes.c_int(0)],
)
@pytest.mark.parametrize("dt", _all_dtypes)
def test_searchsorted_py_scalars(py_zero, dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dt, q)

x = dpt.zeros(10, dtype=dt, sycl_queue=q)

r1 = dpt.searchsorted(x, py_zero)
assert isinstance(r1, dpt.usm_ndarray)
assert r1.shape == ()
Loading