!36993 [feat][assistant][I48O4X, I48O7H]Add SparseSegmentMeanWithNumSegments, SparseSegmentMeanGrad

Merge pull request !36993 from 桂宁馨/SparseSegmentMeanWithNumSegments
This commit is contained in:
i-robot 2022-08-24 09:01:25 +00:00 committed by Gitee
commit 6806bf42cd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 929 additions and 0 deletions

View File

@ -0,0 +1,127 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_segment_mean_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseSegmentMeanGradInputsNum = 4;
constexpr size_t kSparseSegmentMeanGradOutputsNum = 1;
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
KernelAttr() \
.AddInputAttr(kNumberType##t1) \
.AddInputAttr(kNumberType##t2) \
.AddInputAttr(kNumberType##t3) \
.AddInputAttr(kNumberType##t4) \
.AddOutputAttr(kNumberType##t5)
} // namespace
void SparseSegmentMeanGradCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentMeanGradInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentMeanGradOutputsNum, kernel_name_);
}
void SparseSegmentMeanGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
}
bool SparseSegmentMeanGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (x_dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (x_dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
<< " which is not supported.";
}
return true;
}
template <typename T>
void SparseSegmentMeanGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
constexpr size_t kMultiply = 1;
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kMultiply, std::multiplies<int>()) / x_shape_[kIndex0];
size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kMultiply, std::multiplies<int>());
size_t num_elements = std::accumulate(y_shape_.begin(), y_shape_.end(), kMultiply, std::multiplies<int>());
int32_t k = *reinterpret_cast<int32_t *>(inputs[kIndex3]->addr);
auto x_shape_0 = static_cast<int32_t>(x_shape_[kIndex0]);
auto x_addr = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto indices_addr = reinterpret_cast<int32_t *>(inputs[kIndex1]->addr);
auto segment_ids_addr = reinterpret_cast<int32_t *>(inputs[kIndex2]->addr);
auto y_addr = reinterpret_cast<T *>(outputs[kIndex0]->addr);
for (size_t i = 0; i < num_elements; i++) {
y_addr[i] = (T)0;
}
for (size_t i = 1; i < m; i++) {
if (segment_ids_addr[i] < segment_ids_addr[i - 1]) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
}
}
for (size_t i = 0; i < m; i++) {
if (indices_addr[i] >= x_shape_0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices is out of range of x's first dimension.";
}
if (segment_ids_addr[i] >= k) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids is out of range of output_dim0.";
}
}
int beginindex = segment_ids_addr[0];
size_t countnum = 1;
for (size_t i = 1; i < m; i++) {
if (segment_ids_addr[i] != beginindex) {
for (size_t j = 1; j <= countnum; j++) {
for (size_t l = 0; l < n; l++) {
y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(countnum);
}
}
beginindex = segment_ids_addr[i];
countnum = 1;
} else {
countnum++;
}
}
int i = m;
for (size_t j = 1; j <= countnum; j++) {
for (size_t l = 0; l < n; l++) {
y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(countnum);
}
}
}
std::vector<KernelAttr> SparseSegmentMeanGradCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {ADD_KERNEL(Float32, Int32, Int32, Int32, Float32),
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentMeanGrad, SparseSegmentMeanGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_GRAD_CPU_KERNEL_H_
#include <functional>
#include <numeric>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseSegmentMeanGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SparseSegmentMeanGradCpuKernelMod() = default;
~SparseSegmentMeanGradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParam(const CNodePtr &kernel_node);
ShapeVector x_shape_;
ShapeVector segment_ids_shape_;
ShapeVector y_shape_;
TypeId x_dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_CPU_KERNEL_H_

View File

@ -0,0 +1,155 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_segment_mean_with_num_segments_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseSegmentMeanWithNumSegmentsInputsNum = 4;
constexpr size_t kSparseSegmentMeanWithNumSegmentsOutputsNum = 1;
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
KernelAttr() \
.AddInputAttr(kNumberType##t1) \
.AddInputAttr(kNumberType##t2) \
.AddInputAttr(kNumberType##t3) \
.AddInputAttr(kNumberType##t4) \
.AddOutputAttr(kNumberType##t5)
} // namespace
void SparseSegmentMeanWithNumSegmentsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
}
bool SparseSegmentMeanWithNumSegmentsCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (x_dtype_) {
case (kNumberTypeFloat16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float16, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float16, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat32):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat64):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<double, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<double, int64_t>(inputs, outputs);
break;
}
default:
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
<< " which is not supported.";
}
return true;
}
template <typename T1, typename T2>
void SparseSegmentMeanWithNumSegmentsCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
constexpr size_t kMultiply = 1;
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kMultiply, std::multiplies<int>()) / x_shape_[kIndex0];
size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kMultiply, std::multiplies<int>());
size_t num_elements = std::accumulate(y_shape_.begin(), y_shape_.end(), kMultiply, std::multiplies<int>());
auto x_shape_0 = static_cast<T2>(x_shape_[kIndex0]);
auto x_addr = reinterpret_cast<T1 *>(inputs[kIndex0]->addr);
auto indices_addr = reinterpret_cast<T2 *>(inputs[kIndex1]->addr);
auto segment_ids_addr = reinterpret_cast<T2 *>(inputs[kIndex2]->addr);
auto num_segments_addr = reinterpret_cast<T2 *>(inputs[kIndex3]->addr);
auto y_addr = reinterpret_cast<T1 *>(outputs[kIndex0]->addr);
for (size_t i = 1; i < m; i++) {
if (segment_ids_addr[i] < segment_ids_addr[i - 1]) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input segment_ids should be sorted.";
}
}
if (segment_ids_addr[m - 1] >= num_segments_addr[0]) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', num_segments must be bigger than the largest id of segment_ids.";
}
for (size_t i = 0; i < m; i++) {
if (indices_addr[i] >= x_shape_0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input indices is out of range of x's first dimension.";
}
}
for (size_t i = 0; i < num_elements; i++) {
y_addr[i] = (T1)0;
}
int oldindex = -1;
int countnum = 0;
for (size_t i = 0; i < m; i++) {
if (oldindex == segment_ids_addr[i]) {
countnum++;
} else {
if (countnum != 0) {
for (size_t j = 0; j < n; j++) {
y_addr[j + oldindex * n] /= (T1)countnum;
}
}
countnum = 1;
oldindex = segment_ids_addr[i];
}
for (size_t j = 0; j < n; j++) {
y_addr[j + oldindex * n] += x_addr[j + indices_addr[i] * n];
}
}
if (countnum != 0) {
for (size_t j = 0; j < n; j++) {
y_addr[j + oldindex * n] /= (T1)countnum;
}
}
}
void SparseSegmentMeanWithNumSegmentsCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentMeanWithNumSegmentsInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentMeanWithNumSegmentsOutputsNum, kernel_name_);
}
std::vector<KernelAttr> SparseSegmentMeanWithNumSegmentsCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
ADD_KERNEL(Float16, Int32, Int32, Int32, Float16), ADD_KERNEL(Float16, Int64, Int64, Int64, Float16),
ADD_KERNEL(Float32, Int32, Int32, Int32, Float32), ADD_KERNEL(Float32, Int64, Int64, Int64, Float32),
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64), ADD_KERNEL(Float64, Int64, Int64, Int64, Float64)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentMeanWithNumSegments,
SparseSegmentMeanWithNumSegmentsCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
#include <functional>
#include <numeric>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseSegmentMeanWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SparseSegmentMeanWithNumSegmentsCpuKernelMod() = default;
~SparseSegmentMeanWithNumSegmentsCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T1, typename T2>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParam(const CNodePtr &kernel_node);
ShapeVector x_shape_;
ShapeVector segment_ids_shape_;
ShapeVector y_shape_;
TypeId x_dtype_{kTypeUnknown};
TypeId indices_dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_MEAN_WITH_NUM_SGEMENTS_CPU_KERNEL_H_

