add CountNonZero
This commit is contained in:
parent
810019e13c
commit
41742595c3
|
@ -0,0 +1,256 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/count_nonzero_cpu_kernel.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <numeric>
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/mul_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const size_t kCountNonZeroInputsNum = 1;
|
||||
const size_t kCountNonZeroOutputsNum = 1;
|
||||
|
||||
std::vector<int64_t> cnz_dims;
|
||||
std::vector<int64_t> cnz_transposed_shape;
|
||||
int64_t cnz_stride;
|
||||
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
// Class def of ParallelIterator.
|
||||
class ParallelIterator {
|
||||
public:
|
||||
ParallelIterator(const std::vector<int64_t> &transposed_shape, const std::vector<int64_t> &dims,
|
||||
const std::vector<int64_t> &input_shape);
|
||||
~ParallelIterator() = default;
|
||||
void Next();
|
||||
void Set(int64_t pos);
|
||||
inline int64_t Get() const { return _pos; }
|
||||
|
||||
private:
|
||||
int64_t _dimension{0};
|
||||
std::vector<int64_t> _coord;
|
||||
std::vector<int64_t> _shape;
|
||||
std::vector<int64_t> _strides;
|
||||
std::vector<int64_t> _back_strides;
|
||||
std::vector<int64_t> _dims;
|
||||
int64_t _pos{0};
|
||||
};
|
||||
|
||||
ParallelIterator::ParallelIterator(const std::vector<int64_t> &transposed_shape, const std::vector<int64_t> &dims,
|
||||
const std::vector<int64_t> &input_shape)
|
||||
: _dimension(transposed_shape.size()),
|
||||
_coord(transposed_shape.size(), 0),
|
||||
_shape(transposed_shape),
|
||||
_strides(transposed_shape.size(), 1),
|
||||
_back_strides(transposed_shape.size(), 1),
|
||||
_dims(dims),
|
||||
_pos(0) {
|
||||
std::vector<int64_t> strides(_dimension, 1);
|
||||
for (int64_t i = _dimension - 2; i >= 0; --i) {
|
||||
strides[i] = strides[i + 1] * input_shape[i + 1];
|
||||
}
|
||||
for (int64_t i = _dimension - 1; i >= 0; --i) {
|
||||
_strides[i] = strides[_dims[i]];
|
||||
_back_strides[i] = (_shape[i] - 1) * _strides[i];
|
||||
}
|
||||
}
|
||||
void ParallelIterator::Set(int64_t pos) {
|
||||
for (int64_t i = _dimension - 1; i >= 0 && pos != 0; --i) {
|
||||
_coord[i] = pos % _shape[i];
|
||||
_pos += _coord[i] * _strides[i];
|
||||
pos /= _shape[i];
|
||||
}
|
||||
}
|
||||
void ParallelIterator::Next() {
|
||||
for (int64_t i = _dimension - 1; i >= 0; --i) {
|
||||
if (_coord[i] + 1 == _shape[i]) {
|
||||
_coord[i] = 0;
|
||||
_pos -= _back_strides[i];
|
||||
} else {
|
||||
_coord[i]++;
|
||||
_pos += _strides[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <class T>
|
||||
struct is_complex_t : std::false_type {};
|
||||
template <class T>
|
||||
struct is_complex_t<std::complex<T>> : std::true_type {};
|
||||
|
||||
template <class T>
|
||||
int64_t IsNonZero(T val, std::true_type) {
|
||||
return val.real() != 0 || val.imag() != 0 ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
}
|
||||
template <class T>
|
||||
int64_t IsNonZero(T val, std::false_type) {
|
||||
return val != static_cast<T>(0) ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
}
|
||||
|
||||
bool CountNonZeroCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
x_shape_ = inputs[0]->GetShapeVector();
|
||||
y_shape_ = outputs[0]->GetShapeVector();
|
||||
|
||||
int64_t input_rank = x_shape_.size();
|
||||
std::vector<int64_t> dims = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dims"));
|
||||
|
||||
if (dims.size() == 0) {
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
}
|
||||
// Check dims in [-x_rank, x_rank)
|
||||
std::for_each(dims.begin(), dims.end(), [input_rank](auto &dim) { dim = dim < 0 ? dim + input_rank : dim; });
|
||||
std::sort(dims.begin(), dims.end());
|
||||
dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
|
||||
|
||||
int64_t stride_ = static_cast<int64_t>(1);
|
||||
std::vector<int64_t> axes_(input_rank);
|
||||
int64_t j = static_cast<int64_t>(0), k = static_cast<int64_t>(0);
|
||||
for (int64_t i = 0; i < input_rank; i++) {
|
||||
if (j == static_cast<int64_t>(dims.size()) || i != dims[j]) {
|
||||
axes_[k] = i;
|
||||
++k;
|
||||
} else {
|
||||
stride_ *= x_shape_[i];
|
||||
++j;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &dim : dims) {
|
||||
axes_[k] = dim;
|
||||
++k;
|
||||
}
|
||||
// Calculate transposed_shape using axes.
|
||||
// For example, if input_shape = (3, 4, 5, 6, 7), axes = [0, 2, 4, 1, 3],
|
||||
// then transposed_shape = (3, 5, 7) + (4, 6)
|
||||
std::vector<int64_t> transposed_shape_(input_rank);
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
transposed_shape_[i] = x_shape_[axes_[i]];
|
||||
}
|
||||
// Assign values.
|
||||
cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_;
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CountNonZeroCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCountNonZeroInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCountNonZeroOutputsNum, kernel_name_);
|
||||
auto *x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *y = reinterpret_cast<int64_t *>(outputs[0]->addr);
|
||||
auto input_shape = x_shape_;
|
||||
int64_t input_nums = static_cast<int64_t>(inputs[0]->size / sizeof(T));
|
||||
int64_t data_nums = static_cast<int64_t>(outputs[0]->size / sizeof(int64_t));
|
||||
|
||||
if (y_shape_.size() == 0) {
|
||||
(void)y_shape_.insert(y_shape_.begin(), 1);
|
||||
}
|
||||
auto output_size = SizeOf(y_shape_);
|
||||
|
||||
auto count_nonzero_scalar_shard = [&](int64_t start, int64_t end) {
|
||||
y[0] = static_cast<int64_t>(0);
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
y[0] += IsNonZero<T>(x[i], is_complex_t<T>{});
|
||||
}
|
||||
};
|
||||
|
||||
auto count_nonzero_shard = [&](int64_t start, int64_t end) {
|
||||
ParallelIterator iter(cnz_transposed_shape, cnz_dims, input_shape);
|
||||
iter.Set(start * cnz_stride);
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
int64_t reduce_initial = static_cast<int64_t>(0);
|
||||
for (int64_t j = 0; j < cnz_stride; ++j) {
|
||||
reduce_initial += IsNonZero<T>(x[iter.Get()], is_complex_t<T>{});
|
||||
iter.Next();
|
||||
}
|
||||
y[i] = reduce_initial;
|
||||
}
|
||||
};
|
||||
if (data_nums == 1) {
|
||||
ParallelLaunchAutoSearch(count_nonzero_scalar_shard, input_nums, this, ¶llel_search_info_);
|
||||
} else {
|
||||
ParallelLaunchAutoSearch(count_nonzero_shard, output_size, this, ¶llel_search_info_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CountNonZeroCpuKernelMod::CountNonZeroLaunchFunc>>
|
||||
CountNonZeroCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64),
|
||||
&CountNonZeroCpuKernelMod::LaunchKernel<complex128>}};
|
||||
|
||||
std::vector<KernelAttr> CountNonZeroCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CountNonZeroLaunchFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CountNonZero, CountNonZeroCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_COUNT_NONZERO_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_COUNT_NONZERO_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/arithmetic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CountNonZeroCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CountNonZeroCpuKernelMod() = default;
|
||||
~CountNonZeroCpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
using CountNonZeroLaunchFunc = std::function<bool(CountNonZeroCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
static std::vector<std::pair<KernelAttr, CountNonZeroLaunchFunc>> func_list_;
|
||||
CountNonZeroLaunchFunc kernel_func_;
|
||||
float value_;
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector y_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_COUNT_NONZERO_CPU_KERNEL_H_
|
|
@ -156,6 +156,7 @@ constexpr auto kSelfAdjointEig = "SelfAdjointEig";
|
|||
|
||||
// Arrays
|
||||
constexpr auto kLeftShift = "LeftShift";
|
||||
constexpr auto kCountNonZero = "CountNonZero";
|
||||
constexpr auto kFillDiagonal = "FillDiagonal";
|
||||
constexpr auto kSegmentMax = "SegmentMax";
|
||||
constexpr auto kSegmentSum = "SegmentSum";
|
||||
|
@ -536,6 +537,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_redu
|
|||
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimParallelConcat, std::make_shared<Primitive>(kParallelConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCountNonZero, std::make_shared<Primitive>("CountNonZero"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared<Primitive>(kFlattenConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSqueezeV3, std::make_shared<Primitive>("SqueezeV3"));
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ops/count_nonzero.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "mindspore/core/utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
std::vector<int64_t> CheckAttrIntOrTuple(const ValuePtr &attr) {
|
||||
std::vector<int64_t> result{};
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (attr->isa<ValueTuple>() || attr->isa<ValueList>()) {
|
||||
result = GetValue<std::vector<int64_t>>(attr);
|
||||
} else {
|
||||
auto attr_val = GetValue<int64_t>(attr);
|
||||
(void)result.insert(result.begin(), 1, attr_val);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
abstract::ShapePtr CountNonZeroInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto output_shape = input_shape;
|
||||
auto input_rank = SizeToLong(input_shape.size());
|
||||
std::vector<int64_t> dims = CheckAttrIntOrTuple(primitive->GetAttr("dims"));
|
||||
|
||||
if (dims.size() == 0) {
|
||||
output_shape = std::vector<int64_t>{};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
int64_t origin_dims = dims[i];
|
||||
if (dims[i] < 0) {
|
||||
dims[i] += input_rank;
|
||||
}
|
||||
string dims_name = "dims[" + std::to_string(i) + "]";
|
||||
int64_t int_input_rank = static_cast<int64_t>(input_rank);
|
||||
if (input_rank == 0) {
|
||||
if (dims[i] != 0 && dims[i] != -1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the dims[" << i << "] is out of range[-1, 0].";
|
||||
}
|
||||
} else if (int_input_rank > 0) {
|
||||
CheckAndConvertUtils::CheckInRange(dims_name, origin_dims, kIncludeLeft, {-int_input_rank, int_input_rank},
|
||||
"CountNonZero");
|
||||
}
|
||||
}
|
||||
if (input_rank == 0) {
|
||||
output_shape = std::vector<int64_t>{};
|
||||
primitive->EraseAttr("dims");
|
||||
primitive->set_attr("dims", MakeValue(std::vector<int64_t>{}));
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
output_shape[dims[i]] = -1;
|
||||
}
|
||||
|
||||
for (std::vector<int64_t>::iterator iter = output_shape.begin(); iter != output_shape.end(); ++iter) {
|
||||
if (*iter == -1) {
|
||||
iter = output_shape.erase(iter);
|
||||
iter -= 1;
|
||||
}
|
||||
}
|
||||
std::set<int64_t> dim_set(dims.begin(), dims.end());
|
||||
if (dim_set.size() != dims.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the dims contain duplicates.";
|
||||
} else {
|
||||
std::vector<int64_t> dims_processed(dim_set.begin(), dim_set.end());
|
||||
primitive->EraseAttr("dims");
|
||||
primitive->set_attr("dims", MakeValue(dims_processed));
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr CountNonZeroInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr input_x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_x_type);
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, prim->name());
|
||||
auto y_type = std::make_shared<TensorType>(kInt64);
|
||||
return y_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CountNonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = CountNonZeroInferShape(primitive, input_args);
|
||||
auto infer_type = CountNonZeroInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CountNonZero, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CountNonZero, prim::kPrimCountNonZero, CountNonZeroInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_COUNT_NONZERO_H_
|
||||
#define MINDSPORE_CORE_OPS_COUNT_NONZERO_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCountNonZero = "CountNonZero";
|
||||
class MIND_API CountNonZero : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CountNonZero);
|
||||
CountNonZero() : BaseOperator(kNameCountNonZero) { InitIOName({"x"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr CountNonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_COUNT_NONZERO_H_
|
|
@ -47,6 +47,7 @@ from mindspore.ops.operations.array_ops import Im2Col
|
|||
from mindspore.ops.operations.array_ops import Col2Im
|
||||
from mindspore.ops.operations.array_ops import StridedSliceV2
|
||||
from mindspore.ops.operations.array_ops import MaskedScatter
|
||||
from mindspore.ops.operations.array_ops import CountNonZero
|
||||
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
||||
from mindspore.ops.operations.random_ops import LogNormalReverse
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -163,6 +164,16 @@ def get_bprop_masked_scatter(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(CountNonZero)
|
||||
def get_bprop_countnonzero(self):
|
||||
"""Grad definition for CountNonZero"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (zeros_like(x),)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Mvlgamma)
|
||||
def get_bprop_mvlgamma(self):
|
||||
"""Grad definition for Mvlgamma"""
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CountNonZero op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
count_nonzero_op_info = AiCPURegOp("CountNonZero") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("dims", "listInt")\
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(count_nonzero_op_info)
|
||||
def _count_nonzero_aicpu():
|
||||
"""CountNonZero AiCPU register"""
|
||||
return
|
|
@ -43,6 +43,7 @@ from mindspore.ops.operations.array_ops import (
|
|||
Expand,
|
||||
Lstsq,
|
||||
Mvlgamma,
|
||||
CountNonZero,
|
||||
)
|
||||
from mindspore.ops.operations.array_ops import TensorScatterElements
|
||||
from mindspore.common import Tensor
|
||||
|
@ -5160,6 +5161,45 @@ def mvlgamma(input, p):
|
|||
return mvlgamma_op(input)
|
||||
|
||||
|
||||
def count_nonzero(x, dims=None):
|
||||
"""
|
||||
Counts the number of non-zero values in the input tensor along the given dims.
|
||||
If no dim is specified then all non-zeros in the tensor are counted.
|
||||
|
||||
Note:
|
||||
The value range of "dims" is [-x_dims, x_dims). "x_dims" is the dimension length of input "x".
|
||||
|
||||
|
||||
Args:
|
||||
x (Tensor): Input to be computed, a N-D Tensor, can be any dimension. Set the shape of input tensor as
|
||||
:math:`(x_1, x_2, ..., x_N)` .
|
||||
dims (int, list[int], tuple[int]): The dimension to count the number of non-zero values along.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
A N-D Tensor, represents the number of the nonzero elements of the input tensor along the dims.
|
||||
Reduces x_shape along the dimensions given in dims. For example, if the size of x is (2, 3, 4),
|
||||
dims is [0, 1], y_shape will be (4,).
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type of `x` is not support.
|
||||
TypeError: If the data type of `dims` is not int.
|
||||
ValueError: If any of the values of `dims` is not in [-x_dims, x_dims).
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 0, 1], [1, 1, 2], [0, 0, 1]], mindspore.int64)
|
||||
>>> y = ops.count_nonzero(x, dims=[1])
|
||||
>>> print(y)
|
||||
[1 3 1]
|
||||
"""
|
||||
dims = [] if dims is None else dims
|
||||
count_nonzero_ = CountNonZero(dims)
|
||||
return count_nonzero_(x)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_with_pad',
|
||||
|
|
|
@ -7841,3 +7841,33 @@ class Bincount(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize Bincount"""
|
||||
self.init_prim_io_names(inputs=['array', 'size', 'weights'], outputs=['bins'])
|
||||
|
||||
|
||||
class CountNonZero(Primitive):
|
||||
"""
|
||||
Counts the number of non-zero values in the input tensor along the given dims.
|
||||
If no dim is specified then all non-zeros in the tensor are counted.
|
||||
|
||||
Refer to :func:`mindspore.ops.count_nonzero` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 0, 1], [1, 1, 2], [0, 0, 1]], dtype=mindspore.int64)
|
||||
>>> countnonzero = ops.CountNonZero(dims=[1])
|
||||
>>> y = countnonzero(x)
|
||||
>>> print(y)
|
||||
[1 3 1]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dims=None):
|
||||
dims = [] if dims is None else dims
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
validator.check_value_type('dims', dims, [int, list, tuple], "CountNonZero")
|
||||
if isinstance(dims, (list, tuple)):
|
||||
for i, each in enumerate(dims):
|
||||
validator.check_value_type(f'dims[{i}]', each, [int], "CountNonZero")
|
||||
self.dims = dims
|
||||
self.add_prim_attr("dims", self.dims)
|
||||
|
|
|
@ -405,6 +405,17 @@ class MaskedFillFunc(Cell):
|
|||
return y
|
||||
|
||||
|
||||
class CountNonZeroFunc(Cell):
|
||||
def __init__(self, dims):
|
||||
super(CountNonZeroFunc, self).__init__()
|
||||
self.countnonzero_ = ops.function.array_func.count_nonzero
|
||||
self.dims = dims
|
||||
|
||||
def construct(self, x):
|
||||
y = self.countnonzero_(x, self.dims)
|
||||
return y
|
||||
|
||||
|
||||
test_case_array_ops = [
|
||||
('CustNet1', {
|
||||
'block': CustNet1(),
|
||||
|
@ -476,6 +487,10 @@ test_case_array_ops = [
|
|||
Tensor(np.array([[True, True, False]]), mstype.bool_),
|
||||
Tensor(5.0, mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[3.0, 2.0, 1.0]]), mstype.float32)]}),
|
||||
('CountNonZero', {
|
||||
'block': CountNonZeroFunc(dims=()),
|
||||
'desc_inputs': [Tensor(np.array([[3.0, 2.0, 1.0]]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[3.0, 2.0, 1.0]]), mstype.float32)]}),
|
||||
('TensorShapeNet', {'block': TensorShapeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]})
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue