remove graphengine changes

remove graphengine changes

concat op

Truncate Pair

concat_op

remove graph engine changes

ToNumberOp implementation almost done

ToNumberOp complete

ci fix

ci fix

ci fix

ci fix

ci fix

ci fix

ci fix

ci fix

ci fix

ci fix

merge conflicts
This commit is contained in:
peilinwang 2020-06-18 14:33:08 -04:00 committed by Peilin Wang
parent 685c8cec68
commit 1e36b0649f
10 changed files with 642 additions and 53 deletions

View File

@ -16,9 +16,37 @@
#include <exception>
#include "dataset/api/de_pipeline.h"
#include "dataset/kernels/no_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
#include "dataset/kernels/data/duplicate_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/one_hot_op.h"
#include "dataset/kernels/data/pad_end_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/image/bounding_box_augment_op.h"
#include "dataset/kernels/image/center_crop_op.h"
#include "dataset/kernels/image/cut_out_op.h"
#include "dataset/kernels/image/decode_op.h"
@ -27,11 +55,11 @@
#include "dataset/kernels/image/normalize_op.h"
#include "dataset/kernels/image/pad_op.h"
#include "dataset/kernels/image/random_color_adjust_op.h"
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
#include "dataset/kernels/image/random_crop_and_resize_op.h"
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
#include "dataset/kernels/image/random_crop_op.h"
#include "dataset/kernels/image/random_horizontal_flip_op.h"
#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h"
#include "dataset/kernels/image/random_horizontal_flip_op.h"
#include "dataset/kernels/image/random_resize_op.h"
#include "dataset/kernels/image/random_rotation_op.h"
#include "dataset/kernels/image/random_vertical_flip_op.h"
@ -39,42 +67,24 @@
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/image/bounding_box_augment_op.h"
#include "dataset/kernels/data/duplicate_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/pad_end_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/kernels/no_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
#include "dataset/text/kernels/lookup_op.h"
#include "dataset/text/kernels/ngram_op.h"
#include "dataset/text/kernels/to_number_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "dataset/text/kernels/wordpiece_tokenizer_op.h"
#include "dataset/text/vocab.h"
#include "dataset/text/kernels/lookup_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sequential_sample.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#ifdef ENABLE_ICU4C
#include "dataset/text/kernels/basic_tokenizer_op.h"
@ -87,16 +97,6 @@
#include "dataset/text/kernels/whitespace_tokenizer_op.h"
#endif
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sequential_sample.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
namespace py = pybind11;
namespace mindspore {
@ -542,6 +542,10 @@ void bindTensorOps4(py::module *m) {
.def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
"TensorOp to convert strings to numbers.")
.def(py::init<DataType>(), py::arg("data_type"))
.def(py::init<std::string>(), py::arg("data_type"));
}
void bindTokenizerOps(py::module *m) {

View File

@ -15,15 +15,19 @@
*/
#include "dataset/kernels/data/data_utils.h"
#include <algorithm>
#include <limits>
#include <string>
#include <vector>
#include "dataset/core/constants.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/core/data_type.h"
#include "dataset/core/pybind_support.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -330,7 +334,18 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
auto in_itr = input->begin<float>();
auto out_itr = (*output)->begin<float16>();
auto out_end = (*output)->end<float16>();
for (; out_itr != out_end; in_itr++, out_itr++) *out_itr = Eigen::half(*in_itr);
for (; out_itr != out_end; in_itr++, out_itr++) {
float element = *in_itr;
float float16_max = static_cast<float>(std::numeric_limits<Eigen::half>::max());
float float16_min = static_cast<float>(std::numeric_limits<Eigen::half>::lowest());
if (element > float16_max || element < float16_min) {
RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" +
std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
}
*out_itr = Eigen::half(*in_itr);
}
return Status::OK();
}

View File