View File

@ -324,6 +324,8 @@ constexpr auto kCSRDiv = "CSRDiv";
constexpr auto kDenseToDenseSetOperation = "DenseToDenseSetOperation";
constexpr auto kSparseMatrixAdd = "SparseMatrixAdd";
constexpr auto kSparseAdd = "SparseAdd";
constexpr auto kSparseSegmentMeanGrad = "SparseSegmentMeanGrad";
constexpr auto kSparseSegmentMeanWithNumSegments = "SparseSegmentMeanWithNumSegments";
constexpr auto kSparseConcat = "SparseConcat";
constexpr auto kSparseMatrixNNZ = "SparseMatrixNNZ";
constexpr auto kSparseMatrixTranspose = "SparseMatrixTranspose";
@ -994,6 +996,9 @@ GVAR_DEF(PrimitivePtr, kPrimSparseSplit, std::make_shared<Primitive>(kSparseSpli
GVAR_DEF(PrimitivePtr, kPrimDenseToDenseSetOperation, std::make_shared<Primitive>(kDenseToDenseSetOperation));
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixAdd, std::make_shared<Primitive>(kSparseMatrixAdd));
GVAR_DEF(PrimitivePtr, kPrimSparseAdd, std::make_shared<Primitive>(kSparseAdd));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMeanGrad, std::make_shared<Primitive>("SparseSegmentMeanGrad"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMeanWithNumSegments,
std::make_shared<Primitive>("SparseSegmentMeanWithNumSegments"));
GVAR_DEF(PrimitivePtr, kPrimDenseToCSRSparseMatrix, std::make_shared<Primitive>("DenseToCSRSparseMatrix"));
GVAR_DEF(PrimitivePtr, kPrimSparseTensorToCSRSparseMatrix, std::make_shared<Primitive>(kSparseTensorToCSRSparseMatrix));
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared<Primitive>(kCSRSparseMatrixToSparseTensor));

