[feat] [assistant] [I4XJGW] Add UniqueConsecutive

This commit is contained in:
XinWang2021 2022-10-31 21:57:10 +08:00
parent 1b7fc0aa16
commit 660b630c8a
7 changed files with 487 additions and 21 deletions

View File

@ -7,7 +7,7 @@ mindspore.ops.unique_consecutive
参数:
- **x** (Tensor) - 输入Tensor。
- **return_idx** (bool, 可选) - 是否返回每个去重元素在输入中所在的连续序列的末尾位置的索引。默认值False。
- **return_idx** (bool, 可选) - 是否返回每个输入中元素映射到输出中位置的索引。默认值False。
- **return_counts** (bool, 可选) - 是否返回每个去重元素在输入所在的连续序列的计数。默认值False。
- **axis** (int, 可选) - 维度。如果为None则对输入进行展平操作。如果指定必须是int32或int64类型。默认值None。
@ -15,9 +15,13 @@ mindspore.ops.unique_consecutive
Tensor或包含Tensor对象的元组 `output``idx``counts` )。
- `output` 为去重后的输出,与 `x` 具有相同的数据类型。
- 如果 `return_idx` 为 True则返回张量 `idx` shape与 `x` 相同,表示每个去重元素在输入中所在的连续序列的末尾位置的索引。
- 如果 `return_idx` 为 True则返回张量 `idx` shape与 `x` 相同,表示每个输入中元素映射到输出中位置的索引。
- 如果 `return_counts` 为 True则返回张量 `counts` ,表示每个去重元素在输入中所在的连续序列的计数。
异常:
- **TypeError** - `x` 不是Tensor。
- **RuntimeError** - `axis` 不在 `[-ndim, ndim-1]` 范围内。
- **TypeError** - `x`的数据类型不支持。
- **TypeError** - `return_idx` 不是bool。
- **TypeError** - `return_counts` 不是bool。
- **TypeError** - `axis` 不是int。
- **ValueError** - `axis` 不在 `[-ndim, ndim-1]` 范围内。

View File

