dataset: C++ API TypeCast: change data_type parm from string to DataType

This commit is contained in:
Cathy Wong 2021-03-24 12:00:46 -04:00
parent bdc5a9c88b
commit 82f4b3f757
8 changed files with 36 additions and 26 deletions

View File

@ -19,6 +19,8 @@
#include <algorithm>
#include "mindspore/ccsrc/minddata/dataset/core/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
namespace mindspore {
@ -211,11 +213,12 @@ std::shared_ptr<TensorOperation> Slice::Parse() { return std::make_shared<SliceO
// Constructor to TypeCast
struct TypeCast::Data {
explicit Data(const std::vector<char> &data_type) : data_type_(CharToString(data_type)) {}
std::string data_type_;
dataset::DataType data_type_;
};
TypeCast::TypeCast(const std::vector<char> &data_type) : data_(std::make_shared<Data>(data_type)) {}
TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>()) {
data_->data_type_ = dataset::MSTypeToDEType(static_cast<TypeId>(data_type));
}
std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }

View File

@ -24,6 +24,7 @@
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/api/types.h"
#include "minddata/dataset/include/constants.h"
namespace mindspore {
@ -349,10 +350,8 @@ class Slice final : public TensorTransform {
class TypeCast final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] data_type mindspore.dtype to be cast to.
explicit TypeCast(std::string data_type) : TypeCast(StringToChar(data_type)) {}
explicit TypeCast(const std::vector<char> &data_type);
/// \param[in] data_type mindspore::DataType to be cast to.
explicit TypeCast(mindspore::DataType data_type);
/// \brief Destructor
~TypeCast() = default;

View File

@ -15,6 +15,7 @@
*/
#include <algorithm>
#include <typeinfo>
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
@ -213,19 +214,22 @@ std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<Slic
#endif
// TypeCastOperation
TypeCastOperation::TypeCastOperation(std::string data_type) : data_type_(data_type) {}
// DataType data_type - required for C++ API
TypeCastOperation::TypeCastOperation(DataType data_type) : data_type_(data_type) {}
// std::string data_type - required for Pybind
TypeCastOperation::TypeCastOperation(std::string data_type) {
// Convert from string to DEType
DataType temp_data_type(data_type);
data_type_ = temp_data_type;
}
Status TypeCastOperation::ValidateParams() {
std::vector<std::string> predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32",
"int64", "uint64", "float16", "float32", "float64", "string"};
auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_);
if (itr == predefine_type.end()) {
std::string err_msg = "TypeCast: Invalid data type: " + data_type_;
MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, "
<< "int64, uint64, float16, float32, float64, string, but got: " << data_type_;
if (data_type_ == DataType::DE_UNKNOWN) {
std::string err_msg = "TypeCast: Invalid data type";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -233,7 +237,7 @@ std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<T
Status TypeCastOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["data_type"] = data_type_;
args["data_type"] = data_type_.ToString();
*out_json = args;
return Status::OK();
}

View File

@ -22,6 +22,7 @@
#include <string>
#include <vector>
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
@ -214,7 +215,8 @@ class SliceOperation : public TensorOperation {
class TypeCastOperation : public TensorOperation {
public:
explicit TypeCastOperation(std::string data_type);
explicit TypeCastOperation(DataType data_type); // Used for C++ API
explicit TypeCastOperation(std::string data_type); // Used for Pybind
~TypeCastOperation() = default;
@ -227,7 +229,7 @@ class TypeCastOperation : public TensorOperation {
Status to_json(nlohmann::json *out_json) override;
private:
std::string data_type_;
DataType data_type_;
};
#ifndef ENABLE_ANDROID

View File

@ -75,7 +75,8 @@ TEST_F(MindDataTestPipeline, TestSaveCifar10AndLoad) {
// Create objects for the tensor ops
// uint32 will be casted to int64 implicitly in mindrecord file, so we have to cast it back to uint32
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint32");
std::shared_ptr<TensorTransform> type_cast =
std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt32);
EXPECT_NE(type_cast, nullptr);
// Create a Map operation on ds

View File

@ -825,7 +825,8 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) {
iter->Stop();
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint16");
std::shared_ptr<TensorTransform> type_cast =
std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt16);
// Create a Map operation on ds
std::shared_ptr<Dataset> ds2 = ds->Map({type_cast}, {"image"});
@ -848,7 +849,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) {
}
TEST_F(MindDataTestPipeline, TestTypeCastFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid params.";
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid param.";
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
@ -856,7 +857,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) {
EXPECT_NE(ds, nullptr);
// incorrect data type
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("char");
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>(mindspore::DataType::kTypeUnknown);
// Create a Map operation on ds
ds = ds->Map({type_cast}, {"image", "label"});
@ -865,4 +866,4 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) {
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid TypeCast input
EXPECT_EQ(iter, nullptr);
}
}

View File

@ -49,7 +49,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess1) {
// Note: No need to check for output after calling API class constructor
// Convert to the same type
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast("uint8"));
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8));
// Note: No need to check for output after calling API class constructor
ds = ds->Map({rescale, type_cast}, {"image"});

View File

@ -332,7 +332,7 @@ TEST_F(MindDataTestCallback, TestCAPICallback) {
ASSERT_OK(schema->add_column("label", mindspore::DataType::kNumberTypeUInt32, {}));
std::shared_ptr<Dataset> ds = RandomData(44, schema);
ASSERT_NE(ds, nullptr);
ds = ds->Map({std::make_shared<transforms::TypeCast>("uint64")}, {"label"}, {}, {}, nullptr, {cb1});
ds = ds->Map({std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt64)}, {"label"}, {}, {}, nullptr, {cb1});
ASSERT_NE(ds, nullptr);
ds = ds->Repeat(2);
ASSERT_NE(ds, nullptr);