This commit is contained in:
Arielxyx 2022-03-29 15:37:14 +08:00
parent dad943d6f6
commit 8f9226666f
10 changed files with 367 additions and 0 deletions

View File

@ -139,6 +139,24 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const
input_shape_attr_list->add_i(shape);
}
(*node_attr)[attr_name] = input_shape_attr;
} else if (type == "listFloat") {
std::vector<float> attr_value;
auto value_type = value->type();
MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString();
if (value_type_str == "float") {
auto data = GetValue<float>(value);
attr_value.push_back(data);
} else {
attr_value = GetValue<std::vector<float>>(value);
}
mindspore::AttrValue input_shape_attr;
mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array();
MS_EXCEPTION_IF_NULL(input_shape_attr_list);
for (const auto shape : attr_value) {
input_shape_attr_list->add_f(shape);
}
(*node_attr)[attr_name] = input_shape_attr;
} else {
MS_LOG(EXCEPTION) << "type: " << type << "not support";
}

View File

@ -0,0 +1,108 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/bucketize_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
const size_t kOutputNum = 1;
const size_t kInputNum = 1;
const size_t kParallelDataNumSameShape = 64 * 1024;
const size_t kParallelDataNumSameShapeMid = 35 * 1024;
} // namespace
void BucketizeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
boundaries_ = common::AnfAlgo::GetNodeAttr<std::vector<float>>(kernel_node, "boundaries");
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
bool BucketizeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (dtype_ != kNumberTypeInt32 && dtype_ != kNumberTypeInt64 && dtype_ != kNumberTypeFloat32 &&
dtype_ != kNumberTypeFloat64) {
MS_LOG(EXCEPTION) << "Input data type must int32 or int64 or float32 or float64, but got data type." << dtype_;
return false;
}
size_t input_sizes = input_shape_.size();
size_t output_sizes = output_shape_.size();
if (input_sizes != output_sizes) {
MS_LOG(EXCEPTION) << "The tensor shape of input need be same with output.";
return false;
}
// BucketizeCompute(inputs, outputs);
switch (dtype_) {
case kNumberTypeInt32:
return BucketizeCompute<int32_t>(inputs, outputs);
case kNumberTypeInt64:
return BucketizeCompute<int64_t>(inputs, outputs);
case kNumberTypeFloat32:
return BucketizeCompute<float>(inputs, outputs);
case kNumberTypeFloat64:
return BucketizeCompute<double>(inputs, outputs);
default:
MS_LOG(ERROR) << "Unsupported data type.";
}
return true;
}
template <typename T>
bool BucketizeCpuKernelMod::BucketizeCompute(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
auto output_data = reinterpret_cast<int32_t *>(outputs[0]->addr);
size_t data_num_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
std::vector<float> boundaries_data = boundaries_;
std::sort(boundaries_data.begin(), boundaries_data.end());
if (data_num_ >= kParallelDataNumSameShape) {
auto sharder_bucketize = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
output_data[i] = first_bigger_it - boundaries_data.begin();
}
};
ParallelLaunchAutoSearch(sharder_bucketize, data_num_, this, &parallel_search_info_);
} else {
for (size_t i = 0; i < data_num_; i++) {
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
output_data[i] = first_bigger_it - boundaries_data.begin();
}
}
return true;
}
std::vector<KernelAttr> BucketizeCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Bucketize, BucketizeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class BucketizeCpuKernelMod : public NativeCpuKernelMod {
public:
BucketizeCpuKernelMod() = default;
~BucketizeCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
bool BucketizeCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
std::vector<float> boundaries_;
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -739,6 +739,7 @@ GVAR_DEF(PrimitivePtr, kPrimInv, std::make_shared<Primitive>("Inv"));
GVAR_DEF(PrimitivePtr, kPrimBitwiseOr, std::make_shared<Primitive>("BitwiseOr"));
GVAR_DEF(PrimitivePtr, kPrimBitwiseAnd, std::make_shared<Primitive>("BitwiseAnd"));
GVAR_DEF(PrimitivePtr, kPrimBitwiseXor, std::make_shared<Primitive>("BitwiseXor"));
GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared<Primitive>("Bucketize"));
GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared<Primitive>("Einsum"));
GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"));

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/bucketize.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr BucketizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto out_shape = x_shape;
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr BucketizeInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input", x_type, common_valid_types, prim_name);
return std::make_shared<TensorType>(kInt32);
}
} // namespace
MIND_API_BASE_IMPL(Bucketize, PrimitiveC, BaseOperator);
AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const size_t input_num = 1;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = BucketizeInferType(primitive, input_args);
auto infer_shape = BucketizeInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Bucketize, prim::kPrimBucketize, BucketizeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_BUCKETIZE_H_
#define MINDSPORE_CORE_OPS_BUCKETIZE_H_
#include <memory>
#include <vector>
#include <algorithm>
#include <set>
#include <map>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBucketize = "Bucketize";
/// \brief Bucketizes 'input' based on 'boundaries'.
/// Refer to Python API @ref mindspore.ops.Bucketize for more details.
class MIND_API Bucketize : public BaseOperator {
public:
/// \brief Constructor.
Bucketize() : BaseOperator(kNameBucketize) { InitIOName({"input"}, {"output"}); }
// /// \brief Destructor.
// ~Bucketize() = default;
MIND_API_BASE_MEMBER(Bucketize);
};
AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_BUCKETIZE_H_

