forked from mindspore-Ecosystem/mindspore
cpu_0329
This commit is contained in:
parent
dad943d6f6
commit
8f9226666f
|
@ -139,6 +139,24 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const
|
|||
input_shape_attr_list->add_i(shape);
|
||||
}
|
||||
(*node_attr)[attr_name] = input_shape_attr;
|
||||
} else if (type == "listFloat") {
|
||||
std::vector<float> attr_value;
|
||||
auto value_type = value->type();
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
auto value_type_str = value_type->ToString();
|
||||
if (value_type_str == "float") {
|
||||
auto data = GetValue<float>(value);
|
||||
attr_value.push_back(data);
|
||||
} else {
|
||||
attr_value = GetValue<std::vector<float>>(value);
|
||||
}
|
||||
mindspore::AttrValue input_shape_attr;
|
||||
mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_attr_list);
|
||||
for (const auto shape : attr_value) {
|
||||
input_shape_attr_list->add_f(shape);
|
||||
}
|
||||
(*node_attr)[attr_name] = input_shape_attr;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "type: " << type << "not support";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/bucketize_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const size_t kOutputNum = 1;
|
||||
const size_t kInputNum = 1;
|
||||
const size_t kParallelDataNumSameShape = 64 * 1024;
|
||||
const size_t kParallelDataNumSameShapeMid = 35 * 1024;
|
||||
} // namespace
|
||||
|
||||
void BucketizeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
boundaries_ = common::AnfAlgo::GetNodeAttr<std::vector<float>>(kernel_node, "boundaries");
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool BucketizeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /* workspace */,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
if (dtype_ != kNumberTypeInt32 && dtype_ != kNumberTypeInt64 && dtype_ != kNumberTypeFloat32 &&
|
||||
dtype_ != kNumberTypeFloat64) {
|
||||
MS_LOG(EXCEPTION) << "Input data type must int32 or int64 or float32 or float64, but got data type." << dtype_;
|
||||
return false;
|
||||
}
|
||||
size_t input_sizes = input_shape_.size();
|
||||
size_t output_sizes = output_shape_.size();
|
||||
if (input_sizes != output_sizes) {
|
||||
MS_LOG(EXCEPTION) << "The tensor shape of input need be same with output.";
|
||||
return false;
|
||||
}
|
||||
// BucketizeCompute(inputs, outputs);
|
||||
switch (dtype_) {
|
||||
case kNumberTypeInt32:
|
||||
return BucketizeCompute<int32_t>(inputs, outputs);
|
||||
case kNumberTypeInt64:
|
||||
return BucketizeCompute<int64_t>(inputs, outputs);
|
||||
case kNumberTypeFloat32:
|
||||
return BucketizeCompute<float>(inputs, outputs);
|
||||
case kNumberTypeFloat64:
|
||||
return BucketizeCompute<double>(inputs, outputs);
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported data type.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BucketizeCpuKernelMod::BucketizeCompute(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_data = reinterpret_cast<int32_t *>(outputs[0]->addr);
|
||||
size_t data_num_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
|
||||
std::vector<float> boundaries_data = boundaries_;
|
||||
std::sort(boundaries_data.begin(), boundaries_data.end());
|
||||
if (data_num_ >= kParallelDataNumSameShape) {
|
||||
auto sharder_bucketize = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
|
||||
output_data[i] = first_bigger_it - boundaries_data.begin();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(sharder_bucketize, data_num_, this, ¶llel_search_info_);
|
||||
} else {
|
||||
for (size_t i = 0; i < data_num_; i++) {
|
||||
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
|
||||
output_data[i] = first_bigger_it - boundaries_data.begin();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BucketizeCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Bucketize, BucketizeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BucketizeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
BucketizeCpuKernelMod() = default;
|
||||
~BucketizeCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
template <typename T>
|
||||
bool BucketizeCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<float> boundaries_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -739,6 +739,7 @@ GVAR_DEF(PrimitivePtr, kPrimInv, std::make_shared<Primitive>("Inv"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimBitwiseOr, std::make_shared<Primitive>("BitwiseOr"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBitwiseAnd, std::make_shared<Primitive>("BitwiseAnd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBitwiseXor, std::make_shared<Primitive>("BitwiseXor"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared<Primitive>("Bucketize"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared<Primitive>("Einsum"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"));
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/bucketize.h"
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr BucketizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto out_shape = x_shape;
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr BucketizeInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input", x_type, common_valid_types, prim_name);
|
||||
return std::make_shared<TensorType>(kInt32);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(Bucketize, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const size_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = BucketizeInferType(primitive, input_args);
|
||||
auto infer_shape = BucketizeInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Bucketize, prim::kPrimBucketize, BucketizeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_BUCKETIZE_H_
|
||||
#define MINDSPORE_CORE_OPS_BUCKETIZE_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBucketize = "Bucketize";
|
||||
/// \brief Bucketizes 'input' based on 'boundaries'.
|
||||
/// Refer to Python API @ref mindspore.ops.Bucketize for more details.
|
||||
class MIND_API Bucketize : public BaseOperator {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Bucketize() : BaseOperator(kNameBucketize) { InitIOName({"input"}, {"output"}); }
|
||||
// /// \brief Destructor.
|
||||
// ~Bucketize() = default;
|
||||
MIND_API_BASE_MEMBER(Bucketize);
|
||||
};
|
||||
|
||||
AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_BUCKETIZE_H_
|
|
@ -135,4 +135,5 @@ from .priority_replay_buffer import _prb_sample_op_cpu
|
|||
from .priority_replay_buffer import _prb_update_op_cpu
|
||||
from .right_shift import _right_shift_aicpu
|
||||
from .tril import _tril_aicpu
|
||||
from .bucketize import _bucketize_aicpu
|
||||
from .triu import _triu_aicpu
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bucketize op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
bucketize_op_info = AiCPURegOp("Bucketize") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("boundaries", "listFloat") \
|
||||
.input(0, "input", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bucketize_op_info)
|
||||
def _bucketize_aicpu():
|
||||
"""Bucketize aicpu register"""
|
||||
return
|
|
@ -1054,6 +1054,51 @@ class ReduceMin(_Reduce):
|
|||
"""
|
||||
|
||||
|
||||
class Bucketize(Primitive):
|
||||
"""
|
||||
Bucketizes 'input' based on 'boundaries'.
|
||||
|
||||
Args:
|
||||
boundaries (list_float): A sorted list of floats gives the boundary of the buckets, and no default value.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - A tensor containing the search value(s).
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same shape as the input, and data type is int32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `boundaries` is not a listFloat.
|
||||
TypeError: If `input` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> class Bucketize(nn.Cell):
|
||||
... def __init__(self, boundaries):
|
||||
... super().__init__()
|
||||
... self.bucketize = op.Bucketize(boundaries=boundaries)
|
||||
... def construct(self, input):
|
||||
... return self.bucketize(input)
|
||||
>>> input = Tensor(np.array([[3, 6, 9], [3, 6, 9]]).astype(np.int32))
|
||||
>>> boundaries = list(np.array([1., 3., 5., 7., 9.]))
|
||||
>>> net = Bucketize(boundaries)
|
||||
>>> output = net(input)
|
||||
>>> print(output)
|
||||
[[2 3 5]
|
||||
[2 3 5]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, boundaries):
|
||||
"""Initialize Bucketize"""
|
||||
validator.check_value_type("boundaries", boundaries, [list], self.name)
|
||||
for index, one_boundaries in enumerate(boundaries):
|
||||
validator.check_value_type('boundaries[%d]' % index, one_boundaries, [float], self.name)
|
||||
self.init_prim_io_names(inputs=['input'], outputs=['output'])
|
||||
|
||||
|
||||
class ReduceProd(_Reduce):
|
||||
"""
|
||||
Reduces a dimension of a tensor by multiplying all elements in the dimension, by default. And also can
|
||||
|
|
|
@ -30,6 +30,7 @@ from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations.math_ops import Bucketize
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import Tril
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
|
@ -1863,6 +1864,10 @@ test_case_math_ops = [
|
|||
'block': P.Real(),
|
||||
'desc_inputs': [[2, 2]],
|
||||
'skip': ['backward']}),
|
||||
('Bucketize', {
|
||||
'block': Bucketize(boundaries=[1., 3., 5., 7., 9.]),
|
||||
'desc_inputs': [Tensor(np.array([[-1, 6, 8], [3, 6, 9]]).astype(np.float))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_nn_ops = [
|
||||
|
|
Loading…
Reference in New Issue