diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index d3a0280f404..c9670e71252 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -19,6 +19,8 @@ #include #include "mindspore/ccsrc/minddata/dataset/core/type_id.h" +#include "mindspore/core/ir/dtype/type_id.h" +#include "minddata/dataset/core/type_id.h" #include "minddata/dataset/kernels/ir/data/transforms_ir.h" namespace mindspore { @@ -211,11 +213,12 @@ std::shared_ptr Slice::Parse() { return std::make_shared &data_type) : data_type_(CharToString(data_type)) {} - std::string data_type_; + dataset::DataType data_type_; }; -TypeCast::TypeCast(const std::vector &data_type) : data_(std::make_shared(data_type)) {} +TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared()) { + data_->data_type_ = dataset::MSTypeToDEType(static_cast(data_type)); +} std::shared_ptr TypeCast::Parse() { return std::make_shared(data_->data_type_); } diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 1114a99980e..c41a687e3da 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -24,6 +24,7 @@ #include "include/api/dual_abi_helper.h" #include "include/api/status.h" +#include "include/api/types.h" #include "minddata/dataset/include/constants.h" namespace mindspore { @@ -349,10 +350,8 @@ class Slice final : public TensorTransform { class TypeCast final : public TensorTransform { public: /// \brief Constructor. - /// \param[in] data_type mindspore.dtype to be cast to. - explicit TypeCast(std::string data_type) : TypeCast(StringToChar(data_type)) {} - - explicit TypeCast(const std::vector &data_type); + /// \param[in] data_type mindspore::DataType to be cast to. + explicit TypeCast(mindspore::DataType data_type); /// \brief Destructor ~TypeCast() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/data/transforms_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/data/transforms_ir.cc index bbe09ff94e8..952d89d8f0f 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/data/transforms_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/data/transforms_ir.cc @@ -15,6 +15,7 @@ */ #include +#include #include "minddata/dataset/kernels/ir/data/transforms_ir.h" @@ -213,19 +214,22 @@ std::shared_ptr SliceOperation::Build() { return std::make_shared predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32", - "int64", "uint64", "float16", "float32", "float64", "string"}; - auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_); - if (itr == predefine_type.end()) { - std::string err_msg = "TypeCast: Invalid data type: " + data_type_; - MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, " - << "int64, uint64, float16, float32, float64, string, but got: " << data_type_; + if (data_type_ == DataType::DE_UNKNOWN) { + std::string err_msg = "TypeCast: Invalid data type"; + MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return Status::OK(); } @@ -233,7 +237,7 @@ std::shared_ptr TypeCastOperation::Build() { return std::make_shared #include +#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/kernels/ir/tensor_operation.h" namespace mindspore { @@ -214,7 +215,8 @@ class SliceOperation : public TensorOperation { class TypeCastOperation : public TensorOperation { public: - explicit TypeCastOperation(std::string data_type); + explicit TypeCastOperation(DataType data_type); // Used for C++ API + explicit TypeCastOperation(std::string data_type); // Used for Pybind ~TypeCastOperation() = default; @@ -227,7 +229,7 @@ class TypeCastOperation : public TensorOperation { Status to_json(nlohmann::json *out_json) override; private: - std::string data_type_; + DataType data_type_; }; #ifndef ENABLE_ANDROID diff --git a/tests/ut/cpp/dataset/c_api_dataset_save.cc b/tests/ut/cpp/dataset/c_api_dataset_save.cc index 5bd7e409f9c..2c981e42a7e 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_save.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_save.cc @@ -75,7 +75,8 @@ TEST_F(MindDataTestPipeline, TestSaveCifar10AndLoad) { // Create objects for the tensor ops // uint32 will be casted to int64 implicitly in mindrecord file, so we have to cast it back to uint32 - std::shared_ptr type_cast = std::make_shared("uint32"); + std::shared_ptr type_cast = + std::make_shared(mindspore::DataType::kNumberTypeUInt32); EXPECT_NE(type_cast, nullptr); // Create a Map operation on ds diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 5c45f684905..d45ed50bd47 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -825,7 +825,8 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) { iter->Stop(); // Create objects for the tensor ops - std::shared_ptr type_cast = std::make_shared("uint16"); + std::shared_ptr type_cast = + std::make_shared(mindspore::DataType::kNumberTypeUInt16); // Create a Map operation on ds std::shared_ptr ds2 = ds->Map({type_cast}, {"image"}); @@ -848,7 +849,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) { } TEST_F(MindDataTestPipeline, TestTypeCastFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid params."; + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid param."; // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; @@ -856,7 +857,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) { EXPECT_NE(ds, nullptr); // incorrect data type - std::shared_ptr type_cast = std::make_shared("char"); + std::shared_ptr type_cast = std::make_shared(mindspore::DataType::kTypeUnknown); // Create a Map operation on ds ds = ds->Map({type_cast}, {"image", "label"}); @@ -865,4 +866,4 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) { std::shared_ptr iter = ds->CreateIterator(); // Expect failure: invalid TypeCast input EXPECT_EQ(iter, nullptr); -} +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc index 9bd69b616a6..c377ddb1cc5 100644 --- a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc @@ -49,7 +49,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess1) { // Note: No need to check for output after calling API class constructor // Convert to the same type - std::shared_ptr type_cast(new transforms::TypeCast("uint8")); + std::shared_ptr type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8)); // Note: No need to check for output after calling API class constructor ds = ds->Map({rescale, type_cast}, {"image"}); diff --git a/tests/ut/cpp/dataset/ir_callback_test.cc b/tests/ut/cpp/dataset/ir_callback_test.cc index 1fe30e493cc..112dc26df7b 100644 --- a/tests/ut/cpp/dataset/ir_callback_test.cc +++ b/tests/ut/cpp/dataset/ir_callback_test.cc @@ -332,7 +332,7 @@ TEST_F(MindDataTestCallback, TestCAPICallback) { ASSERT_OK(schema->add_column("label", mindspore::DataType::kNumberTypeUInt32, {})); std::shared_ptr ds = RandomData(44, schema); ASSERT_NE(ds, nullptr); - ds = ds->Map({std::make_shared("uint64")}, {"label"}, {}, {}, nullptr, {cb1}); + ds = ds->Map({std::make_shared(mindspore::DataType::kNumberTypeUInt64)}, {"label"}, {}, {}, nullptr, {cb1}); ASSERT_NE(ds, nullptr); ds = ds->Repeat(2); ASSERT_NE(ds, nullptr);