!49953 CPU后端算子Size

Merge pull request !49953 from shenwei41/size_operation
This commit is contained in:
i-robot 2023-03-08 11:40:00 +00:00 committed by Gitee
commit 80581893ba
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 298 additions and 18 deletions

View File

@ -0,0 +1,101 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/size_cpu_kernel.h"
#include <cmath>
#include <functional>
#include <map>
#include <type_traits>
#include <algorithm>
#include <tuple>
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
const size_t kSizeInputsNum = 1;
const size_t kSizeOutputsNum = 1;
}; // namespace
bool SizeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(tensor_attr, GetOpSupport()).first;
if (!is_match) {
MS_LOG_ERROR << "Can not match kernel based on given attr!";
return false;
}
if (Resize(base_operator, inputs, outputs) == KRET_RESIZE_FAILED) {
MS_LOG_ERROR << "Resize failed!";
return false;
}
return true;
}
int SizeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
MS_EXCEPTION_IF_NULL(base_operator);
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto shape_vector = inputs[kIndex0]->GetShapeVector();
int64_t elements = 1;
for (size_t i = 0; i < shape_vector.size(); i++) {
elements *= shape_vector[i];
}
input_elements = elements;
return KRET_OK;
}
bool SizeCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSizeInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSizeOutputsNum, kernel_name_);
auto output_data = reinterpret_cast<int32_t *>(outputs[kIndex0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
output_data[kIndex0] = input_elements;
return true;
}
std::vector<KernelAttr> SizeCpuKernelMod::GetOpSupport() {
return {
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeComplex).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt4).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeGLUInt).AddOutputAttr(kNumberTypeInt32),
};
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Size, SizeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_
#include <vector>
#include <string>
#include <limits>
#include <tuple>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
namespace mindspore {
namespace kernel {
class SizeCpuKernelMod : public NativeCpuKernelMod {
public:
SizeCpuKernelMod() = default;
~SizeCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
int32_t input_elements;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_

View File

@ -15,13 +15,68 @@
*/
#include "ops/size.h"
#include "ops/primitive_c.h"
#include "utils/log_adapter.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t input_num = 1;
} // namespace
class SizeInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
TypePtr res = kInt64;
return res;
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (!input_type->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', input must be a Tensor, but got: " << input_type->ToString() << ".";
}
auto input_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_shape_ptr);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_shape_ptr);
auto input_shape = shape_map[kShape];
if (IsDynamicRank(input_shape) || IsDynamicShape(input_shape)) {
return kAnyValue;
}
size_t elements = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
elements *= input_shape[i];
}
auto elements_value = SizeToLong(elements);
ValuePtr res = MakeValue(elements_value);
return res;
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
auto value = InferValue(primitive, input_args);
auto res = MakeAbstract(shape, type);
res->set_value(value);
return res;
}
};
MIND_API_OPERATOR_IMPL(Size, BaseOperator);
REGISTER_PRIMITIVE_C(kNameSize, Size);
REGISTER_PRIMITIVE_OP_INFER_IMPL(Size, prim::kPrimSize, SizeInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -15,7 +15,6 @@
"""Operators for array."""
import copy
import functools
import itertools
import numbers
@ -1197,7 +1196,7 @@ class Rank(Primitive):
return len(x.shape)
class Size(PrimitiveWithInfer):
class Size(Primitive):
r"""
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
Tensor.
@ -1219,19 +1218,6 @@ class Size(PrimitiveWithInfer):
def __init__(self):
"""Initialize Size"""
def __infer__(self, x):
size = 1
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
shp = x['shape']
if not shp:
size = 0
else:
size = functools.reduce(lambda x, y: x * y, x['shape'])
out = {'shape': None,
'dtype': mstype.int64,
'value': size}
return out
class MatrixDiagV3(Primitive):
"""

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.ops = P.Size()
def construct(self, x):
return self.ops(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_size_1_dimension(mode):
"""
Feature: test pynative mode and graph mode
Description: Test 1-D Tensor
Expectation: the result match to expected value
"""
np_array = np.array([2, 3, 4]).astype(np.int32)
input_x = Tensor(np_array)
expect = 3
net = Net()
out = net(input_x)
assert out == expect
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_size_2_dimension(mode):
"""
Feature: test pynative mode and graph mode
Description: Test 2-D Tensor
Expectation: the result match to expected value
"""
np_array = np.array([[2, 2], [2, 2], [3, 3]]).astype(np.int32)
input_x = Tensor(np_array)
expect = 6
net = Net()
out = net(input_x)
assert out == expect
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_size_3_dimension(mode):
"""
Feature: test pynative mode and graph mode
Description: Test 3-D Tensor
Expectation: the result match to expected value
"""
np_array = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]], [[5, 5], [6, 6]]]).astype(np.int32)
input_x = Tensor(np_array)
expect = 12
net = Net()
out = net(input_x)
assert out == expect