From 357cd941293935162e6576310d06dead5922d3e8 Mon Sep 17 00:00:00 2001 From: peiing Date: Sat, 23 Jul 2022 11:07:59 +0800 Subject: [PATCH] [feat][assistant][I48O4X]Add SparseSegmentMeanWithNumSegments --- .../sparse_segment_mean_grad_cpu_kernel.cc | 127 ++++++++++++++ .../sparse_segment_mean_grad_cpu_kernel.h | 56 +++++++ ...gment_mean_with_num_segments_cpu_kernel.cc | 155 ++++++++++++++++++ ...egment_mean_with_num_segments_cpu_kernel.h | 58 +++++++ mindspore/core/ops/core_ops.h | 5 + .../core/ops/grad/sparse_segment_mean_grad.cc | 106 ++++++++++++ .../core/ops/grad/sparse_segment_mean_grad.h | 46 ++++++ .../sparse_segment_mean_with_num_segments.cc | 117 +++++++++++++ .../sparse_segment_mean_with_num_segments.h | 49 ++++++ .../ops/_grad_experimental/grad_sparse_ops.py | 18 ++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 2 + .../aicpu/sparse_segment_mean_grad.py | 36 ++++ .../sparse_segment_mean_with_num_segments.py | 44 +++++ .../mindspore/ops/operations/_grad_ops.py | 45 ++++- .../mindspore/ops/operations/sparse_ops.py | 56 +++++++ tests/ut/python/ops/test_ops.py | 15 ++ 16 files changed, 932 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h create mode 100644 mindspore/core/ops/grad/sparse_segment_mean_grad.cc create mode 100644 mindspore/core/ops/grad/sparse_segment_mean_grad.h create mode 100644 mindspore/core/ops/sparse_segment_mean_with_num_segments.cc create mode 100644 mindspore/core/ops/sparse_segment_mean_with_num_segments.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.cc new file mode 100644 index 00000000000..b7a5abe91d3 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseSegmentMeanGradInputsNum = 4; +constexpr size_t kSparseSegmentMeanGradOutputsNum = 1; + +#define ADD_KERNEL(t1, t2, t3, t4, t5) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddInputAttr(kNumberType##t4) \ + .AddOutputAttr(kNumberType##t5) +} // namespace + +void SparseSegmentMeanGradCpuKernelMod::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentMeanGradInputsNum, kernel_name_); + size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentMeanGradOutputsNum, kernel_name_); +} + +void SparseSegmentMeanGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0); + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0); + segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2); + y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0); +} + +bool SparseSegmentMeanGradCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (x_dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (x_dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_) + << " which is not supported."; + } + return true; +} + +template +void SparseSegmentMeanGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t kMultiply = 1; + size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kMultiply, std::multiplies()) / x_shape_[kIndex0]; + size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kMultiply, std::multiplies()); + size_t num_elements = std::accumulate(y_shape_.begin(), y_shape_.end(), kMultiply, std::multiplies()); + int32_t k = *reinterpret_cast(inputs[kIndex3]->addr); + auto x_shape_0 = static_cast(x_shape_[kIndex0]); + auto x_addr = reinterpret_cast(inputs[kIndex0]->addr); + auto indices_addr = reinterpret_cast(inputs[kIndex1]->addr); + auto segment_ids_addr = reinterpret_cast(inputs[kIndex2]->addr); + auto y_addr = reinterpret_cast(outputs[kIndex0]->addr); + + for (size_t i = 0; i < num_elements; i++) { + y_addr[i] = (T)0; + } + for (size_t i = 1; i < m; i++) { + if (segment_ids_addr[i] < segment_ids_addr[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; + } + } + for (size_t i = 0; i < m; i++) { + if (indices_addr[i] >= x_shape_0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices is out of range of x's first dimension."; + } + if (segment_ids_addr[i] >= k) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids is out of range of output_dim0."; + } + } + int beginindex = segment_ids_addr[0]; + size_t countnum = 1; + for (size_t i = 1; i < m; i++) { + if (segment_ids_addr[i] != beginindex) { + for (size_t j = 1; j <= countnum; j++) { + for (size_t l = 0; l < n; l++) { + y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(countnum); + } + } + beginindex = segment_ids_addr[i]; + countnum = 1; + } else { + countnum++; + } + } + + int i = m; + for (size_t j = 1; j <= countnum; j++) { + for (size_t l = 0; l < n; l++) { + y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(countnum); + } + } +} + +std::vector SparseSegmentMeanGradCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = {ADD_KERNEL(Float32, Int32, Int32, Int32, Float32), + ADD_KERNEL(Float64, Int32, Int32, Int32, Float64)}; + + return kernel_attr_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentMeanGrad, SparseSegmentMeanGradCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h new file mode 100644 index 00000000000..a729171cd5e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentMeanGradCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + SparseSegmentMeanGradCpuKernelMod() = default; + + ~SparseSegmentMeanGradCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + protected: + std::vector GetOpSupport() override; + + private: + void CheckParam(const CNodePtr &kernel_node); + ShapeVector x_shape_; + ShapeVector segment_ids_shape_; + ShapeVector y_shape_; + TypeId x_dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.cc new file mode 100644 index 00000000000..6b84bfa0e97 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseSegmentMeanWithNumSegmentsInputsNum = 4; +constexpr size_t kSparseSegmentMeanWithNumSegmentsOutputsNum = 1; + +#define ADD_KERNEL(t1, t2, t3, t4, t5) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddInputAttr(kNumberType##t4) \ + .AddOutputAttr(kNumberType##t5) +} // namespace + +void SparseSegmentMeanWithNumSegmentsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0); + indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1); + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0); + segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2); + y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0); +} + +bool SparseSegmentMeanWithNumSegmentsCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + switch (x_dtype_) { + case (kNumberTypeFloat16): + if (indices_dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + break; + } else { + LaunchKernel(inputs, outputs); + break; + } + case (kNumberTypeFloat32): + if (indices_dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + break; + } else { + LaunchKernel(inputs, outputs); + break; + } + case (kNumberTypeFloat64): + if (indices_dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + break; + } else { + LaunchKernel(inputs, outputs); + break; + } + default: + MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_) + << " which is not supported."; + } + return true; +} + +template +void SparseSegmentMeanWithNumSegmentsCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t kMultiply = 1; + size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kMultiply, std::multiplies()) / x_shape_[kIndex0]; + size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kMultiply, std::multiplies()); + size_t num_elements = std::accumulate(y_shape_.begin(), y_shape_.end(), kMultiply, std::multiplies()); + auto x_shape_0 = static_cast(x_shape_[kIndex0]); + auto x_addr = reinterpret_cast(inputs[kIndex0]->addr); + auto indices_addr = reinterpret_cast(inputs[kIndex1]->addr); + auto segment_ids_addr = reinterpret_cast(inputs[kIndex2]->addr); + auto num_segments_addr = reinterpret_cast(inputs[kIndex3]->addr); + auto y_addr = reinterpret_cast(outputs[kIndex0]->addr); + for (size_t i = 1; i < m; i++) { + if (segment_ids_addr[i] < segment_ids_addr[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input segment_ids should be sorted."; + } + } + if (segment_ids_addr[m - 1] >= num_segments_addr[0]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', num_segments must be bigger than the largest id of segment_ids."; + } + for (size_t i = 0; i < m; i++) { + if (indices_addr[i] >= x_shape_0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input indices is out of range of x's first dimension."; + } + } + for (size_t i = 0; i < num_elements; i++) { + y_addr[i] = (T1)0; + } + int oldindex = -1; + int countnum = 0; + for (size_t i = 0; i < m; i++) { + if (oldindex == segment_ids_addr[i]) { + countnum++; + } else { + if (countnum != 0) { + for (size_t j = 0; j < n; j++) { + y_addr[j + oldindex * n] /= (T1)countnum; + } + } + countnum = 1; + oldindex = segment_ids_addr[i]; + } + for (size_t j = 0; j < n; j++) { + y_addr[j + oldindex * n] += x_addr[j + indices_addr[i] * n]; + } + } + if (countnum != 0) { + for (size_t j = 0; j < n; j++) { + y_addr[j + oldindex * n] /= (T1)countnum; + } + } +} + +void SparseSegmentMeanWithNumSegmentsCpuKernelMod::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentMeanWithNumSegmentsInputsNum, kernel_name_); + size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentMeanWithNumSegmentsOutputsNum, kernel_name_); +} + +std::vector SparseSegmentMeanWithNumSegmentsCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + ADD_KERNEL(Float16, Int32, Int32, Int32, Float16), ADD_KERNEL(Float16, Int64, Int64, Int64, Float16), + ADD_KERNEL(Float32, Int32, Int32, Int32, Float32), ADD_KERNEL(Float32, Int64, Int64, Int64, Float32), + ADD_KERNEL(Float64, Int32, Int32, Int32, Float64), ADD_KERNEL(Float64, Int64, Int64, Int64, Float64)}; + + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentMeanWithNumSegments, + SparseSegmentMeanWithNumSegmentsCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h new file mode 100644 index 00000000000..839c7b23fb0 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentMeanWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + SparseSegmentMeanWithNumSegmentsCpuKernelMod() = default; + ~SparseSegmentMeanWithNumSegmentsCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + protected: + std::vector GetOpSupport() override; + + private: + void CheckParam(const CNodePtr &kernel_node); + ShapeVector x_shape_; + ShapeVector segment_ids_shape_; + ShapeVector y_shape_; + TypeId x_dtype_{kTypeUnknown}; + TypeId indices_dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 875ba0406f9..031af0b10de 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -324,6 +324,8 @@ constexpr auto kCSRDiv = "CSRDiv"; constexpr auto kDenseToDenseSetOperation = "DenseToDenseSetOperation"; constexpr auto kSparseMatrixAdd = "SparseMatrixAdd"; constexpr auto kSparseAdd = "SparseAdd"; +constexpr auto kSparseSegmentMeanGrad = "SparseSegmentMeanGrad"; +constexpr auto kSparseSegmentMeanWithNumSegments = "SparseSegmentMeanWithNumSegments"; constexpr auto kSparseConcat = "SparseConcat"; constexpr auto kSparseMatrixNNZ = "SparseMatrixNNZ"; constexpr auto kSparseMatrixTranspose = "SparseMatrixTranspose"; @@ -994,6 +996,9 @@ GVAR_DEF(PrimitivePtr, kPrimSparseSplit, std::make_shared(kSparseSpli GVAR_DEF(PrimitivePtr, kPrimDenseToDenseSetOperation, std::make_shared(kDenseToDenseSetOperation)); GVAR_DEF(PrimitivePtr, kPrimSparseMatrixAdd, std::make_shared(kSparseMatrixAdd)); GVAR_DEF(PrimitivePtr, kPrimSparseAdd, std::make_shared(kSparseAdd)); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMeanGrad, std::make_shared("SparseSegmentMeanGrad")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMeanWithNumSegments, + std::make_shared("SparseSegmentMeanWithNumSegments")); GVAR_DEF(PrimitivePtr, kPrimDenseToCSRSparseMatrix, std::make_shared("DenseToCSRSparseMatrix")); GVAR_DEF(PrimitivePtr, kPrimSparseTensorToCSRSparseMatrix, std::make_shared(kSparseTensorToCSRSparseMatrix)); GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared(kCSRSparseMatrixToSparseTensor)); diff --git a/mindspore/core/ops/grad/sparse_segment_mean_grad.cc b/mindspore/core/ops/grad/sparse_segment_mean_grad.cc new file mode 100644 index 00000000000..085aab32776 --- /dev/null +++ b/mindspore/core/ops/grad/sparse_segment_mean_grad.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/sparse_segment_mean_grad.h" +#include "abstract/dshape.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentMeanGradInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + constexpr size_t kRankNum0 = 0; + constexpr size_t kRankNum1 = 1; + constexpr size_t kShapeNum0 = 0; + constexpr int kDimNum0 = 0; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto output_dim0_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + if (x_shape.size() < kRankNum1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank cannot be less than 1."; + } + if (output_dim0_shape.size() != kRankNum0) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar."; + } + if (indices_shape[kShapeNum0] != segment_ids_shape[kShapeNum0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch."; + } + if (!input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + auto output_dim0_value = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(output_dim0_value); + auto output_dim0_value_ptr = output_dim0_value->BuildValue(); + MS_EXCEPTION_IF_NULL(output_dim0_value_ptr); + auto output_dim0_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name); + int dim_zero = output_dim0_value_ptr_tensor[kShapeNum0]; + if (dim_zero <= kDimNum0) { + MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!"; + } else { + ShapeVector y_shape = x_shape; + y_shape[kShapeNum0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentMeanGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + auto output_dim0_type = input_args[kInputIndex3]->BuildType(); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat32, kFloat64}, prim->name()); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + (void)types.emplace("output_dim0", output_dim0_type); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentMeanGrad, BaseOperator); +AbstractBasePtr SparseSegmentMeanGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = kInputIndex4; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto types = SparseSegmentMeanGradInferType(prim, input_args); + auto shapes = SparseSegmentMeanGradInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentMeanGrad, {3}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentMeanGrad, prim::kPrimSparseSegmentMeanGrad, SparseSegmentMeanGradInfer, + nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/sparse_segment_mean_grad.h b/mindspore/core/ops/grad/sparse_segment_mean_grad.h new file mode 100644 index 00000000000..df43fd7d363 --- /dev/null +++ b/mindspore/core/ops/grad/sparse_segment_mean_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_ + +#include +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentMeanGrad = "SparseSegmentMeanGrad"; +class MIND_API SparseSegmentMeanGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentMeanGrad); + SparseSegmentMeanGrad() : BaseOperator(kNameSparseSegmentMeanGrad) { + InitIOName({"x", "indices", "segment_ids", "output_dim0"}, {"y"}); + } +}; + +abstract::AbstractBasePtr SparseSegmentMeanGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentMeanGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_ diff --git a/mindspore/core/ops/sparse_segment_mean_with_num_segments.cc b/mindspore/core/ops/sparse_segment_mean_with_num_segments.cc new file mode 100644 index 00000000000..782f6625581 --- /dev/null +++ b/mindspore/core/ops/sparse_segment_mean_with_num_segments.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/sparse_segment_mean_with_num_segments.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentMeanWithNumSegmentsInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + constexpr size_t kRankOne = 1; + constexpr size_t kDimOne = 1; + constexpr size_t kShapeZero = 0; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto num_segments_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + if (indices_shape.size() != kRankOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1."; + } + if (segment_ids_shape.size() != kRankOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1."; + } + if (x_shape.size() < kRankOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot be less than 1."; + } + if (indices_shape[kShapeZero] != segment_ids_shape[kShapeZero]) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", indices and segment_ids's ranks mismatch."; + } + if (num_segments_shape.size() > kRankOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of num_segments should be 0 or 1."; + } + if (num_segments_shape.size() == kRankOne && num_segments_shape[kShapeZero] != kDimOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1."; + } + if (!input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + auto num_segments_value = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(num_segments_value); + auto num_segments_value_ptr = num_segments_value->BuildValue(); + MS_EXCEPTION_IF_NULL(num_segments_value_ptr); + auto num_segments_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim_name); + size_t dim_zero = num_segments_value_ptr_tensor.back(); + if (dim_zero < kDimOne) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", num_segments must be bigger than the largest id of segment_ids."; + } else { + ShapeVector y_shape = x_shape; + y_shape[kShapeZero] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentMeanWithNumSegmentsInferType(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + auto num_segments_type = input_args[kInputIndex3]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + (void)types.emplace("num_segments", num_segments_type); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +AbstractBasePtr SparseSegmentMeanWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto types = SparseSegmentMeanWithNumSegmentsInferType(prim, input_args); + auto shapes = SparseSegmentMeanWithNumSegmentsInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} + +MIND_API_OPERATOR_IMPL(SparseSegmentMeanWithNumSegments, BaseOperator); +REGISTER_HOST_DEPENDS(kNameSparseSegmentMeanWithNumSegments, {3}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentMeanWithNumSegments, prim::kPrimSparseSegmentMeanWithNumSegments, + SparseSegmentMeanWithNumSegmentsInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_segment_mean_with_num_segments.h b/mindspore/core/ops/sparse_segment_mean_with_num_segments.h new file mode 100644 index 00000000000..05beb7dad7d --- /dev/null +++ b/mindspore/core/ops/sparse_segment_mean_with_num_segments.h @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_ +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentMeanWithNumSegments = "SparseSegmentMeanWithNumSegments"; +/// \brief Computes the mean along sparse segments of a tensor, but it is allowed to miss id in segment_ids. +/// Refer to Python API @ref mindspore.ops.SparseSegmentMeanWithNumSegments for more details. +class MIND_API SparseSegmentMeanWithNumSegments : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentMeanWithNumSegments); + /// \brief Constructor. + SparseSegmentMeanWithNumSegments() : BaseOperator(kNameSparseSegmentMeanWithNumSegments) { + InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"}); + } +}; + +abstract::AbstractBasePtr SparseSegmentMeanWithNumSegmentsInfer( + const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentMeanWithNumSegmentsPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py index 5e5a7a4966f..0b341ff2f48 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py @@ -19,6 +19,7 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix from mindspore.ops.operations.sparse_ops import SparseToDenseV2 from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments +from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments from mindspore.common import dtype as mstype from .. import functional as F from .. import operations as P @@ -99,3 +100,20 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self): return all_d return bprop + + +@bprop_getters.register(SparseSegmentMeanWithNumSegments) +def get_bprop_sparse_segment_mean_with_num_segments(self): + """Grad definition for `SparseSegmentMeanWithNumSegments` operation.""" + input_grad = G.SparseSegmentMeanGrad() + shape = P.Shape() + + def bprop(x, indices, segment_ids, num_segments, out, dout): + output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32) + indices = F.cast(indices, mstype.int32) + segment_ids = F.cast(segment_ids, mstype.int32) + dx = input_grad(dout, indices, segment_ids, output_dim0) + all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments)) + return all_d + + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index b5369b69508..6e282ce5a2b 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -292,6 +292,8 @@ from .segment_sum import _segment_sum_aicpu from .sparse_segment_sqrt_n import _sparse_segment_sqrt_n_aicpu from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu from .sparse_segment_sqrt_n_with_num_segments import _sparse_segment_sqrt_n_with_num_segments_aicpu +from .sparse_segment_mean_grad import _sparse_segment_mean_grad_aicpu +from .sparse_segment_mean_with_num_segments import _sparse_segment_mean_with_num_segments_aicpu from .scatter_nd_max import _scatter_nd_max_aicpu from .conj import _conj_aicpu from .ctc_loss_v2 import _ctc_loss_v2_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py new file mode 100644 index 00000000000..72f7b9fb0b3 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py @@ -0,0 +1,36 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseSegmentMeanGrad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_segment_mean_grad_op_info = AiCPURegOp("SparseSegmentMeanGrad") \ + .fusion_type("OPAQUE") \ + .input(0, "dout", "required") \ + .input(1, "indices", "required") \ + .input(2, "segment_ids", "required") \ + .input(3, "output_dim0", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(sparse_segment_mean_grad_op_info) +def _sparse_segment_mean_grad_aicpu(): + """SparseSegmentMeanGrad aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py new file mode 100644 index 00000000000..7f06d2a1138 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py @@ -0,0 +1,44 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseSegmentMeanWithNumSegments op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_segment_mean_with_num_segments_op_info = AiCPURegOp("SparseSegmentMeanWithNumSegments") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "indices", "required") \ + .input(2, "segment_ids", "required") \ + .input(3, "num_segments", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(sparse_segment_mean_with_num_segments_op_info) +def _sparse_segment_mean_with_num_segments_aicpu(): + """SparseSegmentMeanWithNumSegments aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 1d83dc3f233..4652bae51f8 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -2047,7 +2047,7 @@ class UpsampleNearest3DGrad(Primitive): One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified. Inputs: - **grad_output** (Tensor) - Tensor of shape [N, C, D, H, W], Must be one of the following types: - float16, float32, float64. + float16, float32, float64. Outputs: Tensor, A 5-D tensor. Has the same type as input grad_output, shape depends on x and output_size/scales. @@ -3161,6 +3161,45 @@ class GridSampler3DGrad(Primitive): self.add_prim_attr('align_corners', align_corners) +class SparseSegmentMeanGrad(Primitive): + """ + Compute gradients for SparseSegmentMeanGrad operation. + + Inputs: + - **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanGrad. + - **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following + types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`. + - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one of the + following types: int32, int64. Values should be sorted and can be repeated. The shape should be :math:`(N,)`. + - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentMean op. + + Outputs: + A Tensor. Has the same type as `x` . + Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`. + + Raises: + TypeError: If `x` or `indices` or `segment_ids` is not a tensor. + TypeError: If the dtype of `x` is not any of the following data types: {float32, float64}. + TypeError: If the dtype of `indices` is not int32. + TypeError: If the dtype of `segment_ids` is not int32. + TypeError: If the dtype of `output_dim0` is not int32. + ValueError: If dimension size of `x` less than 1. + ValueError: If rank of `indices` or `segment_ids` is not 1. + ValueError: If dimension size of `output_dim0` is not 0. + ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`. + ValueError: If `segment_ids` is not sorted. + ValueError: If `indices` is out of range of x's first dimension. + + Supported Platforms: + ``Ascend`` ``CPU`` + """ + + @prim_attr_register + def __init__(self): + """Initialize SparseSegmentMeanGrad""" + self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y']) + + class FractionalMaxPoolGrad(Primitive): """Computes gradients for FractionalMaxPool operation.""" @@ -3494,9 +3533,9 @@ class GridSampler2DGrad(Primitive): - **grad** (Tensor) - A 4-D tensor whose dtype is float16 or float32 and whose shape is :math:`(N, C, H_{out}, W_{out})`. The shape is inconsistent with the shape of the output result of forward calculation. - **input_x** (Tensor) - A 4-D tensor whose dtype is the same as `grad` and whose shape is :math:`(N, C, - H_{in}, W_{in})`. + H_{in}, W_{in})`. - **grid** (Tensor) - A 4-D tensor whose dtype is the same as `grad` and whose - shape is :math:`(N, H_{out}, W_{out}, 2)`. + shape is :math:`(N, H_{out}, W_{out}, 2)`. Outputs: - **dx** (Tensor) - A 4-D tensor whose dtype and shape are the same as `input_x`. diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index 265664ee469..a95e5c243e6 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -1085,6 +1085,62 @@ class SparseMatrixNNZ(Primitive): inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'], outputs=['y']) +class SparseSegmentMeanWithNumSegments(Primitive): + """ + Compute the mean along sparse segments of a tensor. It is allowed to have missing id in segment_ids. + + Inputs: + - **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanWithNumSegments. + - **indices** (Tensor) - 1-D Tensor with indices into `x`. Must be one of the following + types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`. + - **segment_ids** (Tensor) - 1-D Tensor with indices into the output `y`. Must be one of the + following types: int32, int64. Values should be sorted and can be repeated. The shape should + be :math:`(N,)`. + - **num_segments** (Tensor) - Num_segments indicates the size of the output. + It should be bigger than the largest id of `segment_ids`. + + Outputs: + A Tensor. Has the same type as `x` . + Has same shape as `x`, except for dimension 0 which is the value of `num_segments`. + + Raises: + TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor. + TypeError: If dtype of `x` is not in [float16, float32, float64]. + TypeError: If dtype of `indices` is not int32 or int64. + TypeError: If dtype of `segment_ids` and `indices` mismatch. + TypeError: If dtype of `num_segments` and `indices` mismatch. + ValueError: If rank of `x` less than 1. + ValueError: If rank of `indices` or `segment_ids` is not 1. + ValueError: If rank of `num_segments` is bigger than 1. + ValueError: If numelements of `num_segments` is not 1. + ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`. + ValueError: If `segment_ids` is not sorted. + ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`. + ValueError: If `indices` is out of range of x's first dimension. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=ms.float16) + >>> indices = Tensor([0, 2, 1], dtype=ms.int32) + >>> segment_ids = Tensor([0, 0, 2], dtype=ms.int32) + >>> num_segments = Tensor([4], dtype=ms.int32) + >>> sparse_segment_mean_with_num_segments = ops.SparseSegmentMeanWithNumSegments() + >>> output = sparse_segment_mean_with_num_segments(x, indices, segment_ids, num_segments) + >>> print(output) + [[1. 1. 1. 0.] + [0. 0. 0. 0.] + [0. 1. 1. 0.] + [0. 0. 0. 0.]] + """ + + @prim_attr_register + def __init__(self): + """Initialize SparseSegmentMeanWithNumSegments""" + self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'num_segments'], outputs=['y']) + + class SparseAdd(Primitive): """ Computes the sum of a COOTensor and another COOTensor. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 42d0afcd074..3d0bff5bdbb 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -146,6 +146,7 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix from mindspore.ops.operations.sparse_ops import SparseSparseMinimum from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments +from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments from mindspore.ops.operations.other_ops import BlackmanWindow from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent @@ -2177,6 +2178,13 @@ test_case_math_ops = [ 'block': P.Sign(), 'desc_inputs': [[3]], 'desc_bprop': [[3]]}), + ('SparseSegmentMeanGrad', { + 'block': G.SparseSegmentMeanGrad(), + 'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)), + Tensor(np.array([0, 1]).astype(np.int32)), + Tensor(np.array([0, 1]).astype(np.int32)), + Tensor(np.array(4).astype(np.int32))], + 'skip': ['backward']}), ('Round', { 'block': P.Round(), 'desc_inputs': [[3]], @@ -4426,6 +4434,13 @@ test_case_sparse_ops = [ Tensor(np.array([1, 1]), mstype.int64), Tensor(np.array([[1, 2], [3, 4]]), mstype.int64)], 'skip': ['backward']}), + ('SparseSegmentMeanWithNumSegments', { + 'block': SparseSegmentMeanWithNumSegments(), + 'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)), + Tensor(np.array([0, 1]).astype(np.int32)), + Tensor(np.array([0, 1]).astype(np.int32)), + Tensor(np.array([2]).astype(np.int32))], + 'desc_bprop': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32))]}), ('SparseTensorDenseAdd', { 'block': SparseTensorDenseAdd(), 'desc_inputs': [Tensor([[0]], mstype.int32),