View File

@ -0,0 +1,106 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/sparse_segment_mean_grad.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentMeanGradInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
constexpr size_t kRankNum0 = 0;
constexpr size_t kRankNum1 = 1;
constexpr size_t kShapeNum0 = 0;
constexpr int kDimNum0 = 0;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto output_dim0_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
if (x_shape.size() < kRankNum1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank cannot be less than 1.";
}
if (output_dim0_shape.size() != kRankNum0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar.";
}
if (indices_shape[kShapeNum0] != segment_ids_shape[kShapeNum0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch.";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
auto output_dim0_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(output_dim0_value);
auto output_dim0_value_ptr = output_dim0_value->BuildValue();
MS_EXCEPTION_IF_NULL(output_dim0_value_ptr);
auto output_dim0_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
int dim_zero = output_dim0_value_ptr_tensor[kShapeNum0];
if (dim_zero <= kDimNum0) {
MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!";
} else {
ShapeVector y_shape = x_shape;
y_shape[kShapeNum0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentMeanGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat32, kFloat64}, prim->name());
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("output_dim0", output_dim0_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentMeanGrad, BaseOperator);
AbstractBasePtr SparseSegmentMeanGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = kInputIndex4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentMeanGradInferType(prim, input_args);
auto shapes = SparseSegmentMeanGradInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentMeanGrad, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentMeanGrad, prim::kPrimSparseSegmentMeanGrad, SparseSegmentMeanGradInfer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSegmentMeanGrad = "SparseSegmentMeanGrad";
class MIND_API SparseSegmentMeanGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSegmentMeanGrad);
SparseSegmentMeanGrad() : BaseOperator(kNameSparseSegmentMeanGrad) {
InitIOName({"x", "indices", "segment_ids", "output_dim0"}, {"y"});
}
};
abstract::AbstractBasePtr SparseSegmentMeanGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSparseSegmentMeanGradPtr = std::shared_ptr<SparseSegmentMeanGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_GRAD_H_

View File