@ -18,5 +18,6 @@ add_library(text-kernels OBJECT
ngram_op.cc
wordpiece_tokenizer_op.cc
truncate_sequence_pair_op.cc
to_number_op.cc
${ICU_DEPEND_FILES}
)

View File

@ -0,0 +1,241 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/text/kernels/to_number_op.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "dataset/core/data_type.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {}
ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {}
Status ToNumberOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string.");
switch (cast_to_type_.value()) {
case DataType::DE_INT8:
RETURN_IF_NOT_OK(ToSignedIntegral<int8_t>(input, output));
break;
case DataType::DE_INT16:
RETURN_IF_NOT_OK(ToSignedIntegral<int16_t>(input, output));
break;
case DataType::DE_INT32:
RETURN_IF_NOT_OK(ToSignedIntegral<int32_t>(input, output));
break;
case DataType::DE_INT64:
RETURN_IF_NOT_OK(ToSignedIntegral<int64_t>(input, output));
break;
case DataType::DE_UINT8:
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint8_t>(input, output));
break;
case DataType::DE_UINT16:
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint16_t>(input, output));
break;
case DataType::DE_UINT32:
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint32_t>(input, output));
break;
case DataType::DE_UINT64:
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint64_t>(input, output));
break;
case DataType::DE_FLOAT16:
RETURN_IF_NOT_OK(this->ToFloat16(input, output));
break;
case DataType::DE_FLOAT32:
RETURN_IF_NOT_OK(ToFloat(input, output));
break;
case DataType::DE_FLOAT64:
RETURN_IF_NOT_OK(ToDouble(input, output));
break;
}
return Status::OK();
}
void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; }
Status ToNumberOp::OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) {
(void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes));
return Status::OK();
}
template <typename T>
Status ToNumberOp::ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
std::vector<T> casted;
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
bool is_cast_out_of_range = false;
int64_t result = 0;
try {
result = std::stoll(std::string(*it));
} catch (const std::out_of_range &) {
is_cast_out_of_range = true;
} catch (const std::invalid_argument &) {
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number.");
}
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
cast_to_type_.ToString() + ". The valid range is: [" +
std::to_string(std::numeric_limits<T>::min()) + ", " +
std::to_string(std::numeric_limits<T>::max()) + "].";
RETURN_STATUS_UNEXPECTED(error_message);
}
T casted_result = static_cast<T>(result);
casted.push_back(casted_result);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
return Status::OK();
}
template <typename T>
Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
std::vector<T> casted;
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
bool is_cast_out_of_range = false;
uint64_t result = 0;
// If there is a - at the start of the string, it is considered by us to
// be out of bounds. If the - is somewhere else in the string, it is
// deemed invalid by std::stoull and will throw std::invalid_argument
for (int i = 0; i < (*it).size(); i++) {
if ((*it)[i] == '-') {
is_cast_out_of_range = true;
break;
}
}
try {
result = std::stoull(std::string(*it));
} catch (const std::out_of_range &) {
is_cast_out_of_range = true;
} catch (const std::invalid_argument &) {
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
}
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
cast_to_type_.ToString() + ". The valid range is: [" +
std::to_string(std::numeric_limits<T>::min()) + ", " +
std::to_string(std::numeric_limits<T>::max()) + "].";
RETURN_STATUS_UNEXPECTED(error_message);
}
T casted_result = static_cast<T>(result);
casted.push_back(casted_result);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
return Status::OK();
}
Status ToNumberOp::ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
// special case, float16 does not exist in c++, no native support for
// casting, so cast to float first then use this method, which use Eigen.
std::shared_ptr<Tensor> temp;
RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32")));
RETURN_IF_NOT_OK(ToFloat(input, &temp));
RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output));
return Status::OK();
}
Status ToNumberOp::ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
std::vector<float> casted;
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
bool is_cast_out_of_range = false;
float result = 0;
try {
result = std::stof(std::string(*it));
} catch (const std::out_of_range &) {
is_cast_out_of_range = true;
} catch (const std::invalid_argument &) {
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
}
if (result > std::numeric_limits<float>::max() || result < std::numeric_limits<float>::lowest() ||
is_cast_out_of_range) {
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
cast_to_type_.ToString() + ". The valid range is: [" +
std::to_string(std::numeric_limits<float>::lowest()) + ", " +
std::to_string(std::numeric_limits<float>::max()) + "].";
RETURN_STATUS_UNEXPECTED(error_message);
}
float casted_result = static_cast<float>(result);
casted.push_back(casted_result);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
return Status::OK();
}
Status ToNumberOp::ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
std::vector<double> casted;
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
bool is_cast_out_of_range = false;
double result = 0;
try {
result = std::stod(std::string(*it));
} catch (const std::out_of_range &) {
is_cast_out_of_range = true;
} catch (const std::invalid_argument &) {
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
}
if (result > std::numeric_limits<double>::max() || result < std::numeric_limits<double>::lowest() ||
is_cast_out_of_range) {
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
cast_to_type_.ToString() + ". The valid range is: [" +
std::to_string(std::numeric_limits<double>::lowest()) + ", " +
std::to_string(std::numeric_limits<double>::max()) + "].";
RETURN_STATUS_UNEXPECTED(error_message);
}
double casted_result = static_cast<double>(result);
casted.push_back(casted_result);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/core/data_type.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class ToNumberOp : public TensorOp {
public:
// Constructor of ToNumberOp
// @param const DataType &cast_to_type - the type to convert string inputs to.
explicit ToNumberOp(const DataType &cast_to_type);
// Constructor of ToNumberOp
// @param const std::string &cast_to_type - the type in string form to convert string inputs to.
explicit ToNumberOp(const std::string &cast_to_type);
~ToNumberOp() override = default;
// Perform numeric conversion on each string in each tensor.
// @param const std::shared_ptr<Tensor> &input
// @param std::shared_ptr<Tensor> *output
// @return error code
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// For each input shape, find the output shape
// @param std::vector<TensorShape> &inputs - shape of input tensors
// @param std::vector<TensorShape> &outputs - shape of output tensors
// @return error code
Status OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) override;
// print arg for debugging
// @param std::ostream &out
void Print(std::ostream &out) const override;
private:
template <typename T>
Status ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
template <typename T>
Status ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
Status ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
Status ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
DataType cast_to_type_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_

