ops CountNonZero supports dynamic shape feature

type: feature
reason: add codes to support dynamic shape for CountNonZero.

------

Signed-off-by: wang_ziqi <wangziqi4@huawei.com>
This commit is contained in:
wang_ziqi 2023-02-14 20:31:21 +08:00
parent 6f51d7e82b
commit 0f3e9f4c56
3 changed files with 28 additions and 8 deletions

View File

@ -115,14 +115,9 @@ int64_t IsNonZero(T val, std::false_type) {
return val != static_cast<T>(0) ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
}
bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
x_shape_ = inputs[0]->GetShapeVector();
y_shape_ = outputs[0]->GetShapeVector();
void CountNonZeroCpuKernelMod::ComputeCountParameter(void) {
int64_t input_rank = x_shape_.size();
std::vector<int64_t> dims = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dims"));
std::vector<int64_t> dims = dims_;
if (dims.size() == 0) {
for (int64_t i = 0; i < input_rank; ++i) {
@ -160,7 +155,12 @@ bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
}
// Assign values.
cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_;
}
bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
dims_ = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dims"));
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
@ -171,6 +171,17 @@ bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
return true;
}
int CountNonZeroCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
x_shape_ = inputs[0]->GetShapeVector();
y_shape_ = outputs[0]->GetShapeVector();
return KRET_OK;
}
template <typename T>
bool CountNonZeroCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
@ -209,6 +220,7 @@ bool CountNonZeroCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
if (data_nums == 1) {
ParallelLaunchAutoSearch(count_nonzero_scalar_shard, input_nums, this, &parallel_search_info_);
} else {
ComputeCountParameter();
ParallelLaunchAutoSearch(count_nonzero_shard, output_size, this, &parallel_search_info_);
}
return true;

View File

@ -43,18 +43,22 @@ class CountNonZeroCpuKernelMod : public NativeCpuKernelMod {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void ComputeCountParameter(void);
using CountNonZeroLaunchFunc = std::function<bool(CountNonZeroCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CountNonZeroLaunchFunc>> func_list_;
CountNonZeroLaunchFunc kernel_func_;
std::vector<int64_t> dims_;
float value_;
ShapeVector x_shape_;
ShapeVector y_shape_;

View File

@ -51,6 +51,10 @@ abstract::ShapePtr CountNonZeroInferShape(const PrimitivePtr &primitive,
auto input_rank = SizeToLong(input_shape.size());
std::vector<int64_t> dims = CheckAttrIntOrTuple(primitive->GetAttr("dims"));
if (IsDynamicRank(input_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
if (dims.size() == 0) {
output_shape = std::vector<int64_t>{};
return std::make_shared<abstract::Shape>(output_shape);