View File

@ -135,4 +135,5 @@ from .priority_replay_buffer import _prb_sample_op_cpu
from .priority_replay_buffer import _prb_update_op_cpu
from .right_shift import _right_shift_aicpu
from .tril import _tril_aicpu
from .bucketize import _bucketize_aicpu
from .triu import _triu_aicpu

View File

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bucketize op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
bucketize_op_info = AiCPURegOp("Bucketize") \
.fusion_type("OPAQUE") \
.attr("boundaries", "listFloat") \
.input(0, "input", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(bucketize_op_info)
def _bucketize_aicpu():
"""Bucketize aicpu register"""
return

View File

@ -1054,6 +1054,51 @@ class ReduceMin(_Reduce):
"""
class Bucketize(Primitive):
"""
Bucketizes 'input' based on 'boundaries'.
Args:
boundaries (list_float): A sorted list of floats gives the boundary of the buckets, and no default value.
Inputs:
- **input** (Tensor) - A tensor containing the search value(s).
Outputs:
Tensor, with the same shape as the input, and data type is int32.
Raises:
TypeError: If `boundaries` is not a listFloat.
TypeError: If `input` is not a Tensor.
Supported Platforms:
``CPU``
Examples:
>>> class Bucketize(nn.Cell):
... def __init__(self, boundaries):
... super().__init__()
... self.bucketize = op.Bucketize(boundaries=boundaries)
... def construct(self, input):
... return self.bucketize(input)
>>> input = Tensor(np.array([[3, 6, 9], [3, 6, 9]]).astype(np.int32))
>>> boundaries = list(np.array([1., 3., 5., 7., 9.]))
>>> net = Bucketize(boundaries)
>>> output = net(input)
>>> print(output)
[[2 3 5]
[2 3 5]]
"""
@prim_attr_register
def __init__(self, boundaries):
"""Initialize Bucketize"""
validator.check_value_type("boundaries", boundaries, [list], self.name)
for index, one_boundaries in enumerate(boundaries):
validator.check_value_type('boundaries[%d]' % index, one_boundaries, [float], self.name)
self.init_prim_io_names(inputs=['input'], outputs=['output'])
class ReduceProd(_Reduce):
"""
Reduces a dimension of a tensor by multiplying all elements in the dimension, by default. And also can

View File

@ -30,6 +30,7 @@ from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations.math_ops import Bucketize
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
@ -1863,6 +1864,10 @@ test_case_math_ops = [
'block': P.Real(),
'desc_inputs': [[2, 2]],
'skip': ['backward']}),
('Bucketize', {
'block': Bucketize(boundaries=[1., 3., 5., 7., 9.]),
'desc_inputs': [Tensor(np.array([[-1, 6, 8], [3, 6, 9]]).astype(np.float))],
'skip': ['backward']}),
]
test_case_nn_ops = [