View File

@ -16,12 +16,13 @@
mindspore.dataset.text
"""
import platform
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
ToNumber
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
__all__ = [
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
"to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair"
"to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber"
]
if platform.system().lower() != 'windows':

View File

@ -23,7 +23,9 @@ import mindspore._c_dataengine as cde
from .utils import JiebaMode, NormalizeForm
from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate
check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \
check_to_number
from ..core.datatypes import mstype_to_detype
class Lookup(cde.LookupOp):
@ -379,3 +381,28 @@ class TruncateSequencePair(cde.TruncateSequencePairOp):
@check_pair_truncate
def __init__(self, max_length):
super().__init__(max_length)
class ToNumber(cde.ToNumberOp):
"""
Tensor operation to convert every element of a string tensor to a number.
Strings are casted according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be casted to an
unsigned integer type.
Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
a numeric type.
Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted.
"""
@check_to_number
def __init__(self, data_type):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)
super().__init__(data_type)

View File

@ -19,7 +19,9 @@ validators for text ops
from functools import wraps
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
from mindspore._c_expression import typing
from ..transforms.validators import check_uint32, check_pos_int64
@ -384,3 +386,28 @@ def check_pair_truncate(method):
return method(self, **kwargs)
return new_method
def check_to_number(method):
"""A wrapper that wraps a parameter check to the original function (ToNumber)."""
@wraps(method)
def new_method(self, *args, **kwargs):
data_type = (list(args) + [None])[0]
if "data_type" in kwargs:
data_type = kwargs.get("data_type")
if data_type is None:
raise ValueError("data_type is a mandatory parameter but was not provided.")
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
if not data_type in mstype.number_type:
raise TypeError("data_type is not numeric data type.")
kwargs["data_type"] = data_type
return method(self, **kwargs)
return new_method

View File

@ -88,15 +88,15 @@ if __name__ == '__main__':
if not os.path.isdir(args_opt.mindrecord_dir):
os.makedirs(args_opt.mindrecord_dir)
prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
yolo_prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(args_opt.mindrecord_dir, yolo_prefix + "0")
if not os.path.exists(mindrecord_file):
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
print("Create Mindrecord")
data_to_mindrecord_byte_image(args_opt.image_dir,
args_opt.anno_path,
args_opt.mindrecord_dir,
prefix=prefix,
prefix=yolo_prefix,
file_num=8)
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
else:

View File

@ -0,0 +1,194 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.text as text
np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
np.uint32, np.uint64]
ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8,
mstype.uint16, mstype.uint32, mstype.uint64]
np_non_integral_types = [np.float16, np.float32, np.float64]
ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64]
def string_dataset_generator(strings):
for string in strings:
yield (np.array(string, dtype='S'),)
def test_to_number_typical_case_integral():
input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
["-1726483716", "98921728421"]]
for ms_type, inputs in zip(ms_integral_types, input_strings):
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
expected_output = [int(string) for string in inputs]
output = []
for data in dataset.create_dict_iterator():
output.append(data["strings"])
assert output == expected_output
def test_to_number_typical_case_non_integral():
input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]]
epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]
for ms_type, inputs in zip(ms_non_integral_types, input_strings):
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
expected_output = [float(string) for string in inputs]
output = []
for data in dataset.create_dict_iterator():
output.append(data["strings"])
for expected, actual, epsilon in zip(expected_output, output, epsilons):
assert abs(expected - actual) < epsilon
def out_of_bounds_error_message_check(dataset, np_type, value_to_cast):
type_info = np.iinfo(np_type)
type_max = str(type_info.max)
type_min = str(type_info.min)
type_name = str(np.dtype(np_type))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "String input " + value_to_cast + " will be out of bounds if casted to " + type_name in str(info.value)
assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value)
def test_to_number_out_of_bounds_integral():
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
type_info = np.iinfo(np_type)
input_strings = [str(type_info.max + 10)]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
input_strings = [str(type_info.min - 10)]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
def test_to_number_out_of_bounds_non_integral():
above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"]
input_strings = [above_range[0]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "outside of valid float16 range" in str(info.value)
input_strings = [above_range[1]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
input_strings = [above_range[2]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"]
input_strings = [below_range[0]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "outside of valid float16 range" in str(info.value)
input_strings = [below_range[1]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
input_strings = [below_range[2]]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
def test_to_number_boundaries_integral():
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
type_info = np.iinfo(np_type)
input_strings = [str(type_info.max)]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
for data in dataset.create_dict_iterator():
assert data["strings"] == int(input_strings[0])
input_strings = [str(type_info.min)]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
for data in dataset.create_dict_iterator():
assert data["strings"] == int(input_strings[0])
input_strings = [str(0)]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
for data in dataset.create_dict_iterator():
assert data["strings"] == int(input_strings[0])
def test_to_number_invalid_input():
input_strings = ["a8fa9ds8fa"]
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(mstype.int32))
with pytest.raises(RuntimeError) as info:
for _ in dataset.create_dict_iterator():
pass
assert "It is invalid to convert " + input_strings[0] + " to a number" in str(info.value)
if __name__ == '__main__':
test_to_number_typical_case_integral()
test_to_number_typical_case_non_integral()
test_to_number_boundaries_integral()
test_to_number_out_of_bounds_integral()
test_to_number_out_of_bounds_non_integral()
test_to_number_invalid_input()