!44545 SparseSegmentSqrtNGrad动态shape支持
Merge pull request !44545 from haozhang/sparse_segment_sqrt_n_grad
This commit is contained in:
commit
9f43350d65
|
@ -32,22 +32,30 @@ constexpr size_t kSparseSegmentSqrtNGradOutputsNum = 1;
|
|||
.AddOutputAttr(kNumberType##t5)
|
||||
} // namespace
|
||||
|
||||
void SparseSegmentSqrtNGradCpuKernelMod::CheckParam(const CNodePtr &kernel_node) const {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSqrtNGradInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSqrtNGradOutputsNum, kernel_name_);
|
||||
bool SparseSegmentSqrtNGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
x_dtype_ = inputs.at(kIndex0)->GetDtype();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSegmentSqrtNGradInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSegmentSqrtNGradOutputsNum, kernel_name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void SparseSegmentSqrtNGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
output_dim0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex3);
|
||||
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
int SparseSegmentSqrtNGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
x_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
indices_shape_ = inputs.at(kIndex1)->GetDeviceShapeAdaptively();
|
||||
segment_ids_shape_ = inputs.at(kIndex2)->GetDeviceShapeAdaptively();
|
||||
output_dim0_shape_ = inputs.at(kIndex3)->GetDeviceShapeAdaptively();
|
||||
y_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool SparseSegmentSqrtNGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -19,18 +19,24 @@
|
|||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentSqrtNGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class SparseSegmentSqrtNGradCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SparseSegmentSqrtNGradCpuKernelMod() = default;
|
||||
~SparseSegmentSqrtNGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -41,7 +47,6 @@ class SparseSegmentSqrtNGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector segment_ids_shape_;
|
||||
|
|
|
@ -43,7 +43,7 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
|
|||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "tensor x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (!IsDynamicRank(output_dim0_shape) && output_dim0_shape.size() != kInputIndex0) {
|
||||
if (!IsDynamic(output_dim0_shape) && output_dim0_shape.size() != kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
|
||||
<< "but got [" << output_dim0_shape.size() << "].";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.ops.operations.sparse_ops as P
|
||||
from mindspore import nn, context, Tensor
|
||||
from .test_grad_of_dynamic import TestDynamicGrad
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class NetSparseSegmentSqrtN(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(NetSparseSegmentSqrtN, self).__init__()
|
||||
self.op = P.SparseSegmentSqrtN()
|
||||
|
||||
def construct(self, x, indices, segment_ids):
|
||||
return self.op(x, indices, segment_ids)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_test(is_dyn_rank):
|
||||
x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=ms.float32)
|
||||
indices = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
segment_ids = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
tester = TestDynamicGrad(NetSparseSegmentSqrtN())
|
||||
tester.test_dynamic_grad_net([x, indices, segment_ids], is_dyn_rank)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_sparse_segment_sqrt_n_dyn_shape():
|
||||
"""
|
||||
Feature: SparseSegmentMean Grad Dynamic Shape.
|
||||
Description: Test case of dynamic shape for SparseSegmentMean grad operator.
|
||||
Expectation: success.
|
||||
"""
|
||||
sparse_segment_sqrt_n_test(False)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_sparse_segment_sqrt_n_dyn_rank():
|
||||
"""
|
||||
Feature: SparseSegmentMean Grad Dynamic Rank.
|
||||
Description: Test case of dynamic rank for SparseSegmentMean grad operator.
|
||||
Expectation: success.
|
||||
"""
|
||||
sparse_segment_sqrt_n_test(True)
|
Loading…
Reference in New Issue