From fb2bd156c96a88a92f460741ed2ed653f7025c60 Mon Sep 17 00:00:00 2001 From: liyong Date: Mon, 21 Sep 2020 10:50:58 +0800 Subject: [PATCH] add unique op --- .../bindings/dataset/kernels/data/bindings.cc | 5 ++ .../dataset/kernels/data/CMakeLists.txt | 1 + .../dataset/kernels/data/data_utils.cc | 72 +++++++++++++++++++ .../dataset/kernels/data/data_utils.h | 22 ++++++ .../dataset/kernels/data/unique_op.cc | 53 ++++++++++++++ .../minddata/dataset/kernels/data/unique_op.h | 45 ++++++++++++ .../minddata/dataset/kernels/tensor_op.h | 1 + mindspore/dataset/transforms/c_transforms.py | 31 ++++++++ tests/ut/python/dataset/test_unique_op.py | 45 ++++++++++++ 9 files changed, 275 insertions(+) create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.h create mode 100644 tests/ut/python/dataset/test_unique_op.py diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc index 084ceaefb9b..e58557a5112 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc @@ -28,6 +28,7 @@ #include "minddata/dataset/kernels/data/slice_op.h" #include "minddata/dataset/kernels/data/to_float16_op.h" #include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/unique_op.h" namespace mindspore { namespace dataset { @@ -42,6 +43,10 @@ PYBIND_REGISTER( (void)py::class_>(*m, "DuplicateOp").def(py::init<>()); })); +PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "UniqueOp").def(py::init<>()); + })); + PYBIND_REGISTER( FillOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "FillOp").def(py::init>()); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt index 9131c9c6676..9a8f0b88180 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt @@ -11,4 +11,5 @@ add_library(kernels-data OBJECT mask_op.cc concatenate_op.cc duplicate_op.cc + unique_op.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index 332d9d9e040..99fa46ee07a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -706,6 +706,78 @@ Status TensorVectorToBatchTensor(const std::vector> &inp } return Status::OK(); } +template +struct UniqueOpHashMap { + using map_type = std::unordered_map; +}; + +template +Status UniqueHelper(const std::shared_ptr &input, std::shared_ptr *output, + std::shared_ptr *output_idx, std::shared_ptr *output_cnt) { + const dsize_t N = input->Size(); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output_idx)); + + typename UniqueOpHashMap::map_type uniq; + uniq.reserve(2 * N); + auto in_iter = input->begin(); + auto out_idx_iter = (*output_idx)->begin(); + int32_t i = 0; + for (; in_iter != input->end(); ++in_iter, ++out_idx_iter) { + auto it = uniq.emplace(*in_iter, i); + *out_idx_iter = it.first->second; + if (it.second) { + ++i; + } + } + auto uniq_size = uniq.size(); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast(uniq_size)}), input->type(), output)); + auto out_iter = (*output)->begin(); + for (const auto &it : uniq) { + *(out_iter + static_cast(it.second)) = it.first; + } + RETURN_IF_NOT_OK( + Tensor::CreateEmpty(TensorShape({static_cast(uniq_size)}), DataType(DataType::DE_INT32), output_cnt)); + RETURN_IF_NOT_OK((*output_cnt)->Zero()); + + auto out_cnt_iter = (*output_cnt)->begin(); + out_idx_iter = (*output_idx)->begin(); + for (int32_t j = 0; j < N; ++j) { + auto idx = *(out_idx_iter + static_cast(j)); + ++*(out_cnt_iter + static_cast(idx)); + } + return Status::OK(); +} + +Status Unique(const std::shared_ptr &input, std::shared_ptr *output, + std::shared_ptr *output_idx, std::shared_ptr *output_cnt) { + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "Only 1D tensors supported."); + if (input->type() == DataType::DE_INT64) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_INT32) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_INT16) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_INT8) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_UINT64) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_UINT32) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_UINT16) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_UINT8) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_FLOAT16) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_FLOAT32) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else if (input->type() == DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK(UniqueHelper(input, output, output_idx, output_cnt)); + } else { + RETURN_STATUS_UNEXPECTED("Unique op only supports numeric input."); + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h index fd66d4def3b..2350c93588e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/data_type.h" @@ -176,6 +177,27 @@ Status BatchTensorToTensorVector(const std::shared_ptr &input, std::vect /// \return Status ok/error Status TensorVectorToBatchTensor(const std::vector> &input, std::shared_ptr *output); +/// Helper method that uniques the input tensor +/// @tparam T type of the tensor +/// \param input[in] input 1d tensor +/// \param output[out] output tensor +/// \param output[out] output tensor of item index +/// \param output[out] output tensor of item count +/// \return Status ok/error +template +Status UniqueHelper(const std::shared_ptr &input, std::shared_ptr *output, + std::shared_ptr *output_idx, std::shared_ptr *output_cnt); + +/// Unique the input tensor +/// @tparam T type of the tensor +/// \param input[in] input 1d tensor +/// \param output[out] output tensor +/// \param output[out] output tensor of item index +/// \param output[out] output tensor of item count +/// \return Status ok/error +Status Unique(const std::shared_ptr &input, std::shared_ptr *output, + std::shared_ptr *output_idx, std::shared_ptr *output_cnt); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.cc new file mode 100644 index 00000000000..af260e6106c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/unique_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status UniqueOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + + auto in_tensor = input[0]; + auto in_tensor_shape = in_tensor->shape(); + auto in_tensor_type = in_tensor->type(); + + CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_type.IsNumeric(), "Tensor type must be numeric."); + CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_shape.Rank() >= 2, "Tensor must be at least 2-D in order to do unique op."); + CHECK_FAIL_RETURN_UNEXPECTED( + in_tensor->Size() <= std::numeric_limits::max(), + "UniqueOp does not support input tensor large than " + std::to_string(std::numeric_limits::max())); + + RETURN_IF_NOT_OK(in_tensor->Reshape(TensorShape({in_tensor->Size()}))); + + std::shared_ptr out; + std::shared_ptr out_idx; + std::shared_ptr out_cnt; + + RETURN_IF_NOT_OK(Unique(in_tensor, &out, &out_idx, &out_cnt)); + + output->push_back(out); + output->push_back(out_idx); + output->push_back(out_cnt); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.h new file mode 100644 index 00000000000..ecae9801aa3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/unique_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class UniqueOp : public TensorOp { + public: + UniqueOp() = default; + + ~UniqueOp() override = default; + + Status Compute(const TensorRow &input, TensorRow *output) override; + + uint32_t NumOutput() override { return 0; } + + std::string Name() const override { return kUniqueOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_UNIQUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 2294e72bb81..779a5e6f262 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -125,6 +125,7 @@ constexpr char kPadEndOp[] = "PadEndOp"; constexpr char kSliceOp[] = "SliceOp"; constexpr char kToFloat16Op[] = "ToFloat16Op"; constexpr char kTypeCastOp[] = "TypeCastOp"; +constexpr char kUniqueOp[] = "UniqueOp"; // other constexpr char kCFuncOp[] = "CFuncOp"; diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 7c19bfd4295..3c1503b0301 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -296,6 +296,37 @@ class Duplicate(cde.DuplicateOp): """ +class Unique(cde.UniqueOp): + """ + Return an output tensor containing all the unique elements of the input tensor in + the same order that they occur in the input tensor. + + Also return an index tensor that contains the index of each element of the + input tensor in the Unique output tensor. + + Finally, return a count tensor that constains the count of each element of + the output tensor in the input tensor. + + Note: + Call batch op before calling this function. + + Examples: + >>> import mindspore.dataset.transforms.c_transforms as c_transforms + >>> + >>> # Data before + >>> # | x | + >>> # +--------------------+ + >>> # | [[0,1,2], [1,2,3]] | + >>> # +--------------------+ + >>> data1 = data1.map(operations=c_transforms.Unique(), input_columns=["x"], + >>> output_columns=["x", "y", "z"], column_order=["x", "y", "z"]) + >>> # Data after + >>> # | x | y |z | + >>> # +---------+-----------------+---------+ + >>> # | [0,1,2,3] | [0,1,2,1,2,3] | [1,2,2,1] + >>> # +---------+-----------------+---------+ + + """ class Compose(cde.ComposeOp): """ Compose a list of transforms into a single transform. diff --git a/tests/ut/python/dataset/test_unique_op.py b/tests/ut/python/dataset/test_unique_op.py new file mode 100644 index 00000000000..9efa2ac38e6 --- /dev/null +++ b/tests/ut/python/dataset/test_unique_op.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing unique op in DE +""" +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops + + +def compare(array, res, idx, cnt): + data = ds.NumpySlicesDataset([array], column_names="x") + data = data.batch(2) + data = data.map(operations=ops.Unique(), input_columns=["x"], output_columns=["x", "y", "z"], + column_order=["x", "y", "z"]) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + np.testing.assert_array_equal(res, d["x"]) + np.testing.assert_array_equal(idx, d["y"]) + np.testing.assert_array_equal(cnt, d["z"]) + + +def test_duplicate_basics(): + compare([0, 1, 2, 1, 2, 3], np.array([0, 1, 2, 3]), + np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) + compare([0.0, 1.0, 2.0, 1.0, 2.0, 3.0], np.array([0.0, 1.0, 2.0, 3.0]), + np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) + compare([1, 1, 1, 1, 1, 1], np.array([1]), + np.array([0, 0, 0, 0, 0, 0]), np.array([6])) + + +if __name__ == "__main__": + test_duplicate_basics()