forked from mindspore-Ecosystem/mindspore
ops CountNonZero supports dynamic shape feature
type: feature reason: add codes to support dynamic shape for CountNonZero. ------ Signed-off-by: wang_ziqi <wangziqi4@huawei.com>
This commit is contained in:
parent
6f51d7e82b
commit
0f3e9f4c56
|
@ -115,14 +115,9 @@ int64_t IsNonZero(T val, std::false_type) {
|
|||
return val != static_cast<T>(0) ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
}
|
||||
|
||||
bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
x_shape_ = inputs[0]->GetShapeVector();
|
||||
y_shape_ = outputs[0]->GetShapeVector();
|
||||
|
||||
void CountNonZeroCpuKernelMod::ComputeCountParameter(void) {
|
||||
int64_t input_rank = x_shape_.size();
|
||||
std::vector<int64_t> dims = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dims"));
|
||||
std::vector<int64_t> dims = dims_;
|
||||
|
||||
if (dims.size() == 0) {
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
|
@ -160,7 +155,12 @@ bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
}
|
||||
// Assign values.
|
||||
cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_;
|
||||
}
|
||||
|
||||
bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
dims_ = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dims"));
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
|
@ -171,6 +171,17 @@ bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
return true;
|
||||
}
|
||||
|
||||
int CountNonZeroCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
x_shape_ = inputs[0]->GetShapeVector();
|
||||
y_shape_ = outputs[0]->GetShapeVector();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CountNonZeroCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
@ -209,6 +220,7 @@ bool CountNonZeroCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
if (data_nums == 1) {
|
||||
ParallelLaunchAutoSearch(count_nonzero_scalar_shard, input_nums, this, ¶llel_search_info_);
|
||||
} else {
|
||||
ComputeCountParameter();
|
||||
ParallelLaunchAutoSearch(count_nonzero_shard, output_size, this, ¶llel_search_info_);
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -43,18 +43,22 @@ class CountNonZeroCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
void ComputeCountParameter(void);
|
||||
using CountNonZeroLaunchFunc = std::function<bool(CountNonZeroCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
static std::vector<std::pair<KernelAttr, CountNonZeroLaunchFunc>> func_list_;
|
||||
CountNonZeroLaunchFunc kernel_func_;
|
||||
std::vector<int64_t> dims_;
|
||||
float value_;
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector y_shape_;
|
||||
|
|
|
@ -51,6 +51,10 @@ abstract::ShapePtr CountNonZeroInferShape(const PrimitivePtr &primitive,
|
|||
auto input_rank = SizeToLong(input_shape.size());
|
||||
std::vector<int64_t> dims = CheckAttrIntOrTuple(primitive->GetAttr("dims"));
|
||||
|
||||
if (IsDynamicRank(input_shape)) {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||
}
|
||||
|
||||
if (dims.size() == 0) {
|
||||
output_shape = std::vector<int64_t>{};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
|
|
Loading…
Reference in New Issue