@ -0,0 +1,117 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "ops/sparse_segment_mean_with_num_segments.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentMeanWithNumSegmentsInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
constexpr size_t kRankOne = 1;
constexpr size_t kDimOne = 1;
constexpr size_t kShapeZero = 0;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto num_segments_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
if (indices_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1.";
}
if (segment_ids_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1.";
}
if (x_shape.size() < kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot be less than 1.";
}
if (indices_shape[kShapeZero] != segment_ids_shape[kShapeZero]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", indices and segment_ids's ranks mismatch.";
}
if (num_segments_shape.size() > kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of num_segments should be 0 or 1.";
}
if (num_segments_shape.size() == kRankOne && num_segments_shape[kShapeZero] != kDimOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1.";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_value);
auto num_segments_value_ptr = num_segments_value->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim_name);
size_t dim_zero = num_segments_value_ptr_tensor.back();
if (dim_zero < kDimOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", num_segments must be bigger than the largest id of segment_ids.";
} else {
ShapeVector y_shape = x_shape;
y_shape[kShapeZero] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentMeanWithNumSegmentsInferType(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto num_segments_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("num_segments", num_segments_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
} // namespace
AbstractBasePtr SparseSegmentMeanWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto types = SparseSegmentMeanWithNumSegmentsInferType(prim, input_args);
auto shapes = SparseSegmentMeanWithNumSegmentsInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
MIND_API_OPERATOR_IMPL(SparseSegmentMeanWithNumSegments, BaseOperator);
REGISTER_HOST_DEPENDS(kNameSparseSegmentMeanWithNumSegments, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentMeanWithNumSegments, prim::kPrimSparseSegmentMeanWithNumSegments,
SparseSegmentMeanWithNumSegmentsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_
#include <set>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSegmentMeanWithNumSegments = "SparseSegmentMeanWithNumSegments";
/// \brief Computes the mean along sparse segments of a tensor, but it is allowed to miss id in segment_ids.
/// Refer to Python API @ref mindspore.ops.SparseSegmentMeanWithNumSegments for more details.
class MIND_API SparseSegmentMeanWithNumSegments : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSegmentMeanWithNumSegments);
/// \brief Constructor.
SparseSegmentMeanWithNumSegments() : BaseOperator(kNameSparseSegmentMeanWithNumSegments) {
InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"});
}
};
abstract::AbstractBasePtr SparseSegmentMeanWithNumSegmentsInfer(
const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSparseSegmentMeanWithNumSegmentsPtr = std::shared_ptr<SparseSegmentMeanWithNumSegments>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_H_

View File

@ -19,6 +19,7 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
from mindspore.ops.operations.sparse_ops import SparseToDenseV2
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
from mindspore.common import dtype as mstype
from .. import functional as F
from .. import operations as P
@ -99,3 +100,20 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
return all_d
return bprop
@bprop_getters.register(SparseSegmentMeanWithNumSegments)
def get_bprop_sparse_segment_mean_with_num_segments(self):
"""Grad definition for `SparseSegmentMeanWithNumSegments` operation."""
input_grad = G.SparseSegmentMeanGrad()
shape = P.Shape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)
all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments))
return all_d
return bprop

View File

@ -294,6 +294,8 @@ from .segment_sum import _segment_sum_aicpu
from .sparse_segment_sqrt_n import _sparse_segment_sqrt_n_aicpu
from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu
from .sparse_segment_sqrt_n_with_num_segments import _sparse_segment_sqrt_n_with_num_segments_aicpu
from .sparse_segment_mean_grad import _sparse_segment_mean_grad_aicpu
from .sparse_segment_mean_with_num_segments import _sparse_segment_mean_with_num_segments_aicpu
from .scatter_nd_max import _scatter_nd_max_aicpu
from .conj import _conj_aicpu
from .ctc_loss_v2 import _ctc_loss_v2_aicpu

View File

