!45259 [MS][OP]bincount cpu support dynamic shape

Merge pull request !45259 from mengyuanli/ds_bincount
This commit is contained in:
i-robot 2022-11-22 12:44:27 +00:00 committed by Gitee
commit 0cf56efb16
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 49 additions and 22 deletions

View File

@ -17,22 +17,38 @@
#include "mindspore/core/ops/op_utils.h"
namespace {
const size_t kOutputNum = 1;
const size_t kInputNum = 3;
} // namespace
namespace mindspore {
namespace kernel {
void BincountCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_arr_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
input_size_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
input_weights_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
dt_arr_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
dt_weights_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
output_sizes_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
bool BincountCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
constexpr size_t input_num = 3;
constexpr size_t output_num = 1;
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dt_arr_ = inputs[kIndex0]->GetDtype();
dt_weights_ = inputs[kIndex2]->GetDtype();
return true;
}
int BincountCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
input_arr_sizes_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
input_size_sizes_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
input_weights_sizes_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
output_sizes_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
return KRET_OK;
}
template <typename T_in, typename T_out>
@ -72,8 +88,6 @@ void BincountCpuKernelMod::SetMap() {
bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
const size_t array_num = SizeOf(input_arr_sizes_);
const size_t weights_num = SizeOf(input_weights_sizes_);
if (array_num != weights_num) {

View File

@ -21,6 +21,7 @@
#include <memory>
#include <map>
#include <string>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "utils/ms_utils.h"
@ -29,11 +30,17 @@
namespace mindspore {
namespace kernel {
class BincountCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class BincountCpuKernelMod : public NativeCpuKernelMod {
public:
BincountCpuKernelMod() = default;
~BincountCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) override;

View File

@ -30,8 +30,15 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
// support dynamic rank
if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
// support dynamic shape
if (IsDynamic(arr_shape) || IsDynamic(size_shape) || IsDynamic(w_shape)) {
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
return std::make_shared<abstract::Shape>(shape_out);
}
CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name());
auto size_value_ptr = input_args[kInputIndex1]->BuildValue();
@ -45,9 +52,8 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
(void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name());
return std::make_shared<abstract::Shape>(out_shape);
} else {
std::vector<int64_t> out_shape;
(void)out_shape.emplace_back(-1);
return std::make_shared<abstract::Shape>(out_shape);
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
return std::make_shared<abstract::Shape>(shape_out);
}
}
TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {