!45259 [MS][OP]bincount cpu support dynamic shape
Merge pull request !45259 from mengyuanli/ds_bincount
This commit is contained in:
commit
0cf56efb16
|
@ -17,22 +17,38 @@
|
||||||
|
|
||||||
#include "mindspore/core/ops/op_utils.h"
|
#include "mindspore/core/ops/op_utils.h"
|
||||||
|
|
||||||
namespace {
|
|
||||||
const size_t kOutputNum = 1;
|
|
||||||
const size_t kInputNum = 3;
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void BincountCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
bool BincountCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
input_arr_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
constexpr size_t input_num = 3;
|
||||||
input_size_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
constexpr size_t output_num = 1;
|
||||||
input_weights_sizes_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
kernel_name_ = base_operator->name();
|
||||||
dt_arr_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
|
||||||
dt_weights_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||||
output_sizes_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match.first) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
dt_arr_ = inputs[kIndex0]->GetDtype();
|
||||||
|
dt_weights_ = inputs[kIndex2]->GetDtype();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BincountCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
input_arr_sizes_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||||
|
input_size_sizes_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||||
|
input_weights_sizes_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||||
|
output_sizes_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||||
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T_in, typename T_out>
|
template <typename T_in, typename T_out>
|
||||||
|
@ -72,8 +88,6 @@ void BincountCpuKernelMod::SetMap() {
|
||||||
|
|
||||||
bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
bool BincountCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
|
||||||
const size_t array_num = SizeOf(input_arr_sizes_);
|
const size_t array_num = SizeOf(input_arr_sizes_);
|
||||||
const size_t weights_num = SizeOf(input_weights_sizes_);
|
const size_t weights_num = SizeOf(input_weights_sizes_);
|
||||||
if (array_num != weights_num) {
|
if (array_num != weights_num) {
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
#include "plugin/factory/ms_factory.h"
|
#include "plugin/factory/ms_factory.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
|
@ -29,11 +30,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class BincountCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
class BincountCpuKernelMod : public NativeCpuKernelMod {
|
||||||
public:
|
public:
|
||||||
BincountCpuKernelMod() = default;
|
BincountCpuKernelMod() = default;
|
||||||
~BincountCpuKernelMod() override = default;
|
~BincountCpuKernelMod() override = default;
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,15 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
|
||||||
auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
auto arr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
||||||
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||||
|
// support dynamic rank
|
||||||
if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) {
|
if (IsDynamicRank(arr_shape) || IsDynamicRank(size_shape) || IsDynamicRank(w_shape)) {
|
||||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// support dynamic shape
|
||||||
|
if (IsDynamic(arr_shape) || IsDynamic(size_shape) || IsDynamic(w_shape)) {
|
||||||
|
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
|
||||||
|
return std::make_shared<abstract::Shape>(shape_out);
|
||||||
}
|
}
|
||||||
CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name());
|
CheckAndConvertUtils::CheckInteger("size", size_shape.size(), kEqual, 0, primitive->name());
|
||||||
auto size_value_ptr = input_args[kInputIndex1]->BuildValue();
|
auto size_value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||||
|
@ -45,9 +52,8 @@ abstract::ShapePtr BincountInferShape(const PrimitivePtr &primitive, const std::
|
||||||
(void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name());
|
(void)CheckAndConvertUtils::CheckPositiveVectorExcludeZero("size", out_shape, primitive->name());
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
} else {
|
} else {
|
||||||
std::vector<int64_t> out_shape;
|
ShapeVector shape_out{abstract::Shape::kShapeDimAny};
|
||||||
(void)out_shape.emplace_back(-1);
|
return std::make_shared<abstract::Shape>(shape_out);
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr BincountInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
|
Loading…
Reference in New Issue