!44545 SparseSegmentSqrtNGrad动态shape支持

Merge pull request !44545 from haozhang/sparse_segment_sqrt_n_grad
This commit is contained in:
i-robot 2022-11-02 17:33:26 +00:00 committed by Gitee
commit 9f43350d65
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 96 additions and 18 deletions

View File

@ -32,22 +32,30 @@ constexpr size_t kSparseSegmentSqrtNGradOutputsNum = 1;
.AddOutputAttr(kNumberType##t5)
} // namespace
void SparseSegmentSqrtNGradCpuKernelMod::CheckParam(const CNodePtr &kernel_node) const {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSqrtNGradInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSqrtNGradOutputsNum, kernel_name_);
bool SparseSegmentSqrtNGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
x_dtype_ = inputs.at(kIndex0)->GetDtype();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSegmentSqrtNGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSegmentSqrtNGradOutputsNum, kernel_name_);
return true;
}
void SparseSegmentSqrtNGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
output_dim0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex3);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
int SparseSegmentSqrtNGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
x_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
indices_shape_ = inputs.at(kIndex1)->GetDeviceShapeAdaptively();
segment_ids_shape_ = inputs.at(kIndex2)->GetDeviceShapeAdaptively();
output_dim0_shape_ = inputs.at(kIndex3)->GetDeviceShapeAdaptively();
y_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
return KRET_OK;
}
bool SparseSegmentSqrtNGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -19,18 +19,24 @@
#include <functional>
#include <vector>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseSegmentSqrtNGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SparseSegmentSqrtNGradCpuKernelMod : public NativeCpuKernelMod {
public:
SparseSegmentSqrtNGradCpuKernelMod() = default;
~SparseSegmentSqrtNGradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -41,7 +47,6 @@ class SparseSegmentSqrtNGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParam(const CNodePtr &kernel_node) const;
ShapeVector x_shape_;
ShapeVector indices_shape_;
ShapeVector segment_ids_shape_;

View File

@ -43,7 +43,7 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "tensor x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (!IsDynamicRank(output_dim0_shape) && output_dim0_shape.size() != kInputIndex0) {
if (!IsDynamic(output_dim0_shape) && output_dim0_shape.size() != kInputIndex0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
<< "but got [" << output_dim0_shape.size() << "].";
}

View File

@ -0,0 +1,65 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.ops.operations.sparse_ops as P
from mindspore import nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
class NetSparseSegmentSqrtN(nn.Cell):
def __init__(self):
super(NetSparseSegmentSqrtN, self).__init__()
self.op = P.SparseSegmentSqrtN()
def construct(self, x, indices, segment_ids):
return self.op(x, indices, segment_ids)
def sparse_segment_sqrt_n_test(is_dyn_rank):
x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=ms.float32)
indices = Tensor([0, 1, 2], dtype=ms.int32)
segment_ids = Tensor([0, 1, 2], dtype=ms.int32)
tester = TestDynamicGrad(NetSparseSegmentSqrtN())
tester.test_dynamic_grad_net([x, indices, segment_ids], is_dyn_rank)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_sparse_segment_sqrt_n_dyn_shape():
"""
Feature: SparseSegmentMean Grad Dynamic Shape.
Description: Test case of dynamic shape for SparseSegmentMean grad operator.
Expectation: success.
"""
sparse_segment_sqrt_n_test(False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_sparse_segment_sqrt_n_dyn_rank():
"""
Feature: SparseSegmentMean Grad Dynamic Rank.
Description: Test case of dynamic rank for SparseSegmentMean grad operator.
Expectation: success.
"""
sparse_segment_sqrt_n_test(True)