@ -0,0 +1,36 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseSegmentMeanGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sparse_segment_mean_grad_op_info = AiCPURegOp("SparseSegmentMeanGrad") \
.fusion_type("OPAQUE") \
.input(0, "dout", "required") \
.input(1, "indices", "required") \
.input(2, "segment_ids", "required") \
.input(3, "output_dim0", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(sparse_segment_mean_grad_op_info)
def _sparse_segment_mean_grad_aicpu():
"""SparseSegmentMeanGrad aicpu register"""
return

View File

@ -0,0 +1,44 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseSegmentMeanWithNumSegments op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sparse_segment_mean_with_num_segments_op_info = AiCPURegOp("SparseSegmentMeanWithNumSegments") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "indices", "required") \
.input(2, "segment_ids", "required") \
.input(3, "num_segments", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(sparse_segment_mean_with_num_segments_op_info)
def _sparse_segment_mean_with_num_segments_aicpu():
"""SparseSegmentMeanWithNumSegments aicpu register"""
return

View File

@ -3153,6 +3153,45 @@ class GridSampler3DGrad(Primitive):
self.add_prim_attr('align_corners', align_corners)
class SparseSegmentMeanGrad(Primitive):
"""
Compute gradients for SparseSegmentMeanGrad operation.
Inputs:
- **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanGrad.
- **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`.
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one of the
following types: int32, int64. Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentMean op.
Outputs:
A Tensor. Has the same type as `x` .
Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
Raises:
TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
TypeError: If the dtype of `x` is not any of the following data types: {float32, float64}.
TypeError: If the dtype of `indices` is not int32.
TypeError: If the dtype of `segment_ids` is not int32.
TypeError: If the dtype of `output_dim0` is not int32.
ValueError: If dimension size of `x` less than 1.
ValueError: If rank of `indices` or `segment_ids` is not 1.
ValueError: If dimension size of `output_dim0` is not 0.
ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
ValueError: If `segment_ids` is not sorted.
ValueError: If `indices` is out of range of x's first dimension.
Supported Platforms:
``Ascend`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize SparseSegmentMeanGrad"""
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
class FractionalMaxPoolGrad(Primitive):
"""Computes gradients for FractionalMaxPool operation."""

View File

@ -1085,6 +1085,62 @@ class SparseMatrixNNZ(Primitive):
inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'], outputs=['y'])
class SparseSegmentMeanWithNumSegments(Primitive):
"""
Compute the mean along sparse segments of a tensor. It is allowed to have missing id in segment_ids.
Inputs:
- **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanWithNumSegments.
- **indices** (Tensor) - 1-D Tensor with indices into `x`. Must be one of the following
types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`.
- **segment_ids** (Tensor) - 1-D Tensor with indices into the output `y`. Must be one of the
following types: int32, int64. Values should be sorted and can be repeated. The shape should
be :math:`(N,)`.
- **num_segments** (Tensor) - Num_segments indicates the size of the output.
It should be bigger than the largest id of `segment_ids`.
Outputs:
A Tensor. Has the same type as `x` .
Has same shape as `x`, except for dimension 0 which is the value of `num_segments`.
Raises:
TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor.
TypeError: If dtype of `x` is not in [float16, float32, float64].
TypeError: If dtype of `indices` is not int32 or int64.
TypeError: If dtype of `segment_ids` and `indices` mismatch.
TypeError: If dtype of `num_segments` and `indices` mismatch.
ValueError: If rank of `x` less than 1.
ValueError: If rank of `indices` or `segment_ids` is not 1.
ValueError: If rank of `num_segments` is bigger than 1.
ValueError: If numelements of `num_segments` is not 1.
ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
ValueError: If `segment_ids` is not sorted.
ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`.
ValueError: If `indices` is out of range of x's first dimension.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=ms.float16)
>>> indices = Tensor([0, 2, 1], dtype=ms.int32)
>>> segment_ids = Tensor([0, 0, 2], dtype=ms.int32)
>>> num_segments = Tensor([4], dtype=ms.int32)
>>> sparse_segment_mean_with_num_segments = ops.SparseSegmentMeanWithNumSegments()
>>> output = sparse_segment_mean_with_num_segments(x, indices, segment_ids, num_segments)
>>> print(output)
[[1. 1. 1. 0.]
[0. 0. 0. 0.]
[0. 1. 1. 0.]
[0. 0. 0. 0.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SparseSegmentMeanWithNumSegments"""
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'num_segments'], outputs=['y'])
class SparseAdd(Primitive):
"""
Computes the sum of a COOTensor and another COOTensor.

View File

@ -146,6 +146,7 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
from mindspore.ops.operations.sparse_ops import SparseSparseMinimum
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
from mindspore.ops.operations.other_ops import BlackmanWindow
from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp
from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent
@ -2177,6 +2178,13 @@ test_case_math_ops = [
'block': P.Sign(),
'desc_inputs': [[3]],
'desc_bprop': [[3]]}),
('SparseSegmentMeanGrad', {
'block': G.SparseSegmentMeanGrad(),
'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)),
Tensor(np.array([0, 1]).astype(np.int32)),
Tensor(np.array([0, 1]).astype(np.int32)),
Tensor(np.array(4).astype(np.int32))],
'skip': ['backward']}),
('Round', {
'block': P.Round(),
'desc_inputs': [[3]],
@ -4426,6 +4434,13 @@ test_case_sparse_ops = [
Tensor(np.array([1, 1]), mstype.int64),
Tensor(np.array([[1, 2], [3, 4]]), mstype.int64)],
'skip': ['backward']}),
('SparseSegmentMeanWithNumSegments', {
'block': SparseSegmentMeanWithNumSegments(),
'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)),
Tensor(np.array([0, 1]).astype(np.int32)),
Tensor(np.array([0, 1]).astype(np.int32)),
Tensor(np.array([2]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32))]}),
('SparseTensorDenseAdd', {
'block': SparseTensorDenseAdd(),
'desc_inputs': [Tensor([[0]], mstype.int32),