fix unique_with_pad docs and add unique_with_pad gpu op

This commit is contained in:
chaijun 2022-06-22 14:49:27 +08:00 committed by gongdaguo1
parent a624ca11e0
commit a011326430
25 changed files with 606 additions and 17 deletions

View File

@ -9,6 +9,7 @@ mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mi
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveToEvalImplMap
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetHostDependsMap
mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::OptimizeIRPassLib::OptimizeIRPassLib
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic

View File

@ -1554,7 +1554,7 @@ mindspore.Tensor
对当前一维张量中元素去重返回一维张量中的唯一元素使用pad_num填充和相对索引。
基本操作与unique相同但unique_with_pad多了pad操作。
unique运算符对张量处理后所返回的元组`y``idx``y``idx` 的shape通常会有差别因此为了解决上述情况
unique运算符对张量处理后所返回的元组 `y` `idx` `y``idx` 的shape通常会有差别因此为了解决上述情况
unique_with_pad操作符将用用户指定的 `pad_num` 填充 `y` 张量,使其具有与张量 `idx` 相同的形状。
**参数:**

View File

@ -6,7 +6,7 @@ mindspore.ops.UniqueWithPad
对输入一维张量中元素去重返回一维张量中的唯一元素使用pad_num填充和相对索引。
基本操作与Unique相同但UniqueWithPad多了Pad操作。
Unique运算符处理输入张量 `x` 后所返回的元组(`y``idx``y``idx` 的shape通常会有差别因此为了解决上述情况
Unique运算符处理输入张量 `x` 后所返回的元组( `y` `idx` `y``idx` 的shape通常会有差别因此为了解决上述情况
UniqueWithPad操作符将用用户指定的 `pad_num` 填充 `y` 张量,使其具有与张量 `idx` 相同的形状。
更多参考详见 :func:`mindspore.ops.unique_with_pad`

View File

@ -6,7 +6,7 @@ mindspore.ops.unique_with_pad
对输入一维张量中元素去重返回一维张量中的唯一元素使用pad_num填充和相对索引。
基本操作与unique相同但unique_with_pad多了pad操作。
unique运算符对张量处理后所返回的元组`y``idx``y``idx` 的shape通常会有差别因此为了解决上述情况
unique运算符对张量处理后所返回的元组 `y` `idx` `y``idx` 的shape通常会有差别因此为了解决上述情况
unique_with_pad操作符将用用户指定的 `pad_num` 填充 `y` 张量,使其具有与张量 `idx` 相同的形状。
**参数:**
@ -20,5 +20,5 @@ mindspore.ops.unique_with_pad
**异常:**
- **TypeError** - 'x'的数据类型既不是int32也不是int64。
- **ValueError** - 'x'不是一维张量。
- **TypeError** - `x` 的数据类型既不是int32也不是int64。
- **ValueError** - `x` 不是一维张量。

View File

@ -449,6 +449,7 @@ constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char DROPOUT[] = "Dropout";
constexpr char KStridedSlice[] = "StridedSlice";
constexpr char UNIQUE[] = "Unique";
constexpr char UNIQUE_WITH_PAD[] = "UniqueWithPad";
constexpr char UNIQUE_CONSECUTIVE[] = "UniqueConsecutive";
constexpr char GATHERND[] = "GatherNd";
constexpr char SCATTER_UPDATE[] = "ScatterUpdate";

View File

@ -685,7 +685,7 @@ static bool IsSplittableOperator(const std::string &op_name) {
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
POPULATION_COUNT};
UNIQUE_WITH_PAD, POPULATION_COUNT};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -39,6 +39,7 @@ constexpr auto kSegmentSumOpName = "SegmentSum";
constexpr auto kConcatOpName = "Concat";
constexpr auto kListDiffOpName = "ListDiff";
constexpr auto kUniqueOpName = "Unique";
constexpr auto kUniqueWithPadOpName = "UniqueWithPad";
constexpr auto kUniqueConsecutiveOpName = "UniqueConsecutive";
constexpr auto kMaskedSelectOpName = "MaskedSelect";
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
@ -836,6 +837,7 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName,
kUniqueWithPadOpName,
kUniqueConsecutiveOpName,
kComputeAccidentalHitsOpName,
kSubAndFilterOpName,

View File

@ -77,11 +77,11 @@ class UniqueGpuKernelMod : public NativeGpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
private:
void *stream_ptr_;
bool is_null_input_;
protected:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_ = nullptr;
std::optional<bool> is_input_dynamic_shape_ = {};
bool is_null_input_;
void *stream_ptr_;
BaseOperatorPtr base_operator_ = nullptr;
std::vector<KernelTensorPtr> inputs_ = {};
std::vector<KernelTensorPtr> outputs_ = {};

View File

@ -0,0 +1,102 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/unique_with_pad_gpu_kernel.h"
#include <functional>
#include <utility>
#include <string>
#include <algorithm>
#include "runtime/device/ms_device_shape_transfer.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateUniqueWithPadKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::UniqueWithPadHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using UniqueWithPadPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, UniqueWithPadPtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateUniqueWithPadKernelPtr<int, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateUniqueWithPadKernelPtr<int64_t, int64_t>}};
} // namespace
bool UniqueWithPadGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
base_operator_ = base_operator;
inputs_ = inputs;
outputs_ = outputs;
auto [is_match, index] = MatchKernelAttr(GetKernelAttrFromTensors(inputs, outputs), GetOpSupport());
if (!is_match) {
return false;
}
helper_ptr_ = kernel_attr[index].second(kernel_name_, device_id_);
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
constexpr size_t kUniqueWithPadInputNum = 2;
if (inputs.size() != kUniqueWithPadInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << inputs.size();
}
std::vector<size_t> shape =
std::vector<size_t>(inputs[0]->GetDeviceShapeAdaptively().begin(), inputs[0]->GetDeviceShapeAdaptively().end());
is_null_input_ = CHECK_SHAPE_NULL(shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
input_shapes.emplace_back(inputs[0]->GetDeviceShapeAdaptively());
input_shapes.emplace_back(inputs[1]->GetDeviceShapeAdaptively());
helper_ptr_->CalMemSize(input_shapes, output_shapes);
InitSizeLists();
if (!is_input_dynamic_shape_.has_value()) {
bool is_input_dynamic_shape = false;
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t dim) { return dim < 0; })) {
is_input_dynamic_shape = true;
break;
}
}
is_input_dynamic_shape_ = is_input_dynamic_shape;
}
return true;
}
std::vector<KernelAttr> UniqueWithPadGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, UniqueWithPadPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, UniqueWithPad, UniqueWithPadGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_UNIQUE_WITH_PAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_UNIQUE_WITH_PAD_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/arrays/unique_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/unique_with_pad_helper.h"
namespace mindspore {
namespace kernel {
class UniqueWithPadGpuKernelMod : public UniqueGpuKernelMod {
public:
UniqueWithPadGpuKernelMod() {
KernelMod::kernel_name_ = "UniqueWithPad";
ResetResource();
}
~UniqueWithPadGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_UNIQUE_WITH_PAD_GPU_KERNEL_H_

View File

@ -0,0 +1,98 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNIQUE_WITH_PAD_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNIQUE_WITH_PAD_HELPER_H_
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_with_pad_impl.cuh"
namespace mindspore {
namespace cukernel {
template <typename T, typename S>
class UniqueWithPadHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit UniqueWithPadHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
num_elements_ = 1;
}
virtual ~UniqueWithPadHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
ResetResource();
constexpr size_t INPUT_NUM = 2;
int flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
if (flag == -1) {
return flag;
}
num_elements_ = input_size_list_[0] / sizeof(T);
size_t workspace_size = num_elements_ * sizeof(S);
work_size_list_.emplace_back(workspace_size);
work_size_list_.emplace_back(workspace_size);
output_size_list_.emplace_back(input_size_list_[0]);
output_size_list_.emplace_back(num_elements_ * sizeof(S));
return 0;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
T *t_input_ptr = nullptr;
T *t_pad_num_ptr = nullptr;
S *s_input_index = nullptr;
S *s_sorted_index = nullptr;
T *t_output_ptr = nullptr;
S *s_output_index = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &t_input_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &t_pad_num_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(work_ptrs, 0, kernel_name_, &s_input_index);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(work_ptrs, 1, kernel_name_, &s_sorted_index);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &t_output_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(output_ptrs, 1, kernel_name_, &s_output_index);
if (flag != 0) {
return flag;
}
CalUniqueWithPad(t_input_ptr, num_elements_, s_input_index, s_sorted_index, t_output_ptr, s_output_index,
reinterpret_cast<cudaStream_t>(cuda_stream), t_pad_num_ptr);
return 0;
}
private:
int num_elements_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNIQUE_WITH_PAD_HELPER_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/adjacent_difference.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <thrust/device_vector.h>
#include <algorithm>
#include "unique_impl.cuh"
#include "unique_with_pad_impl.cuh"
template <typename T, typename S>
void CalUniqueWithPad(const T *input, int num_elements, S *input_index, S *sorted_index, T *output, S *index,
cudaStream_t cuda_stream, T *pad_num) {
int post_output_size = CalUnique(input, num_elements, input_index, sorted_index, output, index, cuda_stream);
auto policy = thrust::cuda::par.on(cuda_stream);
if (num_elements > post_output_size) {
thrust::device_reference<T> pad_ref(thrust::device_pointer_cast(pad_num));
thrust::fill_n(policy,
thrust::device_pointer_cast(output) + post_output_size,
num_elements - post_output_size,
pad_ref);
}
}
template CUDA_LIB_EXPORT void CalUniqueWithPad<int, int>(const int *input, int num_elements, int *input_index,
int *sorted_index, int *output, int *index,
cudaStream_t cuda_stream, int *pad_num);
template CUDA_LIB_EXPORT void CalUniqueWithPad<int64_t, int64_t>(const int64_t *input, int num_elements,
int64_t *input_index, int64_t *sorted_index,
int64_t *output, int64_t *index,
cudaStream_t cuda_stream, int64_t *pad_num);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNIQUE_WITH_PAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNIQUE_WITH_PAD_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalUniqueWithPad(const T *input, int num_elements, S *input_index, S *sorted_index, T *output,
S *index, cudaStream_t cuda_stream, T *pad_num);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNIQUE_WITH_PAD_IMPL_CUH_

View File

@ -195,6 +195,8 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueWithPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueConsecutive(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplOCRRecognitionPreHandle(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -222,6 +222,47 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplUniqueWithPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_LOG(WARNING) << "InferImplUniqueWithPad.";
// inputs: a 1-d Tensor
const std::string op_name = primitive->name();
constexpr size_t kUniqueWithPadInputNum = 2;
CheckArgsSize(op_name, args_spec_list, kUniqueWithPadInputNum);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
// Currently we choose the same data type as input for the idx.
TypePtr ids_idx_type = kInt32;
MS_EXCEPTION_IF_NULL(input->element());
MS_EXCEPTION_IF_NULL(input->element()->GetTypeTrack());
if (input->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
ids_idx_type = kInt64;
}
ShapeVector idx_shape = shape->shape();
ShapeVector idx_min_shape = shape->min_shape();
if (idx_min_shape.empty()) {
idx_min_shape = shape->shape();
}
ShapeVector idx_max_shape = shape->max_shape();
if (idx_max_shape.empty()) {
idx_max_shape = shape->shape();
}
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
auto ids = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
ids->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
// outputs: ids, ids_idx
AbstractBasePtrList elements = {ids, ids_idx};
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplUniqueConsecutive(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

View File

@ -216,6 +216,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimArrayToScalar, R{InferImplArrayToScalar, nullptr, true}},
{prim::kPrimBroadcastShape, R{InferImplBroadCastShape, nullptr, true}},
{prim::kPrimUnique, R{InferImplUnique, nullptr, true}},
{prim::kPrimUniqueWithPad, R{InferImplUniqueWithPad, nullptr, true}},
{prim::kPrimUniqueGrad, R{InferImplUniqueGrad, nullptr, true}},
{prim::kPrimUniqueConsecutive, R{InferImplUniqueConsecutive, nullptr, true}},
{prim::kPrimEmbeddingLookup, R{InferImplEmbeddingLookup, nullptr, true}},

View File

@ -408,6 +408,7 @@ GVAR_DEF(PrimitivePtr, kPrimPadding, std::make_shared<Primitive>(kPadding));
GVAR_DEF(PrimitivePtr, kPrimMirrorPad, std::make_shared<Primitive>(kMirrorPad));
GVAR_DEF(PrimitivePtr, kPrimArgMaxWithValue, std::make_shared<Primitive>("ArgMaxWithValue"));
GVAR_DEF(PrimitivePtr, kPrimUnique, std::make_shared<Primitive>("Unique"));
GVAR_DEF(PrimitivePtr, kPrimUniqueWithPad, std::make_shared<Primitive>("UniqueWithPad"));
GVAR_DEF(PrimitivePtr, kPrimUniqueGrad, std::make_shared<Primitive>("UniqueGrad"));
GVAR_DEF(PrimitivePtr, kPrimUniqueConsecutive, std::make_shared<Primitive>("UniqueConsecutive"));
GVAR_DEF(PrimitivePtr, kPrimExtractImagePatches, std::make_shared<Primitive>("ExtractImagePatches"));

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/unique_with_pad.h"
#include "ops/primitive_c.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(UniqueWithPad, BaseOperator);
REGISTER_PRIMITIVE_C(kNameUniqueWithPad, UniqueWithPad);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_UNIQUE_WITH_PAD_H_
#define MINDSPORE_CORE_OPS_UNIQUE_WITH_PAD_H_
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ops {
constexpr auto kNameUniqueWithPad = "UniqueWithPad";
/// \brief Returns the unique elements of input tensor with pad and also returns a tensor containing
/// the index of each value of input tensor corresponding to the output unique tensor.
/// Refer to Python API @ref mindspore.ops.UniqueWithPad for more details.
class MIND_API UniqueWithPad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(UniqueWithPad);
/// \brief Constructor.
UniqueWithPad() : BaseOperator(kNameUniqueWithPad) { InitIOName({"x", "pad_num"}, {"y", "idx"}); }
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_UNIQUE_WITH_PAD_H_

View File

@ -3967,9 +3967,10 @@ class Tensor(Tensor_):
Returns:
tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
- y (Tensor) - The unique elements filled with pad_num, the shape and data type same as self tensor.
- idx (Tensor) - The index of each value of self tensor in the unique output `y`,
the shape and data type same as self tensor.
the shape and data type same as self tensor.
Raises:
TypeError: If dtype of self tensor is neither int32 nor int64.

View File

@ -32,6 +32,7 @@ from ..operations.array_ops import UniqueConsecutive
from ..operations.array_ops import Col2Im
from ..operations.array_ops import NonZero
from ..operations.array_ops import IndexFill
from ..composite import _VmapGeneralRule
@vmap_rules_getters.register("Cast")
@ -706,6 +707,20 @@ def get_range_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.UniqueWithPad)
def get_unique_with_pad_vmap_rule(prim, axis_size):
"""VmapRule for `UniqueWithPad` operations."""
if isinstance(prim, str):
prim = P.UniqueWithPad()
prim_vmap = _VmapGeneralRule(prim, axis_size)
def vmap_rule(x_bdim, pad_num_bdim):
return prim_vmap(x_bdim, pad_num_bdim)
return vmap_rule
@vmap_rules_getters.register(P.array_ops.MatrixDiagV3)
def get_matrix_diag_v3_vmap_rule(prim, axis_size):
"""VmapRule for `MatrixDiagV3` operation."""

View File

@ -582,6 +582,7 @@ def unique_with_pad(x, pad_num):
Returns:
tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
- y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
- idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
@ -3504,6 +3505,7 @@ def max(input_x, axis=0, keep_dims=False):
__all__ = [
'unique',
'unique_with_pad',
'unique_consecutive',
'eye',
'matrix_band_part',

View File

@ -1250,7 +1250,7 @@ class Padding(Primitive):
self.pad_dim_size = pad_dim_size
class UniqueWithPad(PrimitiveWithInfer):
class UniqueWithPad(PrimitiveWithCheck):
"""
Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
@ -1278,16 +1278,11 @@ class UniqueWithPad(PrimitiveWithInfer):
def __init__(self):
"""init UniqueWithPad"""
def __infer__(self, x, pad_num):
def __check__(self, x, pad_num):
validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
out_shape = x_shape
out = {'shape': (out_shape, out_shape),
'dtype': (x['dtype'], x['dtype']),
'value': None}
return out
class Split(PrimitiveWithCheck):

View File

@ -0,0 +1,71 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.uniq = P.UniqueWithPad()
def construct(self, x, pad):
return self.uniq(x, pad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_int32():
"""
Feature: test uniquewithpad in cpu.
Description: test uniquewithpad forward with int32 dtype.
Expectation: expect correct forward result.
"""
x = Tensor(np.array([1, 2, 5, 2]), mstype.int32)
uniq = Net()
output = uniq(x, 0)
expect_y_result = [1, 2, 5, 0]
expect_idx_result = [0, 1, 2, 1]
assert (output[0].asnumpy() == expect_y_result).all()
assert (output[1].asnumpy() == expect_idx_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_int64():
"""
Feature: test uniquewithpad in cpu.
Description: test uniquewithpad forward with int64 dtype.
Expectation: expect correct forward result.
"""
x = Tensor(np.array([1, 2, 5, 2]), mstype.int64)
uniq = Net()
output = uniq(x, 0)
expect_y_result = [1, 2, 5, 0]
expect_idx_result = [0, 1, 2, 1]
assert (output[0].asnumpy() == expect_y_result).all()
assert (output[1].asnumpy() == expect_idx_result).all()

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetUniqueWithPad(nn.Cell):
def __init__(self):
super(NetUniqueWithPad, self).__init__()
self.unique = P.UniqueWithPad()
def construct(self, x, pad_num):
x_unique, x_idx = self.unique(x, pad_num)
return x_unique, x_idx
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_unique_with_pad_int32():
"""
Feature: test uniquewithpad in gpu.
Description: test uniquewithpad forward with int32 dtype.
Expectation: expect correct forward result.
"""
x = Tensor(np.array([1, 2, 2, 3, 3, 3, 4, 5]).astype(np.int32))
exp_output = np.array([1, 2, 3, 4, 5, 99, 99, 99]).astype(np.int32)
exp_idx = np.array([0, 1, 1, 2, 2, 2, 3, 4]).astype(np.int32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = NetUniqueWithPad()
x_unique, x_idx = net(x, 99)
assert (x_unique.asnumpy() == exp_output).all()
assert (x_idx.asnumpy() == exp_idx).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_unique_with_pad_int64():
"""
Feature: test uniquewithpad in gpu.
Description: test uniquewithpad forward with int64 dtype.
Expectation: expect correct forward result.
"""
x = Tensor(np.array([1, 2, 2, 3, 3, 3, 4, 5]).astype(np.int64))
exp_output = np.array([1, 2, 3, 4, 5, 99, 99, 99]).astype(np.int64)
exp_idx = np.array([0, 1, 1, 2, 2, 2, 3, 4]).astype(np.int64)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = NetUniqueWithPad()
x_unique, x_idx = net(x, 99)
assert (x_unique.asnumpy() == exp_output).all()
assert (x_idx.asnumpy() == exp_idx).all()