diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc index 4dd08d665da..02aa4a10d07 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc @@ -40,7 +40,7 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_type{}; auto op_name = AnfAlgo::GetCNodeName(kernel_node); if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName || - op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName) { + op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName || + op_name == kDynamicStitch) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(kOpFormat_DEFAULT); diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index 156dc9690ae..b3eaefdc4d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -60,7 +60,10 @@ constexpr auto kDropout2D = "Dropout2D"; constexpr auto kDropout3D = "Dropout3D"; constexpr auto kMaskedSelect = "MaskedSelect"; constexpr auto kMaskedSelectGrad = "MaskedSelectGrad"; -const std::set kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad}; +constexpr auto kDynamicStitch = "DynamicStitch"; +constexpr auto kSearchSorted = "SearchSorted"; +const std::set kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, + kSearchSorted}; const std::set kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D, kDropout2D}; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f17d5c8d71b..c9f5ff1a47b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -35,6 +35,7 @@ constexpr auto kUniqueOpName = "Unique"; constexpr auto kMaskedSelectOpName = "MaskedSelect"; constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; +constexpr auto kDynamicStitchOpName = "DynamicStitch"; constexpr auto kFour2FiveOpName = "Four2Five"; constexpr auto kFive2FourOpName = "Five2Four"; constexpr auto kConv3DOpName = "Conv3D"; @@ -560,9 +561,9 @@ const std::set kHWSpecialFormatSet = { const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; -const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, - kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, - kMaskedSelectOpName}; +const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, + kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, + kMaskedSelectOpName, kDynamicStitchOpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index a981a473e90..434529f1aa3 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -184,6 +184,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index cc8574d3609..930ec6d9be1 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -22,6 +22,7 @@ #include "abstract/utils.h" #include "abstract/param_validator.h" #include "utils/shape_utils.h" +#include "ops/op_utils.h" namespace mindspore { namespace abstract { @@ -1140,5 +1141,59 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive return std::make_shared(x->element(), std::make_shared(y_shape, min_shape, max_shape)); } +AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("input number", args_spec_list.size(), kEqual, 2, prim_name); + for (const auto &item : args_spec_list) { + MS_EXCEPTION_IF_NULL(item); + } + + // input0: indices + auto input_tuple = args_spec_list[0]->cast(); + MS_EXCEPTION_IF_NULL(input_tuple); + auto indices = input_tuple->elements(); + auto indices0 = indices[0]->cast(); + MS_EXCEPTION_IF_NULL(indices0); + auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape]; + + // input1: data + auto input_tuple_1 = args_spec_list[1]->cast(); + MS_EXCEPTION_IF_NULL(input_tuple_1); + auto data = input_tuple_1->elements(); + auto data0 = data[0]->cast(); + MS_EXCEPTION_IF_NULL(data0); + auto data0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data0->BuildShape())[kShape]; + if (indices.size() != data.size()) { + MS_LOG(EXCEPTION) << "The number of input[0] must be the same as input[0]!"; + } + + int indices_total_size = 0; + std::map types; + types.emplace("data0", data0->BuildType()); + for (size_t i = 1; i < data.size(); ++i) { + auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape]; + auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape]; + if (indicesi_shape.size() >= datai_shape.size()) { + MS_LOG(EXCEPTION) << "The rank of indices[i] must be < rank of data[i]!"; + } + indices_total_size += indicesi_shape.size(); + } + std::set valid_types = ops::common_valid_types; + auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); + + ShapeVector out_shape = {abstract::Shape::SHP_ANY}; + for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) { + out_shape.push_back(data0_shape[i]); + } + const size_t EXPAND_MAX = 10; + ShapeVector min_shape = out_shape; + ShapeVector max_shape = out_shape; + min_shape[0] = 1; + max_shape[0] = indices_total_size * EXPAND_MAX; + return std::make_shared(infer_type, + std::make_shared(out_shape, min_shape, max_shape)); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index b9b054f63d5..36055681330 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -96,6 +96,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}}, {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}}, {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}}, + {prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}}, {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}}, {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}}, {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 4e5d2f36381..4b8364ac244 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -278,6 +278,7 @@ inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared("CustomNormalize"); inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared("CTCGreedyDecoder"); +inline const PrimitivePtr kPrimDynamicStitch = std::make_shared("DynamicStitch"); inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = std::make_shared("DepthwiseConv2dNativeBackpropFilter"); inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 41820bfa4d5..53c6b61a37b 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -29,6 +29,7 @@ from .pad_and_shift import _pad_and_shift_aicpu from .dropout_genmask import _dropout_genmask_aicpu from .dropout2d import _dropout2d_aicpu from .dropout3d import _dropout3d_aicpu +from .dynamic_stitch import _dynamic_stitch_aicpu from .get_next import _get_next_aicpu from .print_tensor import _print_aicpu from .topk import _top_k_aicpu @@ -39,6 +40,7 @@ from .squeeze import _squeeze_aicpu from .expand_dims import _expand_dims_aicpu from .randperm import _randperm_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu +from .search_sorted import _search_sorted_aicpu from .stack import _stack_aicpu from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/dynamic_stitch.py b/mindspore/ops/_op_impl/aicpu/dynamic_stitch.py new file mode 100644 index 00000000000..9dc1482b21a --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/dynamic_stitch.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""DynamicStitch op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +dynamic_stitch_op_info = AiCPURegOp("DynamicStitch") \ + .fusion_type("OPAQUE") \ + .input(0, "indices", "dynamic") \ + .input(1, "data", "dynamic") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(dynamic_stitch_op_info) +def _dynamic_stitch_aicpu(): + """DynamicStitch AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/search_sorted.py b/mindspore/ops/_op_impl/aicpu/search_sorted.py new file mode 100644 index 00000000000..92460c3088a --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/search_sorted.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SearchSorted op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +search_sorted_op_info = AiCPURegOp("SearchSorted") \ + .fusion_type("OPAQUE") \ + .attr("out_int32", "bool") \ + .attr("right", "bool") \ + .input(0, "sequence", "required") \ + .input(1, "values", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(search_sorted_op_info) +def _search_sorted_aicpu(): + """SearchSorted AiCPU register""" + return diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 7314a95ab9a..2c346a819a5 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -1022,3 +1022,92 @@ class StackDestroy(PrimitiveWithInfer): def __init__(self, index=1): """StackDestroy""" validator.check_value_type("index", index, [int], self.name) + + +class DynamicStitch(PrimitiveWithCheck): + r""" + Interleave the values from the data tensors into a single tensor. + + Inputs: + - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. + - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. + + Outputs: + Tensor. A stacked Tensor with the same type as `data`. + + Raises: + TypeError: If the data types of elements in `data` or `indices` are not the same. + ValueError: If the length of `data` or `indices` is not greater than 1. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> x1 = Tensor([6], mstype.int32) + >>> x2 = Tensor(np.array([4, 1]), mstype.int32) + >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32) + >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32) + >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32) + >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32) + >>> stitch = ops.DynamicStitch() + >>> output = stitch([x1, x2, x3], [y1, y2, y3]) + >>> print(output) + [[ 1 2] + [11 12] + [21 22] + [31 32] + [41 42] + [51 52] + [61 62]] + """ + + @prim_attr_register + def __init__(self): + """Initialize DynamicStitch""" + + def check_shape(self, indices_shape, data_shape): + validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name) + validator.check_int(len(indices_shape), 1, Rel.GE, "len of indices_shape", self.name) + indices_dim0 = len(indices_shape[0]) + indices_num = len(indices_shape) + + validator.check_value_type("shape of data", data_shape, [tuple, list], self.name) + validator.check_int(len(data_shape), 1, Rel.GE, "len of data_shape", self.name) + data_dim0 = len(data_shape[0]) + data_num = len(indices_shape) + + validator.check("size of indices", indices_num, 'size of data', data_num, Rel.EQ, self.name) + + # shape of `data` must start with shape of `indices` + for i in range(0, indices_num): + indices_dim = len(indices_shape[i]) + data_dim = len(data_shape[i]) + validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LT, self.name) + if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]: + raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}") + + # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape. + base_extra = data_dim0 - indices_dim0 + for i in range(0, data_num): + indices_dim = len(indices_shape[i]) + data_dim = len(data_shape[i]) + extra = data_dim - indices_dim + validator.check(f"extra dim of data[{i}]", extra, + f"extra dim of data[0]", base_extra, Rel.EQ, self.name) + validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:], + f"data[{i}].shape[{len(indices_shape[i])}:]", + data_shape[i][indices_dim:], Rel.EQ, self.name) + + out_shape = [-1] + data_shape[0][indices_dim0:] + return out_shape + + def check_dtype(self, indices_type, data_type): + validator.check_subclass("indices[0]", indices_type[0], mstype.tensor, self.name) + validator.check_subclass("data[0]", data_type[0], mstype.tensor, self.name) + indices_num = len(indices_type) + for i in range(0, indices_num): + validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name) + validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i], + mstype.number_type + (mstype.bool_,), self.name) + validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name) + return data_type[0] diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_dynamic_stitch.py b/tests/st/ops/ascend/test_aicpu_ops/test_dynamic_stitch.py new file mode 100644 index 00000000000..1d385b4b4e0 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_dynamic_stitch.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.stitch = P.DynamicStitch() + + def construct(self, indices, data): + return self.stitch(indices, data) + + +def test_net_int32(): + x1 = Tensor([6], mindspore.int32) + x2 = Tensor(np.array([4, 1]), mindspore.int32) + x3 = Tensor(np.array([[5, 2], [0, 3]]), mindspore.int32) + y1 = Tensor(np.array([[61, 62]]), mindspore.int32) + y2 = Tensor(np.array([[41, 42], [11, 12]]), mindspore.int32) + y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mindspore.int32) + expected = np.array([[1, 2], [11, 12], [21, 22], + [31, 32], [41, 42], [51, 52], [61, 62]]).astype(np.int32) + + print(x1.shape, x2.shape, x3.shape) + print(y1.shape, y2.shape, y3.shape) + indices = [x1, x2, x3] + data = [y1, y2, y3] + net = Net() + output = net(indices, data) + print(output.asnumpy()) + assert np.array_equal(output.asnumpy(), expected) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_search_sorted.py b/tests/st/ops/ascend/test_aicpu_ops/test_search_sorted.py new file mode 100644 index 00000000000..61705d0e257 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_search_sorted.py @@ -0,0 +1,42 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, right=False, out_int32=True): + super(Net, self).__init__() + self.search = P.SearchSorted(out_int32=out_int32, right=right) + + def construct(self, sequence, values): + return self.search(sequence, values) + + +def test_net_int32(): + np.random.seed(1) + input1 = np.sort(np.array(np.random.randint(10, size=(2, 3, 9)), dtype=np.int32), axis=-1) + sequence = Tensor(input1, mindspore.int32) + input2 = np.array(np.random.randint(10, size=(2, 3, 1)), dtype=np.int32) + values = Tensor(input2, mindspore.int32) + net = Net() + output = net(sequence, values) + print(output)