@ -0,0 +1,378 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/unique_consecutive_cpu_kernel.h"
#include <algorithm>
#include <map>
#include <set>
#include <utility>
#include <complex>
#include <functional>
#include "include/common/thread_pool.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/unique_consecutive.h"
#include "mindspore/core/base/base.h"
namespace mindspore {
namespace kernel {
namespace {
// Value check constant
constexpr size_t kUniqueConsecutiveInputsNum = 1;
constexpr size_t kUniqueConsecutiveOutputsNum = 3;
// Attr default value constant
constexpr int64_t kNone = 1000;
template <typename T>
class PositionIterator {
public:
PositionIterator() {}
~PositionIterator() {}
PositionIterator(std::vector<T> stt, std::vector<T> sh) {
if (stt.size() != sh.size()) {
PositionIterator();
} else {
for (unsigned int i = 0; i < sh.size(); i++) {
if (stt[i] >= sh[i]) {
PositionIterator();
}
}
pos_ = stt;
shape_ = sh;
}
}
PositionIterator operator++() {
pos_[shape_.size() - static_cast<size_t>(1)] += 1;
for (size_t i = shape_.size() - static_cast<size_t>(1); i > static_cast<size_t>(0); i--) {
if (pos_[i] / shape_[i] != 0) {
pos_[i - 1] += pos_[i] / shape_[i];
pos_[i] = pos_[i] % shape_[i];
}
}
return *this;
}
bool End() {
if (pos_[0] != shape_[0]) {
return false;
}
return true;
}
std::vector<T> GetPos() { return pos_; }
std::vector<T> GetShape() { return shape_; }
private:
std::vector<T> pos_;
std::vector<T> shape_;
};
template <typename T>
std::vector<T> ConstructStride(std::vector<T> t_shape) {
std::vector<T> t_stride(t_shape.size(), 1);
int initial = 1;
for (size_t i = t_shape.size(); i > 0; i--) {
t_stride[i - 1] = initial;
initial = initial * static_cast<int>(t_shape[i - static_cast<size_t>(1)]);
}
return t_stride;
}
template <typename T>
T MulSum(std::vector<T> v1, std::vector<T> v2) {
T mul_sum = 0;
for (unsigned int i = 0; i < v1.size(); i++) {
mul_sum += v1[i] * v2[i];
}
return mul_sum;
}
template <typename T1>
std::vector<std::vector<T1>> ReshapeInput(std::vector<int64_t> input_shape_, int32_t axis, T1 *x_dataptr) {
int64_t dim0 = input_shape_[static_cast<size_t>(axis)];
std::vector<int64_t> input_stride = ConstructStride<int64_t>(input_shape_);
std::vector<int64_t> v_shape = input_shape_;
v_shape.erase(v_shape.begin() + axis);
std::vector<int64_t> v_start(v_shape.size(), 0);
std::vector<int64_t> v_stride = input_stride;
v_stride.erase(v_stride.begin() + axis);
std::vector<std::vector<T1>> data_;
for (int64_t i = 0; i < dim0; i++) {
std::vector<T1> tmp_v1;
for (PositionIterator<int64_t> mit(v_start, v_shape); !mit.End(); ++mit) {
auto pos = mit.GetPos();
tmp_v1.push_back(
x_dataptr[static_cast<size_t>(MulSum<int64_t>(pos, v_stride) + i * input_stride[static_cast<size_t>(axis)])]);
}
data_.push_back(tmp_v1);
}
return data_;
}
template <typename T1>
void OutputYSet(const std::vector<int64_t> &y_shape_, const std::vector<int64_t> &input_shape_, int32_t axis,
T1 *y_dataptr, std::vector<std::vector<T1>> out_data_) {
std::vector<int64_t> y_stride = ConstructStride<int64_t>(y_shape_);
std::vector<int64_t> y_v_shape = y_shape_;
y_v_shape.erase(y_v_shape.begin() + axis);
std::vector<int64_t> y_v_start(y_v_shape.size(), 0);
std::vector<int64_t> y_v_stride = y_stride;
y_v_stride.erase(y_v_stride.begin() + axis);
std::vector<int64_t> v_shape = input_shape_;
v_shape.erase(v_shape.begin() + axis);
std::vector<int64_t> trans_stride = ConstructStride<int64_t>(v_shape);
int64_t size0 = static_cast<int64_t>(out_data_.size());
for (int64_t i = 0; i < size0; i++) {
auto tmp_v = out_data_[static_cast<size_t>(i)];
for (PositionIterator<int64_t> mit(y_v_start, y_v_shape); !mit.End(); ++mit) {
auto pos = mit.GetPos();
y_dataptr[static_cast<size_t>(MulSum<int64_t>(pos, y_v_stride) + i * y_stride[axis])] =
tmp_v[static_cast<size_t>(MulSum<int64_t>(pos, trans_stride))];
}
}
}
} // namespace
bool UniqueConsecutiveCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
outputs_ = outputs;
auto kernel_ptr = std::dynamic_pointer_cast<ops::UniqueConsecutive>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast UniqueConsecutive ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
// Get attrs from primitive.
auto axis_ptr = base_operator->GetAttr("axis");
return_idx_ = GetValue<bool>(base_operator->GetAttr("return_idx"));
return_counts_ = GetValue<bool>(base_operator->GetAttr("return_counts"));
// Get input shape
if (axis_ptr == nullptr || GetValue<int64_t>(axis_ptr) == kNone) {
axis_ = kNone;
} else {
axis_ = GetValue<int64_t>(axis_ptr);
}
is_need_retrieve_output_shape_ = true;
return true;
}
int UniqueConsecutiveCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
outputs_ = outputs;
auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
input_shape_ = inputs[0]->GetShapeVector();
int64_t input_size = input_shape_.size();
axis_ = axis_ < 0 ? (axis_ + input_size) : axis_;
return ret;
}
template <typename T1, typename T2>
void UniqueConsecutiveCpuKernelMod::UniqueConsecutiveNone(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
// Get the input and output
const T1 *input_x = GetDeviceAddress<T1>(inputs, kIndex0);
T1 *output_y = GetDeviceAddress<T1>(outputs, kIndex0);
T2 *output_idx = GetDeviceAddress<T2>(outputs, kIndex1);
T2 *output_count = GetDeviceAddress<T2>(outputs, kIndex2);
int64_t input_total = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<int64_t>());
if (input_total > 0) {
*output_y = *input_x;
T1 *p = output_y;
T2 *q = output_count;
T2 last = 0;
for (T2 i = 0; i < input_total; i++) {
if (input_x[i] != *p) {
*(++p) = input_x[i];
if (return_counts_) {
*(q++) = i - last;
}
last = i;
}
if (return_idx_) {
output_idx[i] = static_cast<T2>(p - output_y);
}
}
if (return_counts_) {
*q = input_total - last;
}
// Set the shape of output and count, the idx has the same shape of input
output_shape_.push_back((p - output_y) + 1);
if (return_idx_) {
idx_shape_ = input_shape_;
} else {
idx_shape_.clear();
idx_shape_.push_back(0);
}
if (return_counts_) {
count_shape_ = output_shape_;
} else {
count_shape_.clear();
count_shape_.push_back(0);
}
} else {
output_shape_.push_back(0);
if (return_idx_) {
idx_shape_ = input_shape_;
} else {
idx_shape_.clear();
idx_shape_.push_back(0);
}
if (return_counts_) {
count_shape_ = input_shape_;
} else {
count_shape_.clear();
count_shape_.push_back(0);
}
}
}
template <typename T1, typename T2>
void UniqueConsecutiveCpuKernelMod::UniqueConsecutiveDim(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
// Get the inuput and output
T1 *input_x = GetDeviceAddress<T1>(inputs, kIndex0);
T1 *output_y = GetDeviceAddress<T1>(outputs, kIndex0);
T2 *output_idx = GetDeviceAddress<T2>(outputs, kIndex1);
T2 *output_count = GetDeviceAddress<T2>(outputs, kIndex2);
auto num_zero_dims = std::count(input_shape_.begin(), input_shape_.end(), 0);
int64_t dim0 = input_shape_[static_cast<size_t>(axis_)];
// Set the idx shape
if (return_idx_) {
idx_shape_.clear();
idx_shape_.push_back(dim0);
} else {
idx_shape_.clear();
idx_shape_.push_back(0);
}
// Some check
if (dim0 == 0) {
if (num_zero_dims != 1) {
MS_LOG(EXCEPTION)
<< "For 'UniqueConsecutive', the number of zero sized dimensions > 1, so unique cannot be applied.";
} else {
output_shape_.push_back(0);
count_shape_.push_back(0);
return;
}
}
if (num_zero_dims != 0) {
MS_LOG(EXCEPTION) << "For 'UniqueConsecutive', there are 0 sized dimensions, and they aren't selected by 'axis', "
"so unique cannot be applied.";
}
// If the input is 1D, return UniqueConsecutiveNone
if (input_shape_.size() != 1) {
std::vector<std::vector<T1>> data_ = ReshapeInput<T1>(input_shape_, axis_, input_x);
std::vector<std::vector<T1>> out_data_;
out_data_.push_back(data_[0]);
auto p = data_[0];
T2 *q = output_count;
T2 last = 0;
for (size_t i = 0; i < static_cast<size_t>(dim0); i++) {
if (!std::equal(data_[i].begin(), data_[i].end(), p.begin())) {
p = data_[i];
out_data_.push_back(data_[i]);
if (return_counts_) {
*(q++) = static_cast<T2>(i) - last;
}
last = static_cast<T2>(i);
}
if (return_idx_) {
output_idx[i] = static_cast<T2>(static_cast<int32_t>(out_data_.size()) - 1);
}
}
if (return_counts_) {
*q = static_cast<T2>(dim0) - last;
}
output_shape_ = input_shape_;
output_shape_[static_cast<size_t>(axis_)] = static_cast<int64_t>(out_data_.size());
OutputYSet(output_shape_, input_shape_, axis_, output_y, out_data_);
// Set the output and count shape
if (return_counts_) {
count_shape_.clear();
count_shape_.push_back(out_data_.size());
} else {
count_shape_.clear();
count_shape_.push_back(0);
}
} else {
return UniqueConsecutiveNone<T1, T2>(inputs, outputs);
}
}
template <typename T1, typename T2>
bool UniqueConsecutiveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kUniqueConsecutiveInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniqueConsecutiveOutputsNum, kernel_name_);
if (axis_ == kNone) {
UniqueConsecutiveNone<T1, T2>(inputs, outputs);
} else {
UniqueConsecutiveDim<T1, T2>(inputs, outputs);
}
// Update output shape and type
outputs_[kIndex0]->SetShapeVector(output_shape_);
outputs_[kIndex1]->SetShapeVector(idx_shape_);
outputs_[kIndex2]->SetShapeVector(count_shape_);
return true;
}
#define CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
{ \
KernelAttr() \
.AddInputAttr(ms_value_type) \
.AddOutputAttr(ms_value_type) \
.AddOutputAttr(ms_index_type) \
.AddOutputAttr(ms_index_type), \
&UniqueConsecutiveCpuKernelMod::LaunchKernel<value_type, index_type> \
}
using UCKernelRunFunc = UniqueConsecutiveCpuKernelMod::KernelRunFunc;
const std::vector<std::pair<KernelAttr, UCKernelRunFunc>> &UniqueConsecutiveCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, UCKernelRunFunc>> func_list = {
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex<float>),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, int64_t, std::complex<double>),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat16, int64_t, float16),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt8, int64_t, int8_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeUInt8, int64_t, uint8_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeUInt16, int64_t, uint16_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeUInt32, int64_t, uint32_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeUInt64, int64_t, uint64_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeComplex64, int32_t, std::complex<float>),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeComplex128, int32_t, std::complex<double>),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat16, int32_t, float16),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, int32_t, float),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, int32_t, double),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt8, int32_t, int8_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, int32_t, int16_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t, int32_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t, int64_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeUInt8, int32_t, uint8_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeUInt16, int32_t, uint16_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeUInt32, int32_t, uint32_t),
CPU_UNIQUE_CONSECUTIVE_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeUInt64, int32_t, uint64_t)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, UniqueConsecutive, UniqueConsecutiveCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_CONSECUTIVE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_CONSECUTIVE_CPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class UniqueConsecutiveCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<UniqueConsecutiveCpuKernelMod> {
public:
UniqueConsecutiveCpuKernelMod() = default;
~UniqueConsecutiveCpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T1, typename T2>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);
template <typename T1, typename T2>
void UniqueConsecutiveDim(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T1, typename T2>
void UniqueConsecutiveNone(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
bool return_idx_;
bool return_counts_;
int64_t axis_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> output_shape_;
std::vector<int64_t> idx_shape_;
std::vector<int64_t> count_shape_;
std::vector<KernelTensorPtr> outputs_ = {};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_CONSECUTIVE_CPU_KERNEL_H_

View File

@ -75,8 +75,8 @@ abstract::BaseShapePtr UniqueConsecutiveInferShape(const PrimitivePtr &primitive
int64_t axis = GetValue<int64_t>(axis_ptr);
int64_t ndims = SizeToLong(input_shape_vec.size());
if (axis >= ndims || axis < -ndims) {
MS_LOG(EXCEPTION) << "For " << op_name << ", the axis must be in the range [-" << ndims << "," << ndims << ")"
<< "but got " << axis << ".";
MS_EXCEPTION(ValueError) << "For " << op_name << ", the axis must be in the range [-" << ndims << "," << ndims
<< "), but got " << axis << ".";
}
if (axis < 0) {
axis = axis + ndims;
@ -121,7 +121,8 @@ abstract::BaseShapePtr UniqueConsecutiveInferShape(const PrimitivePtr &primitive
TypePtr UniqueConsecutiveInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name();
const std::set valid_types = {kInt32, kInt64, kFloat16, kFloat, kFloat64};
const std::set valid_types = {kComplex64, kComplex128, kFloat16, kFloat, kFloat64, kInt8, kInt16,
kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
auto input_type = CheckAndConvertUtils::CheckTypeValid("input", input_args[0]->BuildType(), valid_types, name);
std::vector<TypePtr> ret_type_vec = {input_type, std::make_shared<TensorType>(kInt32),
std::make_shared<TensorType>(kInt32)};

View File

@ -25,11 +25,19 @@ unique_consecutive_op_info = AiCPURegOp("UniqueConsecutive") \
.attr("return_idx", "bool") \
.attr("return_counts", "bool") \
.attr("axis", "int") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()

View File

@ -51,6 +51,7 @@ from mindspore.common import Tensor
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore._c_expression import Tensor as Tensor_
eye_ = P.Eye()
fill_ = P.Fill()
@ -1000,8 +1001,8 @@ def unique_consecutive(x, return_idx=False, return_counts=False, axis=None):
Args:
x (Tensor): The input tensor.
return_idx (bool, optional): Whether to return the indices of the end position of each element in the
original input list in the returned unique list. Default: False.
return_idx (bool, optional): Whether to return the index of where the element in the original input
maps to the position in the output. Default: False.
return_counts (bool, optional): Whether to return the counts of each unique element. Default: False.
axis (int, optional): The dimension to apply unique. If None, the unique of the flattened input is
returned. If specified, it must be int32 or int64. Default: None.
@ -1016,16 +1017,16 @@ def unique_consecutive(x, return_idx=False, return_counts=False, axis=None):
Raises:
TypeError: If `x` is not a Tensor.
RuntimeError: If `axis` is not in the range of :math:`[-ndim, ndim-1]`.
TypeError: If dtype of `x` is not supported.
TypeError: If `return_idx` is not a bool.
TypeError: If `return_counts` is not a bool.
TypeError: If `axis` is not an int.
ValueError: If `axis` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import ops
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> x = Tensor(np.array([1, 1, 2, 2, 3, 1, 1, 2]), mstype.int32)
>>> output, idx, counts = ops.unique_consecutive(x, True, True, None)
>>> print(output)
@ -1035,6 +1036,9 @@ def unique_consecutive(x, return_idx=False, return_counts=False, axis=None):
>>> print(counts)
[2 2 1 2 1]
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError("For 'unique_consecutive', 'x' must be Tensor.")
unique_consecutive_op = _get_cache_prim(UniqueConsecutive)(return_idx, return_counts, axis)
output, idx, counts = unique_consecutive_op(x)
if return_idx and return_counts:

View File

@ -920,13 +920,9 @@ class UniqueConsecutive(Primitive):
Refer to :func:`mindspore.ops.unique_consecutive` for more details.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> from mindspore.ops import UniqueConsecutive
>>> x = Tensor(np.array([1, 1, 2, 2, 3, 1, 1, 2]), mstype.int32)
>>> unique_consecutive = UniqueConsecutive(True, True, None)
>>> output, idx, counts = unique_consecutive(x)
@ -940,6 +936,7 @@ class UniqueConsecutive(Primitive):
@prim_attr_register
def __init__(self, return_idx=False, return_counts=False, axis=None):
"""Initialize UniqueConsecutive"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("return_idx", return_idx, [bool], self.name)
validator.check_value_type("return_counts", return_counts, [bool], self.name)