!45259 [MS][OP]bincount cpu support dynamic shape
Merge pull request !45259 from mengyuanli/ds_bincount
This commit is contained in:
commit
0cf56efb16
|
@ -17,22 +17,38 @@
|
|||
|
||||
#include "mindspore/core/ops/op_utils.h"
|
||||
|
||||
namespace {
|
||||
const size_t kOutputNum = 1;
|
||||
const size_t kInputNum = 3;
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void BincountCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_arr_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
input_size_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
||||
input_weights_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
dt_arr_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
dt_weights_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
|
||||
output_sizes_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
bool BincountCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
constexpr size_t input_num = 3;
|
||||
constexpr size_t output_num = 1;
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match.first) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
dt_arr_ = inputs[kIndex0]->GetDtype();
|
||||
dt_weights_ = inputs[kIndex2]->GetDtype();
|
||||
return true;
|
||||
}
|
||||
|
||||
int BincountCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
input_arr_sizes_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
input_size_sizes_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||
input_weights_sizes_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||
output_sizes_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
|
@ -72,8 +88,6 @@ void BincountCpuKernelMod::SetMap() {
|
|||
|
||||
bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
const size_t array_num = SizeOf(input_arr_sizes_);
|
||||
const size_t weights_num = SizeOf(input_weights_sizes_);
|
||||
if (array_num != weights_num) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -29,11 +30,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BincountCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class BincountCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
BincountCpuKernelMod() = default;
|
||||
~BincountCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
|
|
@ -30,8 +30,15 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
|
|||
auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
||||
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||
// support dynamic rank
|
||||
if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||
}
|
||||
|
||||
// support dynamic shape
|
||||
if (IsDynamic(arr_shape) || IsDynamic(size_shape) || IsDynamic(w_shape)) {
|
||||
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
|
||||
return std::make_shared<abstract::Shape>(shape_out);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name());
|
||||
auto size_value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||
|
@ -45,9 +52,8 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
|
|||
(void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name());
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
std::vector<int64_t> out_shape;
|
||||
(void)out_shape.emplace_back(-1);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
|
||||
return std::make_shared<abstract::Shape>(shape_out);
|
||||
}
|
||||
}
|
||||
TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
Loading…
Reference in New Issue