forked from mindspore-Ecosystem/mindspore
dataset: C++ API TypeCast: change data_type parm from string to DataType
This commit is contained in:
parent
bdc5a9c88b
commit
82f4b3f757
|
@ -19,6 +19,8 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "mindspore/ccsrc/minddata/dataset/core/type_id.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "minddata/dataset/core/type_id.h"
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -211,11 +213,12 @@ std::shared_ptr<TensorOperation> Slice::Parse() { return std::make_shared<SliceO
|
|||
|
||||
// Constructor to TypeCast
|
||||
struct TypeCast::Data {
|
||||
explicit Data(const std::vector<char> &data_type) : data_type_(CharToString(data_type)) {}
|
||||
std::string data_type_;
|
||||
dataset::DataType data_type_;
|
||||
};
|
||||
|
||||
TypeCast::TypeCast(const std::vector<char> &data_type) : data_(std::make_shared<Data>(data_type)) {}
|
||||
TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>()) {
|
||||
data_->data_type_ = dataset::MSTypeToDEType(static_cast<TypeId>(data_type));
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "minddata/dataset/include/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -349,10 +350,8 @@ class Slice final : public TensorTransform {
|
|||
class TypeCast final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] data_type mindspore.dtype to be cast to.
|
||||
explicit TypeCast(std::string data_type) : TypeCast(StringToChar(data_type)) {}
|
||||
|
||||
explicit TypeCast(const std::vector<char> &data_type);
|
||||
/// \param[in] data_type mindspore::DataType to be cast to.
|
||||
explicit TypeCast(mindspore::DataType data_type);
|
||||
|
||||
/// \brief Destructor
|
||||
~TypeCast() = default;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
|
||||
|
@ -213,19 +214,22 @@ std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<Slic
|
|||
#endif
|
||||
|
||||
// TypeCastOperation
|
||||
TypeCastOperation::TypeCastOperation(std::string data_type) : data_type_(data_type) {}
|
||||
// DataType data_type - required for C++ API
|
||||
TypeCastOperation::TypeCastOperation(DataType data_type) : data_type_(data_type) {}
|
||||
|
||||
// std::string data_type - required for Pybind
|
||||
TypeCastOperation::TypeCastOperation(std::string data_type) {
|
||||
// Convert from string to DEType
|
||||
DataType temp_data_type(data_type);
|
||||
data_type_ = temp_data_type;
|
||||
}
|
||||
|
||||
Status TypeCastOperation::ValidateParams() {
|
||||
std::vector<std::string> predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32",
|
||||
"int64", "uint64", "float16", "float32", "float64", "string"};
|
||||
auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_);
|
||||
if (itr == predefine_type.end()) {
|
||||
std::string err_msg = "TypeCast: Invalid data type: " + data_type_;
|
||||
MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, "
|
||||
<< "int64, uint64, float16, float32, float64, string, but got: " << data_type_;
|
||||
if (data_type_ == DataType::DE_UNKNOWN) {
|
||||
std::string err_msg = "TypeCast: Invalid data type";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -233,7 +237,7 @@ std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<T
|
|||
|
||||
Status TypeCastOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["data_type"] = data_type_;
|
||||
args["data_type"] = data_type_.ToString();
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -214,7 +215,8 @@ class SliceOperation : public TensorOperation {
|
|||
|
||||
class TypeCastOperation : public TensorOperation {
|
||||
public:
|
||||
explicit TypeCastOperation(std::string data_type);
|
||||
explicit TypeCastOperation(DataType data_type); // Used for C++ API
|
||||
explicit TypeCastOperation(std::string data_type); // Used for Pybind
|
||||
|
||||
~TypeCastOperation() = default;
|
||||
|
||||
|
@ -227,7 +229,7 @@ class TypeCastOperation : public TensorOperation {
|
|||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
std::string data_type_;
|
||||
DataType data_type_;
|
||||
};
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -75,7 +75,8 @@ TEST_F(MindDataTestPipeline, TestSaveCifar10AndLoad) {
|
|||
|
||||
// Create objects for the tensor ops
|
||||
// uint32 will be casted to int64 implicitly in mindrecord file, so we have to cast it back to uint32
|
||||
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint32");
|
||||
std::shared_ptr<TensorTransform> type_cast =
|
||||
std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt32);
|
||||
EXPECT_NE(type_cast, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
|
|
|
@ -825,7 +825,8 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) {
|
|||
iter->Stop();
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint16");
|
||||
std::shared_ptr<TensorTransform> type_cast =
|
||||
std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt16);
|
||||
|
||||
// Create a Map operation on ds
|
||||
std::shared_ptr<Dataset> ds2 = ds->Map({type_cast}, {"image"});
|
||||
|
@ -848,7 +849,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTypeCastFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid params.";
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid param.";
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
|
@ -856,7 +857,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) {
|
|||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// incorrect data type
|
||||
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("char");
|
||||
std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>(mindspore::DataType::kTypeUnknown);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({type_cast}, {"image", "label"});
|
||||
|
|
|
@ -49,7 +49,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess1) {
|
|||
// Note: No need to check for output after calling API class constructor
|
||||
|
||||
// Convert to the same type
|
||||
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast("uint8"));
|
||||
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8));
|
||||
// Note: No need to check for output after calling API class constructor
|
||||
|
||||
ds = ds->Map({rescale, type_cast}, {"image"});
|
||||
|
|
|
@ -332,7 +332,7 @@ TEST_F(MindDataTestCallback, TestCAPICallback) {
|
|||
ASSERT_OK(schema->add_column("label", mindspore::DataType::kNumberTypeUInt32, {}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(44, schema);
|
||||
ASSERT_NE(ds, nullptr);
|
||||
ds = ds->Map({std::make_shared<transforms::TypeCast>("uint64")}, {"label"}, {}, {}, nullptr, {cb1});
|
||||
ds = ds->Map({std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt64)}, {"label"}, {}, {}, nullptr, {cb1});
|
||||
ASSERT_NE(ds, nullptr);
|
||||
ds = ds->Repeat(2);
|
||||
ASSERT_NE(ds, nullptr);
|
||||
|
|
Loading…
Reference in New Issue