!49953 CPU后端算子Size
Merge pull request !49953 from shenwei41/size_operation
This commit is contained in:
commit
80581893ba
|
@ -0,0 +1,101 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "plugin/device/cpu/kernel/size_cpu_kernel.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <tuple>
|
||||||
|
#include "include/common/thread_pool.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
const size_t kSizeInputsNum = 1;
|
||||||
|
const size_t kSizeOutputsNum = 1;
|
||||||
|
}; // namespace
|
||||||
|
bool SizeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
kernel_name_ = base_operator->name();
|
||||||
|
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto is_match = MatchKernelAttr(tensor_attr, GetOpSupport()).first;
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG_ERROR << "Can not match kernel based on given attr!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Resize(base_operator, inputs, outputs) == KRET_RESIZE_FAILED) {
|
||||||
|
MS_LOG_ERROR << "Resize failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SizeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
|
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
auto shape_vector = inputs[kIndex0]->GetShapeVector();
|
||||||
|
int64_t elements = 1;
|
||||||
|
for (size_t i = 0; i < shape_vector.size(); i++) {
|
||||||
|
elements *= shape_vector[i];
|
||||||
|
}
|
||||||
|
input_elements = elements;
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SizeCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSizeInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSizeOutputsNum, kernel_name_);
|
||||||
|
auto output_data = reinterpret_cast<int32_t *>(outputs[kIndex0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_data);
|
||||||
|
output_data[kIndex0] = input_elements;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<KernelAttr> SizeCpuKernelMod::GetOpSupport() {
|
||||||
|
return {
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeComplex).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt4).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeGLUInt).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Size, SizeCpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <limits>
|
||||||
|
#include <tuple>
|
||||||
|
#include <map>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class SizeCpuKernelMod : public NativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
SizeCpuKernelMod() = default;
|
||||||
|
~SizeCpuKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t input_elements;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_
|
|
@ -15,13 +15,68 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/size.h"
|
#include "ops/size.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "ops/op_utils.h"
|
||||||
#include "mindapi/src/helper.h"
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
constexpr int64_t input_num = 1;
|
||||||
|
} // namespace
|
||||||
|
class SizeInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
return abstract::kNoShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
TypePtr res = kInt64;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
|
auto input_type = input_args[0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_type);
|
||||||
|
if (!input_type->isa<TensorType>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||||
|
<< "', input must be a Tensor, but got: " << input_type->ToString() << ".";
|
||||||
|
}
|
||||||
|
auto input_shape_ptr = input_args[0]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_shape_ptr);
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_shape_ptr);
|
||||||
|
auto input_shape = shape_map[kShape];
|
||||||
|
if (IsDynamicRank(input_shape) || IsDynamicShape(input_shape)) {
|
||||||
|
return kAnyValue;
|
||||||
|
}
|
||||||
|
size_t elements = 1;
|
||||||
|
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||||
|
elements *= input_shape[i];
|
||||||
|
}
|
||||||
|
auto elements_value = SizeToLong(elements);
|
||||||
|
ValuePtr res = MakeValue(elements_value);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
auto type = InferType(primitive, input_args);
|
||||||
|
auto shape = InferShape(primitive, input_args);
|
||||||
|
auto value = InferValue(primitive, input_args);
|
||||||
|
auto res = MakeAbstract(shape, type);
|
||||||
|
res->set_value(value);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
MIND_API_OPERATOR_IMPL(Size, BaseOperator);
|
MIND_API_OPERATOR_IMPL(Size, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_C(kNameSize, Size);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(Size, prim::kPrimSize, SizeInfer, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
"""Operators for array."""
|
"""Operators for array."""
|
||||||
import copy
|
import copy
|
||||||
import functools
|
|
||||||
import itertools
|
import itertools
|
||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
|
@ -1197,7 +1196,7 @@ class Rank(Primitive):
|
||||||
return len(x.shape)
|
return len(x.shape)
|
||||||
|
|
||||||
|
|
||||||
class Size(PrimitiveWithInfer):
|
class Size(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
||||||
Tensor.
|
Tensor.
|
||||||
|
@ -1219,19 +1218,6 @@ class Size(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Size"""
|
"""Initialize Size"""
|
||||||
|
|
||||||
def __infer__(self, x):
|
|
||||||
size = 1
|
|
||||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
|
||||||
shp = x['shape']
|
|
||||||
if not shp:
|
|
||||||
size = 0
|
|
||||||
else:
|
|
||||||
size = functools.reduce(lambda x, y: x * y, x['shape'])
|
|
||||||
out = {'shape': None,
|
|
||||||
'dtype': mstype.int64,
|
|
||||||
'value': size}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixDiagV3(Primitive):
|
class MatrixDiagV3(Primitive):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.ops = P.Size()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.ops(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_size_1_dimension(mode):
|
||||||
|
"""
|
||||||
|
Feature: test pynative mode and graph mode
|
||||||
|
Description: Test 1-D Tensor
|
||||||
|
Expectation: the result match to expected value
|
||||||
|
"""
|
||||||
|
np_array = np.array([2, 3, 4]).astype(np.int32)
|
||||||
|
input_x = Tensor(np_array)
|
||||||
|
expect = 3
|
||||||
|
net = Net()
|
||||||
|
out = net(input_x)
|
||||||
|
assert out == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_size_2_dimension(mode):
|
||||||
|
"""
|
||||||
|
Feature: test pynative mode and graph mode
|
||||||
|
Description: Test 2-D Tensor
|
||||||
|
Expectation: the result match to expected value
|
||||||
|
"""
|
||||||
|
np_array = np.array([[2, 2], [2, 2], [3, 3]]).astype(np.int32)
|
||||||
|
input_x = Tensor(np_array)
|
||||||
|
expect = 6
|
||||||
|
net = Net()
|
||||||
|
out = net(input_x)
|
||||||
|
assert out == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_size_3_dimension(mode):
|
||||||
|
"""
|
||||||
|
Feature: test pynative mode and graph mode
|
||||||
|
Description: Test 3-D Tensor
|
||||||
|
Expectation: the result match to expected value
|
||||||
|
"""
|
||||||
|
np_array = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]], [[5, 5], [6, 6]]]).astype(np.int32)
|
||||||
|
input_x = Tensor(np_array)
|
||||||
|
expect = 12
|
||||||
|
net = Net()
|
||||||
|
out = net(input_x)
|
||||||
|
assert out == expect
|
Loading…
Reference in New Issue