diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 856fdbc9893..295d568dd9c 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -374,14 +374,12 @@ def searchsorted(a, v, side="left", sorter=None): """ usm_a = dpnp.get_usm_ndarray(a) - if dpnp.isscalar(v): - usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) - else: - usm_v = dpnp.get_usm_ndarray(v) + if not dpnp.isscalar(v): + v = dpnp.get_usm_ndarray(v) usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter) return dpnp_array._create_from_usm_ndarray( - dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) + dpt.searchsorted(usm_a, v, side=side, sorter=usm_sorter) ) diff --git a/dpnp/tensor/_searchsorted.py b/dpnp/tensor/_searchsorted.py index 4c9b54cb63f..f74e4902065 100644 --- a/dpnp/tensor/_searchsorted.py +++ b/dpnp/tensor/_searchsorted.py @@ -26,36 +26,42 @@ # THE POSSIBILITY OF SUCH DAMAGE. # ***************************************************************************** - -from typing import Literal, Union +from typing import Literal import dpctl import dpctl.utils as du +import dpnp.tensor as dpt + from ._compute_follows_data import ( ExecutionPlacementError, get_coerced_usm_type, get_execution_queue, ) from ._copy_utils import _empty_like_orderK -from ._ctors import empty +from ._ctors import empty_like +from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy from ._tensor_impl import _take as ti_take from ._tensor_impl import ( default_device_index_type as ti_default_device_index_type, ) from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right -from ._type_utils import isdtype, result_type +from ._type_utils import ( + _resolve_weak_types_all_py_ints, + _to_device_supported_dtype, + isdtype, +) from ._usmarray import usm_ndarray def searchsorted( x1: usm_ndarray, - x2: usm_ndarray, + x2: usm_ndarray | int | float | complex | bool, /, *, side: Literal["left", "right"] = "left", - sorter: Union[usm_ndarray, None] = None, + sorter: usm_ndarray | None = None, ) -> usm_ndarray: """searchsorted(x1, x2, side='left', sorter=None) @@ -68,8 +74,8 @@ def searchsorted( input array. Must be a one-dimensional array. If `sorter` is `None`, must be sorted in ascending order; otherwise, `sorter` must be an array of indices that sort `x1` in ascending order. - x2 (usm_ndarray): - array containing search values. + x2 (usm_ndarray | int | float | complex | bool): + search value or values. side (Literal["left", "right]): argument controlling which index is returned if a value lands exactly on an edge. If `x2` is an array of rank `N` where @@ -85,13 +91,11 @@ def searchsorted( array of indices that sort `x1` in ascending order. The array must have the same shape as `x1` and have an integral data type. Out of bound index values of `sorter` array are treated using - `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. + `"wrap"` mode documented in :py:func:`dpnp.tensor.take`. Default: `None`. """ if not isinstance(x1, usm_ndarray): raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x1)}") - if not isinstance(x2, usm_ndarray): - raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x2)}") if sorter is not None and not isinstance(sorter, usm_ndarray): raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(sorter)}") @@ -101,23 +105,39 @@ def searchsorted( "Expected either 'left' or 'right'" ) - if sorter is None: - q = get_execution_queue([x1.sycl_queue, x2.sycl_queue]) - else: - q = get_execution_queue( - [x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue] - ) + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = _get_queue_usm_type(x2) + q3 = sorter.sycl_queue if sorter is not None else None + q = get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None)) if q is None: raise ExecutionPlacementError( "Execution placement can not be unambiguously " "inferred from input arguments." ) + res_usm_type = get_coerced_usm_type( + tuple( + ut + for ut in ( + x1_usm_type, + x2_usm_type, + ) + if ut is not None + ) + ) + dpt.validate_usm_type(res_usm_type, allow_none=False) + sycl_dev = q.sycl_device + if x1.ndim != 1: raise ValueError("First argument array must be one-dimensional") x1_dt = x1.dtype - x2_dt = x2.dtype + x2_dt = _get_dtype(x2, sycl_dev) + if not _validate_dtype(x2_dt): + raise ValueError( + "dpt.searchsorted search value argument has " + f"unsupported data type {x2_dt}" + ) _manager = du.SequentialOrderManager[q] dep_evs = _manager.submitted_events @@ -132,7 +152,7 @@ def searchsorted( "Sorter array must be one-dimension with the same " "shape as the first argument array" ) - res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q) + res = empty_like(x1) ind = (sorter,) axis = 0 wrap_out_of_bound_indices_mode = 0 @@ -148,29 +168,28 @@ def searchsorted( x1 = res _manager.add_event_pair(ht_ev, ev) - if x1_dt != x2_dt: - dt = result_type(x1, x2) - if x1_dt != dt: - x1_buf = _empty_like_orderK(x1, dt) - dep_evs = _manager.submitted_events - ht_ev, ev = ti_copy( - src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs - ) - _manager.add_event_pair(ht_ev, ev) - x1 = x1_buf - if x2_dt != dt: - x2_buf = _empty_like_orderK(x2, dt) - dep_evs = _manager.submitted_events - ht_ev, ev = ti_copy( - src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs - ) - _manager.add_event_pair(ht_ev, ev) - x2 = x2_buf + dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev) + dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev) + + # get submitted events again in case some were added by sorter handling + dep_evs = _manager.submitted_events + if x1_dt != dt: + x1_buf = _empty_like_orderK(x1, dt) + ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs) + _manager.add_event_pair(ht_ev, ev) + x1 = x1_buf + + if not isinstance(x2, usm_ndarray): + x2 = dpt.asarray(x2, dtype=dt2, usm_type=res_usm_type, sycl_queue=q) + elif x2_dt != dt: + x2_buf = _empty_like_orderK(x2, dt) + ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs) + _manager.add_event_pair(ht_ev, ev) + x2 = x2_buf - dst_usm_type = get_coerced_usm_type([x1.usm_type, x2.usm_type]) index_dt = ti_default_device_index_type(q) - dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) + dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type) dep_evs = _manager.submitted_events if side == "left": diff --git a/dpnp/tests/tensor/test_usm_ndarray_searchsorted.py b/dpnp/tests/tensor/test_usm_ndarray_searchsorted.py index aef782f06f0..c0290ef7572 100644 --- a/dpnp/tests/tensor/test_usm_ndarray_searchsorted.py +++ b/dpnp/tests/tensor/test_usm_ndarray_searchsorted.py @@ -26,6 +26,8 @@ # THE POSSIBILITY OF SUCH DAMAGE. # ***************************************************************************** +import ctypes + import dpctl import numpy as np import pytest @@ -37,6 +39,30 @@ skip_if_dtype_not_supported, ) +_integer_dtypes = [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", +] + +_floating_dtypes = [ + "f2", + "f4", + "f8", +] + +_complex_dtypes = [ + "c8", + "c16", +] + +_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes + def _check(hay_stack, needles, needles_np): assert hay_stack.dtype == needles.dtype @@ -103,19 +129,7 @@ def test_searchsorted_strided_bool(): ) -@pytest.mark.parametrize( - "idt", - [ - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, - ], -) +@pytest.mark.parametrize("idt", _integer_dtypes) def test_searchsorted_contig_int(idt): q = get_queue_or_skip() skip_if_dtype_not_supported(idt, q) @@ -135,19 +149,7 @@ def test_searchsorted_contig_int(idt): ) -@pytest.mark.parametrize( - "idt", - [ - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, - ], -) +@pytest.mark.parametrize("idt", _integer_dtypes) def test_searchsorted_strided_int(idt): q = get_queue_or_skip() skip_if_dtype_not_supported(idt, q) @@ -174,12 +176,12 @@ def _add_extended_fp(array): array[-1] = dpt.nan -@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64]) -def test_searchsorted_contig_fp(idt): +@pytest.mark.parametrize("fdt", _floating_dtypes) +def test_searchsorted_contig_fp(fdt): q = get_queue_or_skip() - skip_if_dtype_not_supported(idt, q) + skip_if_dtype_not_supported(fdt, q) - dt = dpt.dtype(idt) + dt = dpt.dtype(fdt) hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True) _add_extended_fp(hay_stack) @@ -195,12 +197,12 @@ def test_searchsorted_contig_fp(idt): ) -@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64]) -def test_searchsorted_strided_fp(idt): +@pytest.mark.parametrize("fdt", _floating_dtypes) +def test_searchsorted_strided_fp(fdt): q = get_queue_or_skip() - skip_if_dtype_not_supported(idt, q) + skip_if_dtype_not_supported(fdt, q) - dt = dpt.dtype(idt) + dt = dpt.dtype(fdt) hay_stack = dpt.repeat( dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4 @@ -243,12 +245,12 @@ def _add_extended_cfp(array): return dpt.sort(dpt.concat((ev, array))) -@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128]) -def test_searchsorted_contig_cfp(idt): +@pytest.mark.parametrize("cdt", _complex_dtypes) +def test_searchsorted_contig_cfp(cdt): q = get_queue_or_skip() - skip_if_dtype_not_supported(idt, q) + skip_if_dtype_not_supported(cdt, q) - dt = dpt.dtype(idt) + dt = dpt.dtype(cdt) hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True) hay_stack = _add_extended_cfp(hay_stack) @@ -263,12 +265,12 @@ def test_searchsorted_contig_cfp(idt): ) -@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128]) -def test_searchsorted_strided_cfp(idt): +@pytest.mark.parametrize("cdt", _complex_dtypes) +def test_searchsorted_strided_cfp(cdt): q = get_queue_or_skip() - skip_if_dtype_not_supported(idt, q) + skip_if_dtype_not_supported(cdt, q) - dt = dpt.dtype(idt) + dt = dpt.dtype(cdt) hay_stack = dpt.repeat( dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4 @@ -315,7 +317,7 @@ def test_searchsorted_validation(): x1 = dpt.arange(10, dtype="i4") except dpctl.SyclDeviceCreationError: pytest.skip("Default device could not be created") - with pytest.raises(TypeError): + with pytest.raises(ValueError): dpt.searchsorted(x1, None) with pytest.raises(TypeError): dpt.searchsorted(x1, x1, sorter=dict()) @@ -333,10 +335,10 @@ def test_searchsorted_validation2(): q2 = dpctl.SyclQueue(d, property="in_order") x2 = dpt.ones(5, dtype=x1.dtype, sycl_queue=q2) - with pytest.raises(dpt.ExecutionPlacementError): + with pytest.raises(dpu.ExecutionPlacementError): dpt.searchsorted(x1, x2) - with pytest.raises(dpt.ExecutionPlacementError): + with pytest.raises(dpu.ExecutionPlacementError): dpt.searchsorted(x1, x2, sorter=sorter) sorter = dpt.ones(x1.shape, dtype=dpt.bool) @@ -405,3 +407,19 @@ def test_searchsorted_strided_scalar_needle(): needles = dpt.asarray(needles_np) _check(hay_stack, needles, needles_np) + + +@pytest.mark.parametrize( + "py_zero", + [bool(0), int(0), float(0), complex(0), np.float32(0), ctypes.c_int(0)], +) +@pytest.mark.parametrize("dt", _all_dtypes) +def test_searchsorted_py_scalars(py_zero, dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + + x = dpt.zeros(10, dtype=dt, sycl_queue=q) + + r1 = dpt.searchsorted(x, py_zero) + assert isinstance(r1, dpt.usm_ndarray) + assert r1.shape == ()