diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.cc index 2fda3afcaa5..4b1814e5ad6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.cc @@ -17,22 +17,38 @@ #include "mindspore/core/ops/op_utils.h" -namespace { -const size_t kOutputNum = 1; -const size_t kInputNum = 3; -} // namespace - namespace mindspore { namespace kernel { -void BincountCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - input_arr_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0); - input_size_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1); - input_weights_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2); - dt_arr_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - dt_weights_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2); - output_sizes_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); +bool BincountCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + constexpr size_t input_num = 3; + constexpr size_t output_num = 1; + kernel_name_ = base_operator->name(); + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + dt_arr_ = inputs[kIndex0]->GetDtype(); + dt_weights_ = inputs[kIndex2]->GetDtype(); + return true; +} + +int BincountCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) { + return ret; + } + input_arr_sizes_ = inputs[kIndex0]->GetDeviceShapeAdaptively(); + input_size_sizes_ = inputs[kIndex1]->GetDeviceShapeAdaptively(); + input_weights_sizes_ = inputs[kIndex2]->GetDeviceShapeAdaptively(); + output_sizes_ = outputs[kIndex0]->GetDeviceShapeAdaptively(); + return KRET_OK; } template @@ -72,8 +88,6 @@ void BincountCpuKernelMod::SetMap() { bool BincountCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspaces, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); const size_t array_num = SizeOf(input_arr_sizes_); const size_t weights_num = SizeOf(input_weights_sizes_); if (array_num != weights_num) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.h index 1ef477ffdc7..35078a7c8dd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bincount_cpu_kernel.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" #include "utils/ms_utils.h" @@ -29,11 +30,17 @@ namespace mindspore { namespace kernel { -class BincountCpuKernelMod : public DeprecatedNativeCpuKernelMod { +class BincountCpuKernelMod : public NativeCpuKernelMod { public: BincountCpuKernelMod() = default; ~BincountCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + bool Launch(const std::vector &inputs, const std::vector &workspaces, const std::vector &outputs) override; diff --git a/mindspore/core/ops/bincount.cc b/mindspore/core/ops/bincount.cc index 61adf7a11ea..116ae8382e4 100644 --- a/mindspore/core/ops/bincount.cc +++ b/mindspore/core/ops/bincount.cc @@ -30,8 +30,15 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std:: auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape]; auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; + // support dynamic rank if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) { - return std::make_shared(std::vector{-2}); + return std::make_shared(ShapeVector({abstract::Shape::kShapeRankAny})); + } + + // support dynamic shape + if (IsDynamic(arr_shape) || IsDynamic(size_shape) || IsDynamic(w_shape)) { + ShapeVector shape_out{abstract::Shape::kShapeDimAny}; + return std::make_shared(shape_out); } CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name()); auto size_value_ptr = input_args[kInputIndex1]->BuildValue(); @@ -45,9 +52,8 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std:: (void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name()); return std::make_shared(out_shape); } else { - std::vector out_shape; - (void)out_shape.emplace_back(-1); - return std::make_shared(out_shape); + ShapeVector shape_out{abstract::Shape::kShapeDimAny}; + return std::make_shared(shape_out); } } TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector &input_args) {