forked from mindspore-Ecosystem/mindspore
!2300 Implementation of ToNumberOp
Merge pull request !2300 from Peilin/ToNumberOp
This commit is contained in:
commit
ec8f541325
|
@ -16,9 +16,37 @@
|
|||
#include <exception>
|
||||
|
||||
#include "dataset/api/de_pipeline.h"
|
||||
#include "dataset/kernels/no_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/kernels/data/concatenate_op.h"
|
||||
#include "dataset/kernels/data/duplicate_op.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
#include "dataset/kernels/data/one_hot_op.h"
|
||||
#include "dataset/kernels/data/pad_end_op.h"
|
||||
#include "dataset/kernels/data/slice_op.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "dataset/kernels/image/center_crop_op.h"
|
||||
#include "dataset/kernels/image/cut_out_op.h"
|
||||
#include "dataset/kernels/image/decode_op.h"
|
||||
|
@ -27,11 +55,11 @@
|
|||
#include "dataset/kernels/image/normalize_op.h"
|
||||
#include "dataset/kernels/image/pad_op.h"
|
||||
#include "dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_op.h"
|
||||
#include "dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h"
|
||||
#include "dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "dataset/kernels/image/random_resize_op.h"
|
||||
#include "dataset/kernels/image/random_rotation_op.h"
|
||||
#include "dataset/kernels/image/random_vertical_flip_op.h"
|
||||
|
@ -39,42 +67,24 @@
|
|||
#include "dataset/kernels/image/resize_bilinear_op.h"
|
||||
#include "dataset/kernels/image/resize_op.h"
|
||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "dataset/kernels/data/duplicate_op.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
#include "dataset/kernels/data/pad_end_op.h"
|
||||
#include "dataset/kernels/data/slice_op.h"
|
||||
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/kernels/no_op.h"
|
||||
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/lookup_op.h"
|
||||
#include "dataset/text/kernels/ngram_op.h"
|
||||
#include "dataset/text/kernels/to_number_op.h"
|
||||
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||
#include "dataset/text/vocab.h"
|
||||
#include "dataset/text/kernels/lookup_op.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
#include "mindrecord/include/shard_pk_sample.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_sequential_sample.h"
|
||||
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#ifdef ENABLE_ICU4C
|
||||
#include "dataset/text/kernels/basic_tokenizer_op.h"
|
||||
|
@ -87,16 +97,6 @@
|
|||
#include "dataset/text/kernels/whitespace_tokenizer_op.h"
|
||||
#endif
|
||||
|
||||
#include "dataset/util/random.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
#include "mindrecord/include/shard_pk_sample.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_sequential_sample.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -542,6 +542,10 @@ void bindTensorOps4(py::module *m) {
|
|||
.def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
|
||||
py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
|
||||
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
|
||||
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
|
||||
"TensorOp to convert strings to numbers.")
|
||||
.def(py::init<DataType>(), py::arg("data_type"))
|
||||
.def(py::init<std::string>(), py::arg("data_type"));
|
||||
}
|
||||
|
||||
void bindTokenizerOps(py::module *m) {
|
||||
|
|
|
@ -15,15 +15,19 @@
|
|||
*/
|
||||
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/pybind_support.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -330,7 +334,18 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
auto in_itr = input->begin<float>();
|
||||
auto out_itr = (*output)->begin<float16>();
|
||||
auto out_end = (*output)->end<float16>();
|
||||
for (; out_itr != out_end; in_itr++, out_itr++) *out_itr = Eigen::half(*in_itr);
|
||||
|
||||
for (; out_itr != out_end; in_itr++, out_itr++) {
|
||||
float element = *in_itr;
|
||||
float float16_max = static_cast<float>(std::numeric_limits<Eigen::half>::max());
|
||||
float float16_min = static_cast<float>(std::numeric_limits<Eigen::half>::lowest());
|
||||
if (element > float16_max || element < float16_min) {
|
||||
RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" +
|
||||
std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
|
||||
}
|
||||
|
||||
*out_itr = Eigen::half(*in_itr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -18,5 +18,6 @@ add_library(text-kernels OBJECT
|
|||
ngram_op.cc
|
||||
wordpiece_tokenizer_op.cc
|
||||
truncate_sequence_pair_op.cc
|
||||
to_number_op.cc
|
||||
${ICU_DEPEND_FILES}
|
||||
)
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/text/kernels/to_number_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {}
|
||||
|
||||
ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {}
|
||||
|
||||
Status ToNumberOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string.");
|
||||
|
||||
switch (cast_to_type_.value()) {
|
||||
case DataType::DE_INT8:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int8_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT16:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int16_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT32:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int32_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT64:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int64_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT8:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint8_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT16:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint16_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT32:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint32_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT64:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint64_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT16:
|
||||
RETURN_IF_NOT_OK(this->ToFloat16(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT32:
|
||||
RETURN_IF_NOT_OK(ToFloat(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT64:
|
||||
RETURN_IF_NOT_OK(ToDouble(input, output));
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; }
|
||||
|
||||
Status ToNumberOp::OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) {
|
||||
(void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ToNumberOp::ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<T> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
int64_t result = 0;
|
||||
|
||||
try {
|
||||
result = std::stoll(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
T casted_result = static_cast<T>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<T> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
uint64_t result = 0;
|
||||
|
||||
// If there is a - at the start of the string, it is considered by us to
|
||||
// be out of bounds. If the - is somewhere else in the string, it is
|
||||
// deemed invalid by std::stoull and will throw std::invalid_argument
|
||||
for (int i = 0; i < (*it).size(); i++) {
|
||||
if ((*it)[i] == '-') {
|
||||
is_cast_out_of_range = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
result = std::stoull(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
T casted_result = static_cast<T>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
// special case, float16 does not exist in c++, no native support for
|
||||
// casting, so cast to float first then use this method, which use Eigen.
|
||||
std::shared_ptr<Tensor> temp;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32")));
|
||||
RETURN_IF_NOT_OK(ToFloat(input, &temp));
|
||||
RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<float> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
float result = 0;
|
||||
|
||||
try {
|
||||
result = std::stof(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<float>::max() || result < std::numeric_limits<float>::lowest() ||
|
||||
is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<float>::lowest()) + ", " +
|
||||
std::to_string(std::numeric_limits<float>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
float casted_result = static_cast<float>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<double> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
double result = 0;
|
||||
|
||||
try {
|
||||
result = std::stod(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<double>::max() || result < std::numeric_limits<double>::lowest() ||
|
||||
is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<double>::lowest()) + ", " +
|
||||
std::to_string(std::numeric_limits<double>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
double casted_result = static_cast<double>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ToNumberOp : public TensorOp {
|
||||
public:
|
||||
// Constructor of ToNumberOp
|
||||
// @param const DataType &cast_to_type - the type to convert string inputs to.
|
||||
explicit ToNumberOp(const DataType &cast_to_type);
|
||||
|
||||
// Constructor of ToNumberOp
|
||||
// @param const std::string &cast_to_type - the type in string form to convert string inputs to.
|
||||
explicit ToNumberOp(const std::string &cast_to_type);
|
||||
|
||||
~ToNumberOp() override = default;
|
||||
|
||||
// Perform numeric conversion on each string in each tensor.
|
||||
// @param const std::shared_ptr<Tensor> &input
|
||||
// @param std::shared_ptr<Tensor> *output
|
||||
// @return error code
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
// For each input shape, find the output shape
|
||||
// @param std::vector<TensorShape> &inputs - shape of input tensors
|
||||
// @param std::vector<TensorShape> &outputs - shape of output tensors
|
||||
// @return error code
|
||||
Status OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) override;
|
||||
|
||||
// print arg for debugging
|
||||
// @param std::ostream &out
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
template <typename T>
|
||||
Status ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
DataType cast_to_type_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
|
@ -16,12 +16,13 @@
|
|||
mindspore.dataset.text
|
||||
"""
|
||||
import platform
|
||||
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair
|
||||
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
|
||||
ToNumber
|
||||
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
|
||||
|
||||
__all__ = [
|
||||
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
|
||||
"to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair"
|
||||
"to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber"
|
||||
]
|
||||
|
||||
if platform.system().lower() != 'windows':
|
||||
|
|
|
@ -23,7 +23,9 @@ import mindspore._c_dataengine as cde
|
|||
|
||||
from .utils import JiebaMode, NormalizeForm
|
||||
from .validators import check_lookup, check_jieba_add_dict, \
|
||||
check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate
|
||||
check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \
|
||||
check_to_number
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
|
||||
|
||||
class Lookup(cde.LookupOp):
|
||||
|
@ -379,3 +381,28 @@ class TruncateSequencePair(cde.TruncateSequencePairOp):
|
|||
@check_pair_truncate
|
||||
def __init__(self, max_length):
|
||||
super().__init__(max_length)
|
||||
|
||||
|
||||
class ToNumber(cde.ToNumberOp):
|
||||
"""
|
||||
Tensor operation to convert every element of a string tensor to a number.
|
||||
|
||||
Strings are casted according to the rules specified in the following links:
|
||||
https://en.cppreference.com/w/cpp/string/basic_string/stof,
|
||||
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
|
||||
except that any strings which represent negative numbers cannot be casted to an
|
||||
unsigned integer type.
|
||||
|
||||
Args:
|
||||
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
|
||||
a numeric type.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If strings are invalid to cast, or are out of range after being casted.
|
||||
"""
|
||||
|
||||
@check_to_number
|
||||
def __init__(self, data_type):
|
||||
data_type = mstype_to_detype(data_type)
|
||||
self.data_type = str(data_type)
|
||||
super().__init__(data_type)
|
||||
|
|
|
@ -19,7 +19,9 @@ validators for text ops
|
|||
from functools import wraps
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore._c_expression import typing
|
||||
from ..transforms.validators import check_uint32, check_pos_int64
|
||||
|
||||
|
||||
|
@ -384,3 +386,28 @@ def check_pair_truncate(method):
|
|||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_to_number(method):
|
||||
"""A wrapper that wraps a parameter check to the original function (ToNumber)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
data_type = (list(args) + [None])[0]
|
||||
if "data_type" in kwargs:
|
||||
data_type = kwargs.get("data_type")
|
||||
|
||||
if data_type is None:
|
||||
raise ValueError("data_type is a mandatory parameter but was not provided.")
|
||||
|
||||
if not isinstance(data_type, typing.Type):
|
||||
raise TypeError("data_type is not a MindSpore data type.")
|
||||
|
||||
if not data_type in mstype.number_type:
|
||||
raise TypeError("data_type is not numeric data type.")
|
||||
|
||||
kwargs["data_type"] = data_type
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -88,15 +88,15 @@ if __name__ == '__main__':
|
|||
if not os.path.isdir(args_opt.mindrecord_dir):
|
||||
os.makedirs(args_opt.mindrecord_dir)
|
||||
|
||||
prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
|
||||
yolo_prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, yolo_prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
|
||||
print("Create Mindrecord")
|
||||
data_to_mindrecord_byte_image(args_opt.image_dir,
|
||||
args_opt.anno_path,
|
||||
args_opt.mindrecord_dir,
|
||||
prefix=prefix,
|
||||
prefix=yolo_prefix,
|
||||
file_num=8)
|
||||
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
|
||||
np.uint32, np.uint64]
|
||||
ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8,
|
||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
|
||||
np_non_integral_types = [np.float16, np.float32, np.float64]
|
||||
ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64]
|
||||
|
||||
def string_dataset_generator(strings):
|
||||
for string in strings:
|
||||
yield (np.array(string, dtype='S'),)
|
||||
|
||||
|
||||
def test_to_number_typical_case_integral():
|
||||
input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
|
||||
["-1726483716", "98921728421"]]
|
||||
|
||||
for ms_type, inputs in zip(ms_integral_types, input_strings):
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
|
||||
expected_output = [int(string) for string in inputs]
|
||||
output = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
output.append(data["strings"])
|
||||
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_to_number_typical_case_non_integral():
|
||||
input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]]
|
||||
epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]
|
||||
|
||||
for ms_type, inputs in zip(ms_non_integral_types, input_strings):
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
|
||||
expected_output = [float(string) for string in inputs]
|
||||
output = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
output.append(data["strings"])
|
||||
|
||||
for expected, actual, epsilon in zip(expected_output, output, epsilons):
|
||||
assert abs(expected - actual) < epsilon
|
||||
|
||||
|
||||
def out_of_bounds_error_message_check(dataset, np_type, value_to_cast):
|
||||
type_info = np.iinfo(np_type)
|
||||
type_max = str(type_info.max)
|
||||
type_min = str(type_info.min)
|
||||
type_name = str(np.dtype(np_type))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + value_to_cast + " will be out of bounds if casted to " + type_name in str(info.value)
|
||||
assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value)
|
||||
|
||||
|
||||
def test_to_number_out_of_bounds_integral():
|
||||
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||
type_info = np.iinfo(np_type)
|
||||
input_strings = [str(type_info.max + 10)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||
|
||||
input_strings = [str(type_info.min - 10)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||
|
||||
|
||||
def test_to_number_out_of_bounds_non_integral():
|
||||
above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"]
|
||||
|
||||
input_strings = [above_range[0]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "outside of valid float16 range" in str(info.value)
|
||||
|
||||
input_strings = [above_range[1]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||
|
||||
input_strings = [above_range[2]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||
|
||||
below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"]
|
||||
|
||||
input_strings = [below_range[0]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "outside of valid float16 range" in str(info.value)
|
||||
|
||||
input_strings = [below_range[1]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||
|
||||
input_strings = [below_range[2]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||
|
||||
|
||||
def test_to_number_boundaries_integral():
|
||||
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||
type_info = np.iinfo(np_type)
|
||||
input_strings = [str(type_info.max)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
input_strings = [str(type_info.min)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
input_strings = [str(0)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
|
||||
def test_to_number_invalid_input():
|
||||
input_strings = ["a8fa9ds8fa"]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(mstype.int32))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "It is invalid to convert " + input_strings[0] + " to a number" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_to_number_typical_case_integral()
|
||||
test_to_number_typical_case_non_integral()
|
||||
test_to_number_boundaries_integral()
|
||||
test_to_number_out_of_bounds_integral()
|
||||
test_to_number_out_of_bounds_non_integral()
|
||||
test_to_number_invalid_input()
|
Loading…
Reference in New Issue