forked from mindspore-Ecosystem/mindspore
added FillOp for #119 - special Ops
This commit is contained in:
parent
2005ecc284
commit
dd9bf09f0a
|
@ -38,6 +38,7 @@
|
|||
#include "dataset/kernels/image/resize_op.h"
|
||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
|
@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) {
|
|||
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
|
||||
.def(py::init<int32_t>());
|
||||
|
||||
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
|
||||
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
|
||||
.def(py::init<std::shared_ptr<Tensor>>());
|
||||
|
||||
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
|
||||
*m, "RandomRotationOp",
|
||||
"Tensor operation to apply RandomRotation."
|
||||
|
|
|
@ -5,4 +5,4 @@ add_library(kernels-data OBJECT
|
|||
one_hot_op.cc
|
||||
type_cast_op.cc
|
||||
to_float16_op.cc
|
||||
)
|
||||
fill_op.cc)
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/pybind_support.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_pt
|
|||
|
||||
Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
|
||||
input->Squeeze();
|
||||
|
||||
if (input->Rank() > 1) { // We expect the input to be int he first dimension
|
||||
RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
|
||||
}
|
||||
|
@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
|
|||
}
|
||||
}
|
||||
|
||||
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
|
||||
"Types do not match");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
|
||||
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
||||
const DataType &to = input->type();
|
||||
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
|
||||
|
||||
std::shared_ptr<Tensor> fill_output;
|
||||
op->Compute(fill_value, &fill_output);
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
|
||||
|
||||
switch (input->type().value()) {
|
||||
case DataType::DE_BOOL: {
|
||||
bool value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<bool>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_INT8: {
|
||||
int8_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<int8_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UINT8: {
|
||||
uint8_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<uint8_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UINT16: {
|
||||
uint16_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<uint16_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_INT16: {
|
||||
int16_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<int16_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UINT32: {
|
||||
uint32_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<uint32_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_INT32: {
|
||||
int32_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<int32_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UINT64: {
|
||||
uint64_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<uint64_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_INT64: {
|
||||
int64_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<int64_t>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_FLOAT16: {
|
||||
int64_t value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<float>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_FLOAT32: {
|
||||
float value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<float>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_FLOAT64: {
|
||||
double value = 0;
|
||||
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
|
||||
out->Fill<double>(value);
|
||||
break;
|
||||
}
|
||||
case DataType::DE_STRING: {
|
||||
std::vector<std::string> strings;
|
||||
std::string_view fill_string_view;
|
||||
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
|
||||
std::string fill_string = std::string(fill_string_view);
|
||||
for (int i = 0; i < input->shape().NumOfElements(); i++) {
|
||||
strings.emplace_back(fill_string);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
|
||||
break;
|
||||
}
|
||||
case DataType::DE_UNKNOWN: {
|
||||
RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
*output = out;
|
||||
return Status::OK();
|
||||
}
|
||||
template <typename FROM, typename TO>
|
||||
void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
auto in_itr = input->begin<FROM>();
|
||||
auto out_itr = (*output)->begin<TO>();
|
||||
auto out_end = (*output)->end<TO>();
|
||||
|
||||
for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
|
||||
*out_itr = static_cast<TO>(*in_itr);
|
||||
}
|
||||
|
|
|
@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_
|
|||
Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
|
||||
int64_t index);
|
||||
|
||||
// Returns a tensor of shape input filled with the passed fill_value
|
||||
// @param input Tensor
|
||||
// @param output Tensor. The shape and type of the output tensor is same as input
|
||||
// @param fill_value Tensor. A scalar tensor used to fill the output tensor
|
||||
|
||||
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value);
|
||||
|
||||
// Returns a type changed input tensor.
|
||||
// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp
|
||||
// @param input Tensor
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status FillOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
Status s = Fill(input, output, fill_value_);
|
||||
return s;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_KERNELS_DATA_FILL_OP_H_
|
||||
#define DATASET_KERNELS_DATA_FILL_OP_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FillOp : public TensorOp {
|
||||
public:
|
||||
explicit FillOp(std::shared_ptr<Tensor> value) : fill_value_(value) {}
|
||||
|
||||
~FillOp() override = default;
|
||||
void Print(std::ostream &out) const override { out << "FillOp"; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Tensor> fill_value_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_FILL_OP_H
|
|
@ -15,9 +15,9 @@
|
|||
"""
|
||||
This module c_transforms provides common operations, including OneHotOp and TypeCast.
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .validators import check_num_classes, check_de_type
|
||||
from .validators import check_num_classes, check_de_type, check_fill_value
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
|
||||
|
||||
|
@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp):
|
|||
super().__init__(num_classes)
|
||||
|
||||
|
||||
class Fill(cde.FillOp):
|
||||
"""
|
||||
Tensor operation to create a tensor filled with passed scalar value.
|
||||
The output tensor will have the same shape and type as the input tensor.
|
||||
|
||||
Args:
|
||||
fill_value (python types (str, int, float, or bool)) : scalar value
|
||||
to fill created tensor with.
|
||||
"""
|
||||
|
||||
@check_fill_value
|
||||
def __init__(self, fill_value):
|
||||
print(fill_value)
|
||||
super().__init__(cde.Tensor(np.array(fill_value)))
|
||||
|
||||
|
||||
class TypeCast(cde.TypeCastOp):
|
||||
"""
|
||||
Tensor operation to cast to a given MindSpore data type.
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
from functools import wraps
|
||||
from mindspore._c_expression import typing
|
||||
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
UINT8_MAX = 255
|
||||
|
@ -159,6 +158,25 @@ def check_num_classes(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_fill_value(method):
|
||||
"""Wrapper method to check the parameters of fill value."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
fill_value = (list(args) + [None])[0]
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
if fill_value is None:
|
||||
raise ValueError("fill_value is not provided.")
|
||||
if not isinstance(fill_value, (str, float, bool, int)):
|
||||
raise TypeError("fill_value must be either a primitive python str, float, bool, or int")
|
||||
kwargs["fill_value"] = fill_value
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_de_type(method):
|
||||
"""Wrapper method to check the parameters of data type."""
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ SET(DE_UT_SRCS
|
|||
tokenizer_op_test.cc
|
||||
gnn_graph_test.cc
|
||||
coco_op_test.cc
|
||||
fill_op_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestFillOp : public UT::Common {
|
||||
protected:
|
||||
MindDataTestFillOp() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestFillOp, TestOp) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp.";
|
||||
uint64_t labels[3] = {1, 1, 2};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input =
|
||||
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
|
||||
|
||||
TensorShape fill_shape({});
|
||||
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64));
|
||||
fill_tensor->SetItemAt<uint64_t>({}, 4);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
uint64_t out[3] = {4, 4, 4};
|
||||
|
||||
std::shared_ptr<Tensor> expected =
|
||||
std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out));
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output->shape() == expected->shape());
|
||||
ASSERT_TRUE(output->type() == expected->type());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected << std::endl;
|
||||
|
||||
ASSERT_TRUE(*output == *expected);
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-TestOp end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestFillOp, TestCasting) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting.";
|
||||
uint64_t labels[3] = {0, 1, 2};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input =
|
||||
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
|
||||
|
||||
TensorShape fill_shape({});
|
||||
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32));
|
||||
fill_tensor->SetItemAt<float>({}, 2.0);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
uint64_t out[3] = {2, 2, 2};
|
||||
|
||||
std::shared_ptr<Tensor> expected =
|
||||
std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out));
|
||||
|
||||
ASSERT_TRUE(output->shape() == expected->shape());
|
||||
ASSERT_TRUE(output->type() == expected->type());
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected << std::endl;
|
||||
ASSERT_TRUE(*output == *expected);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-TestCasting end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestFillOp, ScalarFill) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill.";
|
||||
uint64_t labels[3] = {0, 1, 2};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input =
|
||||
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
|
||||
|
||||
TensorShape fill_shape({2});
|
||||
uint64_t fill_labels[3] = {0, 1};
|
||||
std::shared_ptr<Tensor> fill_tensor =
|
||||
std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(fill_labels));
|
||||
std::shared_ptr<Tensor> output;
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
EXPECT_TRUE(s.IsError());
|
||||
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-ScalarFill end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestFillOp, StringFill) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill.";
|
||||
std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||
|
||||
TensorShape fill_shape({});
|
||||
std::string fill_string = "hello";
|
||||
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
std::vector<std::string> expected_strings = {"hello", "hello", "hello"};
|
||||
TensorShape expected_shape({3});
|
||||
std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(expected_strings, expected_shape);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output->shape() == expected->shape());
|
||||
ASSERT_TRUE(output->type() == expected->type());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected << std::endl;
|
||||
|
||||
ASSERT_TRUE(*output == *expected);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-StringFill end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestFillOp, NumericToString) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString.";
|
||||
std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||
|
||||
TensorShape fill_shape({});
|
||||
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32));
|
||||
fill_tensor->SetItemAt<float>({}, 2.0);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
EXPECT_TRUE(s.IsError());
|
||||
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-NumericToString end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestFillOp, StringToNumeric) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric.";
|
||||
uint64_t labels[3] = {0, 1, 2};
|
||||
TensorShape shape({3});
|
||||
std::shared_ptr<Tensor> input =
|
||||
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
|
||||
|
||||
TensorShape fill_shape({});
|
||||
std::string fill_string = "hello";
|
||||
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
|
||||
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
EXPECT_TRUE(s.IsError());
|
||||
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestFillOp-StringToNumeric end.";
|
||||
}
|
|
@ -13,9 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
//
|
||||
// Created by jesse on 10/3/19.
|
||||
//
|
||||
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
@ -25,32 +22,32 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestQueue : public UT::Common {
|
||||
public:
|
||||
MindDataTestQueue() {}
|
||||
MindDataTestQueue() {}
|
||||
|
||||
void SetUp() {}
|
||||
void SetUp() {}
|
||||
};
|
||||
|
||||
int gRefCountDestructorCalled;
|
||||
|
||||
class RefCount {
|
||||
public:
|
||||
RefCount() : v_(nullptr) {}
|
||||
explicit RefCount(int x) : v_(std::make_shared<int>(x)) {}
|
||||
explicit RefCount(const RefCount &o) : v_(o.v_) {}
|
||||
~RefCount() {
|
||||
MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl;
|
||||
gRefCountDestructorCalled++;
|
||||
}
|
||||
RefCount& operator=(const RefCount &o) {
|
||||
v_ = o.v_;
|
||||
return *this;
|
||||
}
|
||||
RefCount() : v_(nullptr) {}
|
||||
explicit RefCount(int x) : v_(std::make_shared<int>(x)) {}
|
||||
explicit RefCount(const RefCount &o) : v_(o.v_) {}
|
||||
~RefCount() {
|
||||
MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl;
|
||||
gRefCountDestructorCalled++;
|
||||
}
|
||||
RefCount &operator=(const RefCount &o) {
|
||||
v_ = o.v_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::shared_ptr<int> v_;
|
||||
};
|
||||
|
@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) {
|
|||
// Use count should remain 2. a and b. No copy in the queue.
|
||||
ASSERT_EQ(a.use_count(), 2);
|
||||
a.reset(new int(5));
|
||||
ASSERT_EQ(a.use_count(),1);
|
||||
ASSERT_EQ(a.use_count(), 1);
|
||||
// Push again but expect a is nullptr after push
|
||||
rc = que.Add(std::move(a));
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_EQ(a.use_count(),0);
|
||||
ASSERT_EQ(a.use_count(), 0);
|
||||
rc = que.PopFront(&b);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_EQ(*b, 5);
|
||||
ASSERT_EQ(b.use_count(),1);
|
||||
ASSERT_EQ(b.use_count(), 1);
|
||||
// Test construct in place
|
||||
rc = que.EmplaceBack(std::make_shared<int>(100));
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = que.PopFront(&b);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_EQ(*b, 100);
|
||||
ASSERT_EQ(b.use_count(),1);
|
||||
ASSERT_EQ(b.use_count(), 1);
|
||||
// Test the destructor of the Queue by add an element in the queue without popping it and let the queue go
|
||||
// out of scope.
|
||||
rc = que.EmplaceBack(std::make_shared<int>(2000));
|
||||
|
@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) {
|
|||
ASSERT_EQ(*b, 40);
|
||||
}
|
||||
|
||||
void test4(){
|
||||
void test4() {
|
||||
gRefCountDestructorCalled = 0;
|
||||
// Pass a structure along the queue.
|
||||
Queue<RefCount> que(3);
|
||||
|
@ -144,9 +141,7 @@ void test4(){
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestQueue, Test4) {
|
||||
test4();
|
||||
}
|
||||
TEST_F(MindDataTestQueue, Test4) { test4(); }
|
||||
|
||||
TEST_F(MindDataTestQueue, Test5) {
|
||||
test4();
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing fill op
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
|
||||
|
||||
def test_fillop_basic():
|
||||
def gen():
|
||||
yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
fill_op = data_trans.Fill(3)
|
||||
|
||||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
expected = np.array([3, 3, 3, 3], dtype=np.uint8)
|
||||
for data_row in data:
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
def test_fillop_down_type_cast():
|
||||
def gen():
|
||||
yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
fill_op = data_trans.Fill(-3)
|
||||
|
||||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
expected = np.array([253, 253, 253, 253], dtype=np.uint8)
|
||||
for data_row in data:
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
def test_fillop_up_type_cast():
|
||||
def gen():
|
||||
yield (np.array([4, 5, 6, 7], dtype=np.float),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
fill_op = data_trans.Fill(3)
|
||||
|
||||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
expected = np.array([3., 3., 3., 3.], dtype=np.float)
|
||||
for data_row in data:
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
def test_fillop_string():
|
||||
def gen():
|
||||
yield (np.array(["45555", "45555"], dtype='S'),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
fill_op = data_trans.Fill("error")
|
||||
|
||||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
expected = np.array(['error', 'error'], dtype='S')
|
||||
for data_row in data:
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
def test_fillop_error_handling():
|
||||
def gen():
|
||||
yield (np.array([4, 4, 4, 4]),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
fill_op = data_trans.Fill("words")
|
||||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for data_row in data:
|
||||
print(data_row)
|
||||
assert "Types do not match" in repr(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fillop_basic()
|
||||
test_fillop_up_type_cast()
|
||||
test_fillop_down_type_cast()
|
||||
test_fillop_string()
|
||||
test_fillop_error_handling()
|
Loading…
Reference in New Issue