diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.cc new file mode 100644 index 00000000000..e335dfe7d05 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/coalesce_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kCoalesceInputsNum = 3; +constexpr size_t kCoalesceOutputsNum = 3; +constexpr char kKernelName[] = "Coalesce"; +} // namespace + +void CoalesceCpuKernelMod::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kCoalesceInputsNum, kKernelName); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kCoalesceOutputsNum, kKernelName); +} + +bool CoalesceCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << " which is not supported."; + } + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } + size_t output_nm = AnfAlgo::GetOutputTensorNum(node_); + std::vector dtypes(output_nm); + for (size_t i = 0; i < output_nm; i++) { + dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i); + } + std::vector dims; + (void)dims.emplace_back(shape_size_); + (void)dims.emplace_back(jump + 1); + std::vector dim; + (void)dim.emplace_back(jump + 1); + AnfAlgo::SetOutputInferTypeAndShape(dtypes, {dims, dim, AnfAlgo::GetOutputInferShape(node_, 2)}, node_.get()); + return true; +} + +void CoalesceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + CheckParam(kernel_node); + node_wpt_ = kernel_node; + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); + auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + values_size_ = indices_shape[1]; + shape_size_ = indices_shape[0]; +} + +void CoalesceCpuKernelMod::Check(const std::vector &inputs) { + auto x_indices_addr = reinterpret_cast(inputs[0]->addr); + auto x_shape_addr = reinterpret_cast(inputs[2]->addr); + for (size_t i = 0; i < values_size_; i++) { + for (size_t j = 0; j < shape_size_; j++) { + if (x_indices_addr[j * values_size_ + i] < 0) { + MS_EXCEPTION(ValueError) << "For Coalesce, values of elements of x_indices should be non-negative" + << ", but got x_indices[" << j << "][" << i + << "] = " << x_indices_addr[j * values_size_ + i]; + } + if (x_indices_addr[j * values_size_ + i] >= x_shape_addr[j]) { + MS_EXCEPTION(ValueError) + << "For Coalesce, values of elements of x_indices should not exceed the limit set by x_shape" + << ", but got x_indices[" << j << "][" << i << "] = " << x_indices_addr[j * values_size_ + i] + << ", got x_shape[" << j << "] = " << x_shape_addr[j]; + } + } + } +} + +template +void CoalesceCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto x_indices_addr = reinterpret_cast(inputs[0]->addr); + auto x_values_addr = reinterpret_cast(inputs[1]->addr); + auto x_shape_addr = reinterpret_cast(inputs[2]->addr); + auto y_indices_addr = reinterpret_cast(outputs[0]->addr); + auto y_values_addr = reinterpret_cast(outputs[1]->addr); + auto y_shape_addr = reinterpret_cast(outputs[2]->addr); + Check(inputs); + + std::vector reorder(values_size_); + std::iota(reorder.begin(), reorder.end(), 0); + + size_t shape_size = shape_size_; + size_t values_size = values_size_; + auto sorter = [x_indices_addr, shape_size, values_size](size_t i, size_t j) -> bool { + for (size_t n = 0; n < shape_size; n++) { + if (x_indices_addr[n * values_size + i] < x_indices_addr[n * values_size + j]) { + return true; + } + if (x_indices_addr[n * values_size + i] > x_indices_addr[n * values_size + j]) { + return false; + } + } + return true; + }; + std::sort(reorder.begin(), reorder.end(), sorter); + + std::vector del(values_size_); + del[0] = false; + y_values_addr[0] = x_values_addr[reorder[0]]; + for (size_t i = 1; i < values_size_; i++) { + del[i] = true; + for (size_t j = 0; j < shape_size_; j++) { + if (x_indices_addr[j * values_size_ + reorder[i]] != x_indices_addr[j * values_size_ + reorder[i - 1]]) { + del[i] = false; + break; + } + } + if (del[i]) { + y_values_addr[jump] += x_values_addr[reorder[i]]; + } else { + jump++; + y_values_addr[jump] = x_values_addr[reorder[i]]; + } + } + + size_t up = 0; + for (size_t i = 0; i < values_size_; i++) { + if (!del[i]) { + for (size_t j = 0; j < shape_size_; j++) { + y_indices_addr[j * (jump + 1) + up] = x_indices_addr[j * values_size_ + reorder[i]]; + } + up++; + } + } + + for (size_t i = 0; i < shape_size_; i++) { + y_shape_addr[i] = x_shape_addr[i]; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.h new file mode 100644 index 00000000000..40a46bd5ffe --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/coalesce_cpu_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_COALESCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_COALESCE_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class CoalesceCpuKernelMod : public NativeCpuKernelMod { + public: + CoalesceCpuKernelMod() = default; + ~CoalesceCpuKernelMod() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + void CheckParam(const CNodePtr &kernel_node); + void Check(const std::vector &inputs); + TypeId dtype_{kTypeUnknown}; + size_t values_size_{0}; + size_t shape_size_{0}; + size_t jump = 0; + CNodeWeakPtr node_wpt_; +}; + +MS_REG_CPU_KERNEL(Coalesce, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt64), + CoalesceCpuKernelMod); + +MS_REG_CPU_KERNEL(Coalesce, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt64), + CoalesceCpuKernelMod); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_COALESCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 7621dd57fc7..17886f98a26 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -86,6 +86,7 @@ constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad"; constexpr auto kMeanGradOpName = "MeanGrad"; constexpr auto kSliceOpName = "Slice"; constexpr auto kSliceGradOpName = "SliceGrad"; +constexpr auto kCoalesceOpName = "Coalesce"; constexpr auto kTileOpName = "Tile"; constexpr auto kScatterNdOpName = "ScatterNd"; constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign"; @@ -741,7 +742,7 @@ const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat const std::set kComputeDepend = { kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName, - kGetNextOpName, kNonMaxSuppressionV3OpName}; + kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 659dc64dea1..db3dd25d248 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -105,6 +105,7 @@ constexpr auto kFastGeLU = "FastGeLU"; constexpr auto kFastGeLUGrad = "FastGeLUGrad"; constexpr auto kStridedSlice = "StridedSlice"; constexpr auto kStridedSliceGrad = "StridedSliceGrad"; +constexpr auto kCoalesce = "Coalesce"; constexpr auto kZerosLike = "ZerosLike"; constexpr auto kOnes = "Ones"; constexpr auto kOnesLike = "OnesLike"; @@ -250,6 +251,7 @@ MS_CORE_API inline const PrimitivePtr kPrimGatherD = std::make_shared MS_CORE_API inline const PrimitivePtr kPrimGather = std::make_shared("Gather"); MS_CORE_API inline const PrimitivePtr kPrimGatherNd = std::make_shared("GatherNd"); MS_CORE_API inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared("SparseGatherV2"); +MS_CORE_API inline const PrimitivePtr kPrimCoalesce = std::make_shared(kCoalesce); MS_CORE_API inline const PrimitivePtr kPrimSparseToDense = std::make_shared("SparseToDense"); MS_CORE_API inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); MS_CORE_API inline const PrimitivePtr kPrimStridedSlice = std::make_shared(kStridedSlice); diff --git a/mindspore/core/ops/coalesce.cc b/mindspore/core/ops/coalesce.cc new file mode 100644 index 00000000000..327cca879dc --- /dev/null +++ b/mindspore/core/ops/coalesce.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/coalesce.h" + +#include "abstract/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +namespace { +TuplePtr CoalesceInferType(const PrimitivePtr &prim, const std::vector &input_args) { + const std::set valid_types = {kFloat16, kFloat32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("x_values", input_args[kInputIndex1]->BuildType(), valid_types, + prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x_indices", input_args[kInputIndex0]->BuildType(), {kInt64}, + prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x_shape", input_args[kInputIndex2]->BuildType(), {kInt64}, + prim->name()); + std::vector types_list = {input_args[0]->BuildType(), input_args[1]->BuildType(), + input_args[2]->BuildType()}; + return std::make_shared(types_list); +} + +abstract::TupleShapePtr CoalesceInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + constexpr int x_indices_shape_size = 2; + auto x_indices_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto x_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto x_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + if (x_indices_shape.size() != x_indices_shape_size || x_values_shape.size() != 1 || x_shape_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For Coalesce, x_indices should be a 2-D tensor" + << ", x_values should be a 1-D tensor" + << ", x_shape should be a 1-D tensor" + << ", but got x_indices is a " << x_indices_shape.size() << "-D tensor" + << ", got x_values is a " << x_values_shape.size() << "-D tensor" + << ", got x_shape is a " << x_shape_shape.size() << "-D tensor"; + } + if (x_indices_shape[0] != x_shape_shape[0]) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", sizes of dim0 of x_indices and dim0 of x_shape should be the same" + << ", but size of dim0 of got x_indices is " << x_indices_shape[0] + << ", size of dim0 of got x_shape is " << x_shape_shape[0]; + } + if (x_indices_shape[1] != x_values_shape[0]) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", sizes of dim1 of x_indices and dim0 of x_values should be the same" + << ", but size of dim1 of got x_indices is " << x_indices_shape[1] + << ", size of dim0 of got x_values is " << x_values_shape[0]; + } + ShapeVector y_indices_shape = {x_indices_shape[0], -1}; + ShapeVector y_indices_min_shape = {x_indices_shape[0], 1}; + ShapeVector y_indices_max_shape = {x_indices_shape[0], x_indices_shape[1]}; + ShapeVector y_values_shape = {-1}; + ShapeVector y_values_min_shape = {1}; + ShapeVector y_values_max_shape = {x_indices_shape[1]}; + auto y_shape = input_args[2]->BuildShape(); + MS_EXCEPTION_IF_NULL(y_shape); + abstract::ShapePtr y_shape_shape_list = y_shape->cast(); + MS_EXCEPTION_IF_NULL(y_shape_shape_list); + abstract::ShapePtr y_indices_shape_list = + std::make_shared(y_indices_shape, y_indices_min_shape, y_indices_max_shape); + abstract::ShapePtr y_values_shape_list = + std::make_shared(y_values_shape, y_values_min_shape, y_values_max_shape); + return std::make_shared( + std::vector{y_indices_shape_list, y_values_shape_list, y_shape_shape_list}); +} +} // namespace +AbstractBasePtr CoalesceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto infer_type = CoalesceInferType(primitive, input_args); + auto infer_shape = CoalesceInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Coalesce, prim::kPrimCoalesce, CoalesceInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/coalesce.h b/mindspore/core/ops/coalesce.h new file mode 100644 index 00000000000..2c2a077e124 --- /dev/null +++ b/mindspore/core/ops/coalesce.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_COALESCE_H_ +#define MINDSPORE_CORE_OPS_COALESCE_H_ +#include +#include +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "abstract/dshape.h" +#include "ops/primitive_c.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCoalesce = "Coalesce"; +class Coalesce : public PrimitiveC { + public: + Coalesce() : PrimitiveC(kNameCoalesce) { + InitIOName({"x_indices", "x_values", "x_shape"}, {"y_indices", "y_values", "y_shape"}); + } + ~Coalesce() = default; + MS_DECLARE_PARENT(Coalesce, PrimitiveC); +}; + +AbstractBasePtr CoalesceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimCoalescePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_COALESCE_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 60d79e72a48..34636780ee2 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -95,6 +95,16 @@ def get_bprop_tensor_scatter_min(self): return bprop +@bprop_getters.register(P.Coalesce) +def get_bprop_coalesce(self): + """Grad definition for `Coalesce` operation.""" + + def bprop(x_indices, x_values, x_shape, out, dout): + return dout + + return bprop + + @bprop_getters.register(P.SplitV) def get_bprop_split_v(self): """Generate bprop for SplitV""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 99e5c93031b..488e3374093 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -65,6 +65,7 @@ from .acosh_grad import _acosh_grad_aicpu from .rnnt_loss import _rnnt_loss_aicpu from .random_categorical import _random_categorical_aicpu from .cast import _cast_aicpu +from .coalesce import _coalesce_aicpu from .mirror_pad import _mirror_pad_aicpu from .masked_select import _masked_select_aicpu from .masked_select_grad import _masked_select_grad_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/coalesce.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/coalesce.py new file mode 100644 index 00000000000..c1a50103c10 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/coalesce.py @@ -0,0 +1,35 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Coalesce op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +coalesce_op_info = AiCPURegOp("Coalesce") \ + .fusion_type("OPAQUE") \ + .input(0, "x_indices", "required") \ + .input(1, "x_values", "required") \ + .input(2, "x_shape", "required") \ + .output(0, "y_indices", "required") \ + .output(1, "y_values", "required") \ + .output(2, "y_shape", "required") \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, + DataType.F32_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, + DataType.F16_Default, DataType.I64_Default) \ + .get_op_info() + +@op_info_register(coalesce_op_info) +def _coalesce_aicpu(): + """Coalesce aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index e86f8038bee..10d2f9486a6 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -23,7 +23,7 @@ from .image_ops import (CropAndResize, NonMaxSuppressionV3, HSVToRGB) from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation, - IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, + IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Coalesce, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, Lstsq, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, @@ -148,6 +148,7 @@ __all__ = [ 'AddN', 'AccumulateNV2', 'Sub', + 'Coalesce', 'CumSum', 'MatMul', 'BatchMatMul', diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index afa38b71759..b10300994f7 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -2933,6 +2933,62 @@ class Slice(PrimitiveWithInfer): 'value': None} +class Coalesce(Primitive): + """ + Returns the coalesced sparse tensor of the input. + + Inputs: + - **x_indices** (Tensor) - A 2-D Tensor, represents the indices of the nonzero elements of the sparse tensor. + Supported data type is int64. It's elements should be non-negative. The shape is :math:`(y, x)`. + - **x_values** (Tensor) - A 1-D Tensor, represents the values corresponding to the indices in `x_indices`. + Supported data types are float16 and float32. The shape is :math:`(x,)`. + - **x_shape** (Tensor) - A 1-D Tensor, specifies the shape of the sparse tensor. + Supported data type is int64. The shape is :math:`(y,)`. + + Outputs: + - **y_indices** (Tensor) - A 2-D Tensor, represents the indices of the nonzero elements of the sparse tensor. + Data type is int64. It's elements are non-negative. The shape is :math:`(y, z)`. + `z` represents the number of different indices in `x_indices`. + - **y_values** (Tensor) - A 1-D Tensor, represents the values corresponding to the indices in `y_indices`. + Data type is the same as `x_values`'s. The shape is :math:`(z,)`. + - **y_shape** (Tensor) - A 1-D Tensor, specifies the shape of the sparse tensor. + Data type is int64. The shape is :math:`(y,)`. + + Raises: + TypeError: If the data type of `x_values` is neither float32 nor float16. + TypeError: If any of the data types of `x_indices` and `x_shape` is not int64. + ValueError: If any of `x_values` and `x_shape` is not a 1-D tensor. + ValueError: If `x_indices` is not a 2-D tensor. + ValueError: If sizes of second dimension of `x_indices` and first dimension of `x_values` are not the same. + ValueError: If sizes of first dimension of `x_indices` and first dimension of `x_shape` are not the same. + ValueError: If any of the values of elements of `x_indices` is negative. + ValueError: If any of the values of elements of `x_indices` exceed the limit set by `x_shape`. + + Supported Platforms: + ``CPU`` + + Examples: + >>> x_indices = Tensor([[0, 0, 1], [1, 1, 2]], dtype=ms.int64) + >>> x_values = Tensor([1, 5, 4], dtype=ms.float32) + >>> x_shape = Tensor([3, 3], dtype=ms.int64) + >>> coalesce = ops.Coalesce() + >>> y_indices, y_values, y_shape = coalesce(x_indices, x_values, x_shape) + >>> print(y_indices) + [[0 1] + [1 2]] + >>> print(y_values) + [6. 4.] + >>> print(y_shape) + [3 3] + """ + + @prim_attr_register + def __init__(self): + """Initialize Coalesce.""" + self.init_prim_io_names(inputs=['x_indices', 'x_values', 'x_shape'], + outputs=['y_indices', 'y_values', 'y_shape']) + + class ReverseV2(PrimitiveWithInfer): """ Reverses specific dimensions of a tensor. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 3c44bce77bd..ef0ae07c0b4 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2012,6 +2012,16 @@ test_case_nn_ops = [ 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), + ('Coalesce', { + 'block': P.Coalesce(), + 'desc_inputs': [ + Tensor(np.array([[0, 0], [1, 1]]).astype(np.int64)), + Tensor(np.array([1, 2]).astype(np.float32)), + Tensor(np.array([2, 2]).astype(np.int64))], + 'desc_bprop': [ + Tensor(np.array([[0], [1]]).astype(np.int64)), + Tensor(np.array([3]).astype(np.float32)), + Tensor(np.array([2, 2]).astype(np.int64))]}), ('TopK', { 'block': P.TopK(), 'desc_const': [5],