diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 5391ad7cb3e..1314e2c09ec 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -16,9 +16,37 @@ #include #include "dataset/api/de_pipeline.h" -#include "dataset/kernels/no_op.h" +#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" +#include "dataset/engine/datasetops/source/coco_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/engine/datasetops/source/manifest_op.h" +#include "dataset/engine/datasetops/source/mindrecord_op.h" +#include "dataset/engine/datasetops/source/mnist_op.h" +#include "dataset/engine/datasetops/source/random_data_op.h" +#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" +#include "dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/engine/gnn/graph.h" +#include "dataset/engine/jagged_connector.h" #include "dataset/kernels/data/concatenate_op.h" +#include "dataset/kernels/data/duplicate_op.h" +#include "dataset/kernels/data/fill_op.h" +#include "dataset/kernels/data/mask_op.h" #include "dataset/kernels/data/one_hot_op.h" +#include "dataset/kernels/data/pad_end_op.h" +#include "dataset/kernels/data/slice_op.h" +#include "dataset/kernels/data/to_float16_op.h" +#include "dataset/kernels/data/type_cast_op.h" +#include "dataset/kernels/image/bounding_box_augment_op.h" #include "dataset/kernels/image/center_crop_op.h" #include "dataset/kernels/image/cut_out_op.h" #include "dataset/kernels/image/decode_op.h" @@ -27,11 +55,11 @@ #include "dataset/kernels/image/normalize_op.h" #include "dataset/kernels/image/pad_op.h" #include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" #include "dataset/kernels/image/random_crop_and_resize_op.h" +#include "dataset/kernels/image/random_crop_decode_resize_op.h" #include "dataset/kernels/image/random_crop_op.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" #include "dataset/kernels/image/random_horizontal_flip_bbox_op.h" +#include "dataset/kernels/image/random_horizontal_flip_op.h" #include "dataset/kernels/image/random_resize_op.h" #include "dataset/kernels/image/random_rotation_op.h" #include "dataset/kernels/image/random_vertical_flip_op.h" @@ -39,42 +67,24 @@ #include "dataset/kernels/image/resize_bilinear_op.h" #include "dataset/kernels/image/resize_op.h" #include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/data/duplicate_op.h" -#include "dataset/kernels/data/fill_op.h" -#include "dataset/kernels/data/mask_op.h" -#include "dataset/kernels/data/pad_end_op.h" -#include "dataset/kernels/data/slice_op.h" -#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/python_sampler.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/gnn/graph.h" -#include "dataset/kernels/data/to_float16_op.h" +#include "dataset/kernels/no_op.h" #include "dataset/text/kernels/jieba_tokenizer_op.h" +#include "dataset/text/kernels/lookup_op.h" #include "dataset/text/kernels/ngram_op.h" +#include "dataset/text/kernels/to_number_op.h" #include "dataset/text/kernels/unicode_char_tokenizer_op.h" #include "dataset/text/kernels/wordpiece_tokenizer_op.h" #include "dataset/text/vocab.h" -#include "dataset/text/kernels/lookup_op.h" +#include "dataset/util/random.h" +#include "mindrecord/include/shard_distributed_sample.h" +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_pk_sample.h" +#include "mindrecord/include/shard_sample.h" +#include "mindrecord/include/shard_sequential_sample.h" +#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" #ifdef ENABLE_ICU4C #include "dataset/text/kernels/basic_tokenizer_op.h" @@ -87,16 +97,6 @@ #include "dataset/text/kernels/whitespace_tokenizer_op.h" #endif -#include "dataset/util/random.h" -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_pk_sample.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_sequential_sample.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "pybind11/stl_bind.h" - namespace py = pybind11; namespace mindspore { @@ -542,6 +542,10 @@ void bindTensorOps4(py::module *m) { .def(py::init(), py::arg("padTop"), py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); + (void)py::class_>(*m, "ToNumberOp", + "TensorOp to convert strings to numbers.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); } void bindTokenizerOps(py::module *m) { diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc index 5a20926618e..103748b0cf8 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc @@ -15,15 +15,19 @@ */ #include "dataset/kernels/data/data_utils.h" + #include +#include #include #include + #include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" #include "dataset/core/data_type.h" #include "dataset/core/pybind_support.h" +#include "dataset/core/tensor.h" +#include "dataset/core/tensor_shape.h" #include "dataset/kernels/data/type_cast_op.h" +#include "dataset/util/status.h" namespace mindspore { namespace dataset { @@ -330,7 +334,18 @@ Status ToFloat16(const std::shared_ptr &input, std::shared_ptr * auto in_itr = input->begin(); auto out_itr = (*output)->begin(); auto out_end = (*output)->end(); - for (; out_itr != out_end; in_itr++, out_itr++) *out_itr = Eigen::half(*in_itr); + + for (; out_itr != out_end; in_itr++, out_itr++) { + float element = *in_itr; + float float16_max = static_cast(std::numeric_limits::max()); + float float16_min = static_cast(std::numeric_limits::lowest()); + if (element > float16_max || element < float16_min) { + RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" + + std::to_string(float16_max) + ", " + std::to_string(float16_min) + "]."); + } + + *out_itr = Eigen::half(*in_itr); + } return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt index 396d03fe44c..449bb93d8b9 100644 --- a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt @@ -18,5 +18,6 @@ add_library(text-kernels OBJECT ngram_op.cc wordpiece_tokenizer_op.cc truncate_sequence_pair_op.cc + to_number_op.cc ${ICU_DEPEND_FILES} ) diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc new file mode 100644 index 00000000000..1368684daff --- /dev/null +++ b/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/text/kernels/to_number_op.h" + +#include +#include +#include +#include +#include +#include + +#include "dataset/core/data_type.h" +#include "dataset/core/tensor.h" +#include "dataset/core/tensor_shape.h" +#include "dataset/kernels/data/data_utils.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {} + +ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {} + +Status ToNumberOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string."); + + switch (cast_to_type_.value()) { + case DataType::DE_INT8: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(this->ToFloat16(input, output)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(ToFloat(input, output)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(ToDouble(input, output)); + break; + } + + return Status::OK(); +} + +void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; } + +Status ToNumberOp::OutputShape(const std::vector &input_shapes, std::vector &output_shapes) { + (void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes)); + return Status::OK(); +} + +template +Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + int64_t result = 0; + + try { + result = std::stoll(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +template +Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + uint64_t result = 0; + + // If there is a - at the start of the string, it is considered by us to + // be out of bounds. If the - is somewhere else in the string, it is + // deemed invalid by std::stoull and will throw std::invalid_argument + for (int i = 0; i < (*it).size(); i++) { + if ((*it)[i] == '-') { + is_cast_out_of_range = true; + break; + } + } + + try { + result = std::stoull(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { + // special case, float16 does not exist in c++, no native support for + // casting, so cast to float first then use this method, which use Eigen. + std::shared_ptr temp; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); + RETURN_IF_NOT_OK(ToFloat(input, &temp)); + RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); + return Status::OK(); +} + +Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + float result = 0; + + try { + result = std::stof(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + float casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + double result = 0; + + try { + result = std::stod(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + double casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/dataset/text/kernels/to_number_op.h new file mode 100644 index 00000000000..1346ce2f474 --- /dev/null +++ b/mindspore/ccsrc/dataset/text/kernels/to_number_op.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ + +#include +#include +#include + +#include "dataset/core/data_type.h" +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class ToNumberOp : public TensorOp { + public: + // Constructor of ToNumberOp + // @param const DataType &cast_to_type - the type to convert string inputs to. + explicit ToNumberOp(const DataType &cast_to_type); + + // Constructor of ToNumberOp + // @param const std::string &cast_to_type - the type in string form to convert string inputs to. + explicit ToNumberOp(const std::string &cast_to_type); + + ~ToNumberOp() override = default; + + // Perform numeric conversion on each string in each tensor. + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // For each input shape, find the output shape + // @param std::vector &inputs - shape of input tensors + // @param std::vector &outputs - shape of output tensors + // @return error code + Status OutputShape(const std::vector &input_shapes, std::vector &output_shapes) override; + + // print arg for debugging + // @param std::ostream &out + void Print(std::ostream &out) const override; + + private: + template + Status ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + template + Status ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToDouble(const std::shared_ptr &input, std::shared_ptr *output); + + DataType cast_to_type_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py index 364ea75d574..4f2ecdc5403 100644 --- a/mindspore/dataset/text/__init__.py +++ b/mindspore/dataset/text/__init__.py @@ -16,12 +16,13 @@ mindspore.dataset.text """ import platform -from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair +from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ + ToNumber from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm __all__ = [ "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", - "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair" + "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber" ] if platform.system().lower() != 'windows': diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 306e08a4dcf..ad4c12ad982 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -23,7 +23,9 @@ import mindspore._c_dataengine as cde from .utils import JiebaMode, NormalizeForm from .validators import check_lookup, check_jieba_add_dict, \ - check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate + check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \ + check_to_number +from ..core.datatypes import mstype_to_detype class Lookup(cde.LookupOp): @@ -379,3 +381,28 @@ class TruncateSequencePair(cde.TruncateSequencePairOp): @check_pair_truncate def __init__(self, max_length): super().__init__(max_length) + + +class ToNumber(cde.ToNumberOp): + """ + Tensor operation to convert every element of a string tensor to a number. + + Strings are casted according to the rules specified in the following links: + https://en.cppreference.com/w/cpp/string/basic_string/stof, + https://en.cppreference.com/w/cpp/string/basic_string/stoul, + except that any strings which represent negative numbers cannot be casted to an + unsigned integer type. + + Args: + data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be + a numeric type. + + Raises: + RuntimeError: If strings are invalid to cast, or are out of range after being casted. + """ + + @check_to_number + def __init__(self, data_type): + data_type = mstype_to_detype(data_type) + self.data_type = str(data_type) + super().__init__(data_type) diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index e288c09b080..74ff31dd7a6 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -19,7 +19,9 @@ validators for text ops from functools import wraps import mindspore._c_dataengine as cde +import mindspore.common.dtype as mstype +from mindspore._c_expression import typing from ..transforms.validators import check_uint32, check_pos_int64 @@ -384,3 +386,28 @@ def check_pair_truncate(method): return method(self, **kwargs) return new_method + + +def check_to_number(method): + """A wrapper that wraps a parameter check to the original function (ToNumber).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + data_type = (list(args) + [None])[0] + if "data_type" in kwargs: + data_type = kwargs.get("data_type") + + if data_type is None: + raise ValueError("data_type is a mandatory parameter but was not provided.") + + if not isinstance(data_type, typing.Type): + raise TypeError("data_type is not a MindSpore data type.") + + if not data_type in mstype.number_type: + raise TypeError("data_type is not numeric data type.") + + kwargs["data_type"] = data_type + + return method(self, **kwargs) + + return new_method diff --git a/model_zoo/yolov3/eval.py b/model_zoo/yolov3/eval.py index 433ae834ba3..65dc408a150 100644 --- a/model_zoo/yolov3/eval.py +++ b/model_zoo/yolov3/eval.py @@ -88,15 +88,15 @@ if __name__ == '__main__': if not os.path.isdir(args_opt.mindrecord_dir): os.makedirs(args_opt.mindrecord_dir) - prefix = "yolo.mindrecord" - mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0") + yolo_prefix = "yolo.mindrecord" + mindrecord_file = os.path.join(args_opt.mindrecord_dir, yolo_prefix + "0") if not os.path.exists(mindrecord_file): if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path): print("Create Mindrecord") data_to_mindrecord_byte_image(args_opt.image_dir, args_opt.anno_path, args_opt.mindrecord_dir, - prefix=prefix, + prefix=yolo_prefix, file_num=8) print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir)) else: diff --git a/tests/ut/python/dataset/test_to_number_op.py b/tests/ut/python/dataset/test_to_number_op.py new file mode 100644 index 00000000000..47b39e7a682 --- /dev/null +++ b/tests/ut/python/dataset/test_to_number_op.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.text as text + +np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, + np.uint32, np.uint64] +ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, + mstype.uint16, mstype.uint32, mstype.uint64] + +np_non_integral_types = [np.float16, np.float32, np.float64] +ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64] + +def string_dataset_generator(strings): + for string in strings: + yield (np.array(string, dtype='S'),) + + +def test_to_number_typical_case_integral(): + input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"], + ["-1726483716", "98921728421"]] + + for ms_type, inputs in zip(ms_integral_types, input_strings): + dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + + expected_output = [int(string) for string in inputs] + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["strings"]) + + assert output == expected_output + + +def test_to_number_typical_case_non_integral(): + input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]] + epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001] + + for ms_type, inputs in zip(ms_non_integral_types, input_strings): + dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + + expected_output = [float(string) for string in inputs] + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["strings"]) + + for expected, actual, epsilon in zip(expected_output, output, epsilons): + assert abs(expected - actual) < epsilon + + +def out_of_bounds_error_message_check(dataset, np_type, value_to_cast): + type_info = np.iinfo(np_type) + type_max = str(type_info.max) + type_min = str(type_info.min) + type_name = str(np.dtype(np_type)) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "String input " + value_to_cast + " will be out of bounds if casted to " + type_name in str(info.value) + assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value) + + +def test_to_number_out_of_bounds_integral(): + for np_type, ms_type in zip(np_integral_types, ms_integral_types): + type_info = np.iinfo(np_type) + input_strings = [str(type_info.max + 10)] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + out_of_bounds_error_message_check(dataset, np_type, input_strings[0]) + + input_strings = [str(type_info.min - 10)] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + out_of_bounds_error_message_check(dataset, np_type, input_strings[0]) + + +def test_to_number_out_of_bounds_non_integral(): + above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"] + + input_strings = [above_range[0]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "outside of valid float16 range" in str(info.value) + + input_strings = [above_range[1]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value) + + input_strings = [above_range[2]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value) + + below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"] + + input_strings = [below_range[0]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "outside of valid float16 range" in str(info.value) + + input_strings = [below_range[1]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value) + + input_strings = [below_range[2]] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2])) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value) + + +def test_to_number_boundaries_integral(): + for np_type, ms_type in zip(np_integral_types, ms_integral_types): + type_info = np.iinfo(np_type) + input_strings = [str(type_info.max)] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + for data in dataset.create_dict_iterator(): + assert data["strings"] == int(input_strings[0]) + + input_strings = [str(type_info.min)] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + for data in dataset.create_dict_iterator(): + assert data["strings"] == int(input_strings[0]) + + input_strings = [str(0)] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type)) + for data in dataset.create_dict_iterator(): + assert data["strings"] == int(input_strings[0]) + + +def test_to_number_invalid_input(): + input_strings = ["a8fa9ds8fa"] + dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings") + dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(mstype.int32)) + + with pytest.raises(RuntimeError) as info: + for _ in dataset.create_dict_iterator(): + pass + assert "It is invalid to convert " + input_strings[0] + " to a number" in str(info.value) + + +if __name__ == '__main__': + test_to_number_typical_case_integral() + test_to_number_typical_case_non_integral() + test_to_number_boundaries_integral() + test_to_number_out_of_bounds_integral() + test_to_number_out_of_bounds_non_integral() + test_to_number_invalid_input()