!45259 [MS][OP]bincount cpu support dynamic shape

Merge pull request !45259 from mengyuanli/ds_bincount
This commit is contained in:
i-robot 2022-11-22 12:44:27 +00:00 committed by Gitee
commit 0cf56efb16
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 49 additions and 22 deletions

View File

@ -17,22 +17,38 @@
#include "mindspore/core/ops/op_utils.h" #include "mindspore/core/ops/op_utils.h"
namespace {
const size_t kOutputNum = 1;
const size_t kInputNum = 3;
} // namespace
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void BincountCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { bool BincountCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
MS_EXCEPTION_IF_NULL(kernel_node); const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); MS_EXCEPTION_IF_NULL(base_operator);
input_arr_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0); constexpr size_t input_num = 3;
input_size_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1); constexpr size_t output_num = 1;
input_weights_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2); kernel_name_ = base_operator->name();
dt_arr_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
dt_weights_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
output_sizes_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dt_arr_ = inputs[kIndex0]->GetDtype();
dt_weights_ = inputs[kIndex2]->GetDtype();
return true;
}
int BincountCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
input_arr_sizes_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
input_size_sizes_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
input_weights_sizes_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
output_sizes_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
return KRET_OK;
} }
template <typename T_in, typename T_out> template <typename T_in, typename T_out>
@ -72,8 +88,6 @@ void BincountCpuKernelMod::SetMap() {
bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
const size_t array_num = SizeOf(input_arr_sizes_); const size_t array_num = SizeOf(input_arr_sizes_);
const size_t weights_num = SizeOf(input_weights_sizes_); const size_t weights_num = SizeOf(input_weights_sizes_);
if (array_num != weights_num) { if (array_num != weights_num) {

View File

@ -21,6 +21,7 @@
#include <memory> #include <memory>
#include <map> #include <map>
#include <string> #include <string>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h" #include "plugin/factory/ms_factory.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
@ -29,11 +30,17 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class BincountCpuKernelMod : public DeprecatedNativeCpuKernelMod { class BincountCpuKernelMod : public NativeCpuKernelMod {
public: public:
BincountCpuKernelMod() = default; BincountCpuKernelMod() = default;
~BincountCpuKernelMod() override = default; ~BincountCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;

View File

@ -30,8 +30,15 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape]; auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
// support dynamic rank
if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) { if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2}); return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
// support dynamic shape
if (IsDynamic(arr_shape) || IsDynamic(size_shape) || IsDynamic(w_shape)) {
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
return std::make_shared<abstract::Shape>(shape_out);
} }
CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name()); CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name());
auto size_value_ptr = input_args[kInputIndex1]->BuildValue(); auto size_value_ptr = input_args[kInputIndex1]->BuildValue();
@ -45,9 +52,8 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
(void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name()); (void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name());
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} else { } else {
std::vector<int64_t> out_shape; ShapeVector shape_out{abstract::Shape::kShapeDimAny};
(void)out_shape.emplace_back(-1); return std::make_shared<abstract::Shape>(shape_out);
return std::make_shared<abstract::Shape>(out_shape);
} }
} }
TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {