From dd9bf09f0a38076fb28d44382ba3cedb4ebfedbd Mon Sep 17 00:00:00 2001 From: nhussain Date: Thu, 11 Jun 2020 11:50:38 -0400 Subject: [PATCH] added FillOp for #119 - special Ops --- .../ccsrc/dataset/api/python_bindings.cc | 5 + .../ccsrc/dataset/kernels/data/CMakeLists.txt | 2 +- .../ccsrc/dataset/kernels/data/data_utils.cc | 112 +++++++++++ .../ccsrc/dataset/kernels/data/data_utils.h | 7 + .../ccsrc/dataset/kernels/data/fill_op.cc | 31 +++ .../ccsrc/dataset/kernels/data/fill_op.h | 47 +++++ mindspore/dataset/transforms/c_transforms.py | 20 +- mindspore/dataset/transforms/validators.py | 20 +- tests/ut/cpp/dataset/CMakeLists.txt | 1 + tests/ut/cpp/dataset/fill_op_test.cc | 183 ++++++++++++++++++ tests/ut/cpp/dataset/queue_test.cc | 47 ++--- tests/ut/python/dataset/test_fill_op.py | 95 +++++++++ 12 files changed, 540 insertions(+), 30 deletions(-) create mode 100644 mindspore/ccsrc/dataset/kernels/data/fill_op.cc create mode 100644 mindspore/ccsrc/dataset/kernels/data/fill_op.h create mode 100644 tests/ut/cpp/dataset/fill_op_test.cc create mode 100644 tests/ut/python/dataset/test_fill_op.py diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 5c574844b90..0dc6f2b1f4a 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -38,6 +38,7 @@ #include "dataset/kernels/image/resize_op.h" #include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/data/type_cast_op.h" +#include "dataset/kernels/data/fill_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/io_block.h" @@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) { *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") .def(py::init()); + (void)py::class_>( + *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") + .def(py::init>()); + (void)py::class_>( *m, "RandomRotationOp", "Tensor operation to apply RandomRotation." diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt index 8472ab51929..8c03b300ee7 100644 --- a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt @@ -5,4 +5,4 @@ add_library(kernels-data OBJECT one_hot_op.cc type_cast_op.cc to_float16_op.cc - ) + fill_op.cc) diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc index c20d9a4c757..85c4cfc67ce 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc @@ -23,6 +23,7 @@ #include "dataset/core/tensor_shape.h" #include "dataset/core/data_type.h" #include "dataset/core/pybind_support.h" +#include "dataset/kernels/data/type_cast_op.h" namespace mindspore { namespace dataset { @@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_pt Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { input->Squeeze(); + if (input->Rank() > 1) { // We expect the input to be int he first dimension RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); } @@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *ou } } +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { + CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), + "Types do not match"); + + CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); + + std::shared_ptr out; + + const DataType &to = input->type(); + std::unique_ptr op(new TypeCastOp(to)); + + std::shared_ptr fill_output; + op->Compute(fill_value, &fill_output); + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); + + switch (input->type().value()) { + case DataType::DE_BOOL: { + bool value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT8: { + int8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT8: { + uint8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT16: { + uint16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT16: { + int16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT32: { + uint32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT32: { + int32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT64: { + uint64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT64: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT16: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT32: { + float value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT64: { + double value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_STRING: { + std::vector strings; + std::string_view fill_string_view; + RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); + std::string fill_string = std::string(fill_string_view); + for (int i = 0; i < input->shape().NumOfElements(); i++) { + strings.emplace_back(fill_string); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); + break; + } + case DataType::DE_UNKNOWN: { + RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); + break; + } + } + + *output = out; + return Status::OK(); +} template void Cast(const std::shared_ptr &input, std::shared_ptr *output) { auto in_itr = input->begin(); auto out_itr = (*output)->begin(); auto out_end = (*output)->end(); + for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) *out_itr = static_cast(*in_itr); } diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/dataset/kernels/data/data_utils.h index bfd51412278..f2faee02dc8 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.h @@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, int64_t index); +// Returns a tensor of shape input filled with the passed fill_value +// @param input Tensor +// @param output Tensor. The shape and type of the output tensor is same as input +// @param fill_value Tensor. A scalar tensor used to fill the output tensor + +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value); + // Returns a type changed input tensor. // Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp // @param input Tensor diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.cc b/mindspore/ccsrc/dataset/kernels/data/fill_op.cc new file mode 100644 index 00000000000..b0a9a370fb9 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/fill_op.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/kernels/data/fill_op.h" + +#include "dataset/core/tensor.h" +#include "dataset/kernels/data/data_utils.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status FillOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = Fill(input, output, fill_value_); + return s; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/dataset/kernels/data/fill_op.h new file mode 100644 index 00000000000..b5333b2367b --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/fill_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ +#define DATASET_KERNELS_DATA_FILL_OP_H_ + +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class FillOp : public TensorOp { + public: + explicit FillOp(std::shared_ptr value) : fill_value_(value) {} + + ~FillOp() override = default; + void Print(std::ostream &out) const override { out << "FillOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + private: + std::shared_ptr fill_value_; +}; + +} // namespace dataset + +} // namespace mindspore + +#endif // MINDSPORE_FILL_OP_H diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 91fb4865317..8f301f196e2 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -15,9 +15,9 @@ """ This module c_transforms provides common operations, including OneHotOp and TypeCast. """ +import numpy as np import mindspore._c_dataengine as cde - -from .validators import check_num_classes, check_de_type +from .validators import check_num_classes, check_de_type, check_fill_value from ..core.datatypes import mstype_to_detype @@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp): super().__init__(num_classes) +class Fill(cde.FillOp): + """ + Tensor operation to create a tensor filled with passed scalar value. + The output tensor will have the same shape and type as the input tensor. + + Args: + fill_value (python types (str, int, float, or bool)) : scalar value + to fill created tensor with. + """ + + @check_fill_value + def __init__(self, fill_value): + print(fill_value) + super().__init__(cde.Tensor(np.array(fill_value))) + + class TypeCast(cde.TypeCastOp): """ Tensor operation to cast to a given MindSpore data type. diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 5572e5285e3..a7eb589cd7d 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -17,7 +17,6 @@ from functools import wraps from mindspore._c_expression import typing - # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 UINT8_MAX = 255 @@ -159,6 +158,25 @@ def check_num_classes(method): return new_method +def check_fill_value(method): + """Wrapper method to check the parameters of fill value.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + fill_value = (list(args) + [None])[0] + if "fill_value" in kwargs: + fill_value = kwargs.get("fill_value") + if fill_value is None: + raise ValueError("fill_value is not provided.") + if not isinstance(fill_value, (str, float, bool, int)): + raise TypeError("fill_value must be either a primitive python str, float, bool, or int") + kwargs["fill_value"] = fill_value + + return method(self, **kwargs) + + return new_method + + def check_de_type(method): """Wrapper method to check the parameters of data type.""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 1691aa3de58..6c126903233 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -72,6 +72,7 @@ SET(DE_UT_SRCS tokenizer_op_test.cc gnn_graph_test.cc coco_op_test.cc + fill_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/fill_op_test.cc b/tests/ut/cpp/dataset/fill_op_test.cc new file mode 100644 index 00000000000..d43b7d75489 --- /dev/null +++ b/tests/ut/cpp/dataset/fill_op_test.cc @@ -0,0 +1,183 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "dataset/kernels/data/fill_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestFillOp : public UT::Common { + protected: + MindDataTestFillOp() {} +}; + +TEST_F(MindDataTestFillOp, TestOp) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp."; + uint64_t labels[3] = {1, 1, 2}; + TensorShape shape({3}); + std::shared_ptr input = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + + TensorShape fill_shape({}); + std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_UINT64)); + fill_tensor->SetItemAt({}, 4); + + std::shared_ptr output; + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + uint64_t out[3] = {4, 4, 4}; + + std::shared_ptr expected = + std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + + EXPECT_TRUE(s.IsOk()); + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + + ASSERT_TRUE(*output == *expected); + MS_LOG(INFO) << "MindDataTestFillOp-TestOp end."; +} + +TEST_F(MindDataTestFillOp, TestCasting) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting."; + uint64_t labels[3] = {0, 1, 2}; + TensorShape shape({3}); + std::shared_ptr input = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + + TensorShape fill_shape({}); + std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); + fill_tensor->SetItemAt({}, 2.0); + + std::shared_ptr output; + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + uint64_t out[3] = {2, 2, 2}; + + std::shared_ptr expected = + std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + + EXPECT_TRUE(s.IsOk()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + ASSERT_TRUE(*output == *expected); + + MS_LOG(INFO) << "MindDataTestFillOp-TestCasting end."; +} + +TEST_F(MindDataTestFillOp, ScalarFill) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill."; + uint64_t labels[3] = {0, 1, 2}; + TensorShape shape({3}); + std::shared_ptr input = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + + TensorShape fill_shape({2}); + uint64_t fill_labels[3] = {0, 1}; + std::shared_ptr fill_tensor = + std::make_shared(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast(fill_labels)); + std::shared_ptr output; + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + EXPECT_TRUE(s.IsError()); + ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); + + MS_LOG(INFO) << "MindDataTestFillOp-ScalarFill end."; +} + +TEST_F(MindDataTestFillOp, StringFill) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill."; + std::vector strings = {"xyzzy", "plugh", "abracadabra"}; + TensorShape shape({3}); + std::shared_ptr input = std::make_shared(strings, shape); + + TensorShape fill_shape({}); + std::string fill_string = "hello"; + std::shared_ptr fill_tensor = std::make_shared(fill_string); + + std::shared_ptr output; + + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + std::vector expected_strings = {"hello", "hello", "hello"}; + TensorShape expected_shape({3}); + std::shared_ptr expected = std::make_shared(expected_strings, expected_shape); + + EXPECT_TRUE(s.IsOk()); + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + + ASSERT_TRUE(*output == *expected); + + MS_LOG(INFO) << "MindDataTestFillOp-StringFill end."; +} + +TEST_F(MindDataTestFillOp, NumericToString) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString."; + std::vector strings = {"xyzzy", "plugh", "abracadabra"}; + TensorShape shape({3}); + std::shared_ptr input = std::make_shared(strings, shape); + + TensorShape fill_shape({}); + std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); + fill_tensor->SetItemAt({}, 2.0); + + std::shared_ptr output; + + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + EXPECT_TRUE(s.IsError()); + ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); + + MS_LOG(INFO) << "MindDataTestFillOp-NumericToString end."; +} + +TEST_F(MindDataTestFillOp, StringToNumeric) { + MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric."; + uint64_t labels[3] = {0, 1, 2}; + TensorShape shape({3}); + std::shared_ptr input = + std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + + TensorShape fill_shape({}); + std::string fill_string = "hello"; + std::shared_ptr fill_tensor = std::make_shared(fill_string); + + std::shared_ptr output; + + std::unique_ptr op(new FillOp(fill_tensor)); + Status s = op->Compute(input, &output); + + EXPECT_TRUE(s.IsError()); + ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); + + MS_LOG(INFO) << "MindDataTestFillOp-StringToNumeric end."; +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/queue_test.cc b/tests/ut/cpp/dataset/queue_test.cc index 00366fcafdc..578405e5370 100644 --- a/tests/ut/cpp/dataset/queue_test.cc +++ b/tests/ut/cpp/dataset/queue_test.cc @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// -// Created by jesse on 10/3/19. -// #include "common/common.h" #include "gtest/gtest.h" @@ -25,32 +22,32 @@ #include "utils/log_adapter.h" using namespace mindspore::dataset; -using mindspore::MsLogLevel::INFO; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; class MindDataTestQueue : public UT::Common { public: - MindDataTestQueue() {} + MindDataTestQueue() {} - void SetUp() {} + void SetUp() {} }; int gRefCountDestructorCalled; class RefCount { public: - RefCount() : v_(nullptr) {} - explicit RefCount(int x) : v_(std::make_shared(x)) {} - explicit RefCount(const RefCount &o) : v_(o.v_) {} - ~RefCount() { - MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; - gRefCountDestructorCalled++; - } - RefCount& operator=(const RefCount &o) { - v_ = o.v_; - return *this; - } + RefCount() : v_(nullptr) {} + explicit RefCount(int x) : v_(std::make_shared(x)) {} + explicit RefCount(const RefCount &o) : v_(o.v_) {} + ~RefCount() { + MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; + gRefCountDestructorCalled++; + } + RefCount &operator=(const RefCount &o) { + v_ = o.v_; + return *this; + } std::shared_ptr v_; }; @@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) { // Use count should remain 2. a and b. No copy in the queue. ASSERT_EQ(a.use_count(), 2); a.reset(new int(5)); - ASSERT_EQ(a.use_count(),1); + ASSERT_EQ(a.use_count(), 1); // Push again but expect a is nullptr after push rc = que.Add(std::move(a)); ASSERT_TRUE(rc.IsOk()); - ASSERT_EQ(a.use_count(),0); + ASSERT_EQ(a.use_count(), 0); rc = que.PopFront(&b); ASSERT_TRUE(rc.IsOk()); ASSERT_EQ(*b, 5); - ASSERT_EQ(b.use_count(),1); + ASSERT_EQ(b.use_count(), 1); // Test construct in place rc = que.EmplaceBack(std::make_shared(100)); ASSERT_TRUE(rc.IsOk()); rc = que.PopFront(&b); ASSERT_TRUE(rc.IsOk()); ASSERT_EQ(*b, 100); - ASSERT_EQ(b.use_count(),1); + ASSERT_EQ(b.use_count(), 1); // Test the destructor of the Queue by add an element in the queue without popping it and let the queue go // out of scope. rc = que.EmplaceBack(std::make_shared(2000)); @@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) { ASSERT_EQ(*b, 40); } -void test4(){ +void test4() { gRefCountDestructorCalled = 0; // Pass a structure along the queue. Queue que(3); @@ -144,9 +141,7 @@ void test4(){ ASSERT_TRUE(rc.IsOk()); } -TEST_F(MindDataTestQueue, Test4) { - test4(); -} +TEST_F(MindDataTestQueue, Test4) { test4(); } TEST_F(MindDataTestQueue, Test5) { test4(); diff --git a/tests/ut/python/dataset/test_fill_op.py b/tests/ut/python/dataset/test_fill_op.py new file mode 100644 index 00000000000..f138dd15ec9 --- /dev/null +++ b/tests/ut/python/dataset/test_fill_op.py @@ -0,0 +1,95 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing fill op +""" +import numpy as np +import pytest +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as data_trans + + +def test_fillop_basic(): + def gen(): + yield (np.array([4, 5, 6, 7], dtype=np.uint8),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + fill_op = data_trans.Fill(3) + + data = data.map(input_columns=["col"], operations=fill_op) + expected = np.array([3, 3, 3, 3], dtype=np.uint8) + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_fillop_down_type_cast(): + def gen(): + yield (np.array([4, 5, 6, 7], dtype=np.uint8),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + fill_op = data_trans.Fill(-3) + + data = data.map(input_columns=["col"], operations=fill_op) + expected = np.array([253, 253, 253, 253], dtype=np.uint8) + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_fillop_up_type_cast(): + def gen(): + yield (np.array([4, 5, 6, 7], dtype=np.float),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + fill_op = data_trans.Fill(3) + + data = data.map(input_columns=["col"], operations=fill_op) + expected = np.array([3., 3., 3., 3.], dtype=np.float) + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_fillop_string(): + def gen(): + yield (np.array(["45555", "45555"], dtype='S'),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + fill_op = data_trans.Fill("error") + + data = data.map(input_columns=["col"], operations=fill_op) + expected = np.array(['error', 'error'], dtype='S') + for data_row in data: + np.testing.assert_array_equal(data_row[0], expected) + + +def test_fillop_error_handling(): + def gen(): + yield (np.array([4, 4, 4, 4]),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + fill_op = data_trans.Fill("words") + data = data.map(input_columns=["col"], operations=fill_op) + + with pytest.raises(RuntimeError) as error_info: + for data_row in data: + print(data_row) + assert "Types do not match" in repr(error_info.value) + + +if __name__ == "__main__": + test_fillop_basic() + test_fillop_up_type_cast() + test_fillop_down_type_cast() + test_fillop_string() + test_fillop_error_handling()