!6645 [MD] add unique op

Merge pull request !6645 from liyong126/md_unique_op
This commit is contained in:
mindspore-ci-bot 2020-10-22 14:43:37 +08:00 committed by Gitee
commit 1ee9c4d014
9 changed files with 275 additions and 0 deletions

View File

@ -28,6 +28,7 @@
#include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/to_float16_op.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
#include "minddata/dataset/kernels/data/unique_op.h"
namespace mindspore {
namespace dataset {
@ -42,6 +43,10 @@ PYBIND_REGISTER(
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp").def(py::init<>());
}));
PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) {
(void)py::class_<UniqueOp, TensorOp, std::shared_ptr<UniqueOp>>(*m, "UniqueOp").def(py::init<>());
}));
PYBIND_REGISTER(
FillOp, 1, ([](const py::module *m) {
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>());

View File

@ -11,4 +11,5 @@ add_library(kernels-data OBJECT
mask_op.cc
concatenate_op.cc
duplicate_op.cc
unique_op.cc
)

View File

@ -706,6 +706,78 @@ Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &inp
}
return Status::OK();
}
template <typename T>
struct UniqueOpHashMap {
using map_type = std::unordered_map<T, int32_t>;
};
template <typename T>
Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
const dsize_t N = input->Size();
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output_idx));
typename UniqueOpHashMap<T>::map_type uniq;
uniq.reserve(2 * N);
auto in_iter = input->begin<T>();
auto out_idx_iter = (*output_idx)->begin<int32_t>();
int32_t i = 0;
for (; in_iter != input->end<T>(); ++in_iter, ++out_idx_iter) {
auto it = uniq.emplace(*in_iter, i);
*out_idx_iter = it.first->second;
if (it.second) {
++i;
}
}
auto uniq_size = uniq.size();
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output));
auto out_iter = (*output)->begin<T>();
for (const auto &it : uniq) {
*(out_iter + static_cast<ptrdiff_t>(it.second)) = it.first;
}
RETURN_IF_NOT_OK(
Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt));
RETURN_IF_NOT_OK((*output_cnt)->Zero());
auto out_cnt_iter = (*output_cnt)->begin<int32_t>();
out_idx_iter = (*output_idx)->begin<int32_t>();
for (int32_t j = 0; j < N; ++j) {
auto idx = *(out_idx_iter + static_cast<ptrdiff_t>(j));
++*(out_cnt_iter + static_cast<ptrdiff_t>(idx));
}
return Status::OK();
}
Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "Only 1D tensors supported.");
if (input->type() == DataType::DE_INT64) {
RETURN_IF_NOT_OK(UniqueHelper<int64_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_INT32) {
RETURN_IF_NOT_OK(UniqueHelper<int32_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_INT16) {
RETURN_IF_NOT_OK(UniqueHelper<int16_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_INT8) {
RETURN_IF_NOT_OK(UniqueHelper<int8_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_UINT64) {
RETURN_IF_NOT_OK(UniqueHelper<uint64_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_UINT32) {
RETURN_IF_NOT_OK(UniqueHelper<uint32_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_UINT16) {
RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_UINT8) {
RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT16) {
RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT32) {
RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT64) {
RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt));
} else {
RETURN_STATUS_UNEXPECTED("Unique op only supports numeric input.");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/data_type.h"
@ -176,6 +177,27 @@ Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vect
/// \return Status ok/error
Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output);
/// Helper method that uniques the input tensor
/// @tparam T type of the tensor
/// \param input[in] input 1d tensor
/// \param output[out] output tensor
/// \param output[out] output tensor of item index
/// \param output[out] output tensor of item count
/// \return Status ok/error
template <typename T>
Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt);
/// Unique the input tensor
/// @tparam T type of the tensor
/// \param input[in] input 1d tensor
/// \param output[out] output tensor
/// \param output[out] output tensor of item index
/// \param output[out] output tensor of item count
/// \return Status ok/error
Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt);
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/data/unique_op.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
Status UniqueOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor");
auto in_tensor = input[0];
auto in_tensor_shape = in_tensor->shape();
auto in_tensor_type = in_tensor->type();
CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_type.IsNumeric(), "Tensor type must be numeric.");
CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_shape.Rank() >= 2, "Tensor must be at least 2-D in order to do unique op.");
CHECK_FAIL_RETURN_UNEXPECTED(
in_tensor->Size() <= std::numeric_limits<int32_t>::max(),
"UniqueOp does not support input tensor large than " + std::to_string(std::numeric_limits<int32_t>::max()));
RETURN_IF_NOT_OK(in_tensor->Reshape(TensorShape({in_tensor->Size()})));
std::shared_ptr<Tensor> out;
std::shared_ptr<Tensor> out_idx;
std::shared_ptr<Tensor> out_cnt;
RETURN_IF_NOT_OK(Unique(in_tensor, &out, &out_idx, &out_cnt));
output->push_back(out);
output->push_back(out_idx);
output->push_back(out_cnt);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_
#include <limits>
#include <vector>
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
namespace mindspore {
namespace dataset {
class UniqueOp : public TensorOp {
public:
UniqueOp() = default;
~UniqueOp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
uint32_t NumOutput() override { return 0; }
std::string Name() const override { return kUniqueOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_UNIQUE_OP_H_

View File

@ -125,6 +125,7 @@ constexpr char kPadEndOp[] = "PadEndOp";
constexpr char kSliceOp[] = "SliceOp";
constexpr char kToFloat16Op[] = "ToFloat16Op";
constexpr char kTypeCastOp[] = "TypeCastOp";
constexpr char kUniqueOp[] = "UniqueOp";
// other
constexpr char kCFuncOp[] = "CFuncOp";

View File

@ -296,6 +296,37 @@ class Duplicate(cde.DuplicateOp):
"""
class Unique(cde.UniqueOp):
"""
Return an output tensor containing all the unique elements of the input tensor in
the same order that they occur in the input tensor.
Also return an index tensor that contains the index of each element of the
input tensor in the Unique output tensor.
Finally, return a count tensor that constains the count of each element of
the output tensor in the input tensor.
Note:
Call batch op before calling this function.
Examples:
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
>>>
>>> # Data before
>>> # | x |
>>> # +--------------------+
>>> # | [[0,1,2], [1,2,3]] |
>>> # +--------------------+
>>> data1 = data1.map(operations=c_transforms.Unique(), input_columns=["x"],
>>> output_columns=["x", "y", "z"], column_order=["x", "y", "z"])
>>> # Data after
>>> # | x | y |z |
>>> # +---------+-----------------+---------+
>>> # | [0,1,2,3] | [0,1,2,1,2,3] | [1,2,2,1]
>>> # +---------+-----------------+---------+
"""
class Compose(cde.ComposeOp):
"""
Compose a list of transforms into a single transform.

View File

@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing unique op in DE
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
def compare(array, res, idx, cnt):
data = ds.NumpySlicesDataset([array], column_names="x")
data = data.batch(2)
data = data.map(operations=ops.Unique(), input_columns=["x"], output_columns=["x", "y", "z"],
column_order=["x", "y", "z"])
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_array_equal(res, d["x"])
np.testing.assert_array_equal(idx, d["y"])
np.testing.assert_array_equal(cnt, d["z"])
def test_duplicate_basics():
compare([0, 1, 2, 1, 2, 3], np.array([0, 1, 2, 3]),
np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1]))
compare([0.0, 1.0, 2.0, 1.0, 2.0, 3.0], np.array([0.0, 1.0, 2.0, 3.0]),
np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1]))
compare([1, 1, 1, 1, 1, 1], np.array([1]),
np.array([0, 0, 0, 0, 0, 0]), np.array([6]))
if __name__ == "__main__":
test_duplicate_basics()