deserialize and tests

This commit is contained in:
zetongzhao 2021-06-07 16:45:29 -04:00
parent 1e4dace193
commit b17464e30c
16 changed files with 237 additions and 20 deletions

View File

@ -656,7 +656,7 @@ Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vect
return Status::OK();
}
const std::vector<char> SchemaObj::to_json_char() {
Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
nlohmann::json json_file;
json_file["columns"] = data_->columns_;
std::string str_dataset_type_(data_->dataset_type_);
@ -667,7 +667,13 @@ const std::vector<char> SchemaObj::to_json_char() {
if (data_->num_rows_ > 0) {
json_file["numRows"] = data_->num_rows_;
}
*out_json = json_file;
return Status::OK();
}
const std::vector<char> SchemaObj::to_json_char() {
nlohmann::json json_file;
this->schema_to_json(&json_file);
return StringToChar(json_file.dump(2));
}

View File

@ -22,6 +22,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -148,5 +151,32 @@ Status ConcatNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
}
Status ConcatNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["children_flag_and_nums"] = children_flag_and_nums_;
args["children_start_end_index"] = children_start_end_index_;
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status ConcatNode::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<DatasetNode>> datasets,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_flag_and_nums") != json_obj.end(),
"Failed to find children_flag_and_nums");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_start_end_index") != json_obj.end(),
"Failed to find children_start_end_index");
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::vector<std::pair<int, int>> children_flag_and_nums = json_obj["children_flag_and_nums"];
std::vector<std::pair<int, int>> children_start_end_index = json_obj["children_start_end_index"];
*result = std::make_shared<ConcatNode>(datasets, sampler, children_flag_and_nums, children_start_end_index);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -70,6 +70,21 @@ class ConcatNode : public DatasetNode {
bool IsSizeDefined() override { return false; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] datasets A vector of datasets for Concat input
/// \param[out] result Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<DatasetNode>> datasets,
std::shared_ptr<DatasetNode> *result);
#endif
/// \brief Getter functions
const std::vector<std::pair<int, int>> &ChildrenFlagAndNums() const { return children_flag_and_nums_; }
const std::vector<std::pair<int, int>> &ChildrenStartEndIndex() const { return children_start_end_index_; }

View File

@ -23,6 +23,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -125,5 +128,45 @@ Status AlbumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
return Status::OK();
}
Status AlbumNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_;
args["data_schema"] = schema_path_;
args["column_names"] = column_names_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("data_schema") != json_obj.end(), "Failed to find data_schema");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string data_schema = json_obj["data_schema"];
std::vector<std::string> column_names = json_obj["column_names"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -82,6 +82,19 @@ class AlbumNode : public MappableSourceNode {
/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
private:
std::string dataset_dir_;
std::string schema_path_;

View File

@ -22,6 +22,10 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/flickr_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -148,5 +152,23 @@ Status FlickrNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
std::string dataset_dir = json_obj["dataset_dir"];
std::string annotation_file = json_obj["annotation_file"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -83,6 +83,12 @@ class FlickrNode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -204,7 +204,6 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_files"] = dataset_files_;
args["schema"] = schema_path_;
args["columns_list"] = columns_list_;
args["num_samples"] = num_samples_;
args["shuffle_global"] = (shuffle_ == ShuffleMode::kGlobal);
@ -221,7 +220,9 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
if (schema_obj_ != nullptr) {
schema_obj_->set_dataset_type("TF");
schema_obj_->set_num_rows(num_samples_);
args["schema_json_string"] = schema_obj_->to_json();
nlohmann::json schema_json_string;
schema_obj_->schema_to_json(&schema_json_string);
args["schema_json_string"] = schema_json_string;
} else {
args["schema_file_path"] = schema_path_;
}
@ -233,7 +234,6 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
@ -241,7 +241,6 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
std::string schema = json_obj["schema"];
std::vector<std::string> columns_list = json_obj["columns_list"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
@ -250,8 +249,18 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
bool shard_equal_rows = json_obj["shard_equal_rows"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache);
if (json_obj.find("schema_file_path") != json_obj.end()) {
std::string schema_file_path = json_obj["schema_file_path"];
*ds = std::make_shared<TFRecordNode>(dataset_files, schema_file_path, columns_list, num_samples, shuffle,
num_shards, shard_id, shard_equal_rows, cache);
} else {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema_json_string") != json_obj.end(),
"Failed to find either schema_file_path or schema_json_string");
std::shared_ptr<SchemaObj> schema_obj = Schema();
RETURN_IF_NOT_OK(schema_obj->from_json(json_obj["schema_json_string"]));
*ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows, cache);
}
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}

View File

@ -97,5 +97,9 @@ Status ZipNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
}
Status ZipNode::from_json(std::vector<std::shared_ptr<DatasetNode>> datasets, std::shared_ptr<DatasetNode> *result) {
*result = std::make_shared<ZipNode>(datasets);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -76,6 +76,12 @@ class ZipNode : public DatasetNode {
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
/// \brief Function to read dataset in json
/// \param[in] datasets A vector of datasets for Zip input
/// \param[out] result Deserialized dataset
/// \return Status The status code returned
static Status from_json(std::vector<std::shared_ptr<DatasetNode>> datasets, std::shared_ptr<DatasetNode> *result);
private:
std::vector<std::shared_ptr<DatasetNode>> datasets_;
};

View File

@ -97,15 +97,20 @@ Status Serdes::ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<Datase
RETURN_IF_NOT_OK(ConstructPipeline(json_obj["children"][0], &child_ds));
RETURN_IF_NOT_OK(CreateNode(child_ds, json_obj, ds));
} else {
// if json object has more than 1 children, the operation must be zip.
CHECK_FAIL_RETURN_UNEXPECTED((json_obj["op_type"] == "Zip"), "Failed to find right op_type - zip");
std::vector<std::shared_ptr<DatasetNode>> datasets;
for (auto child_json_obj : json_obj["children"]) {
RETURN_IF_NOT_OK(ConstructPipeline(child_json_obj, &child_ds));
datasets.push_back(child_ds);
}
if (json_obj["op_type"] == "Zip") {
CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should zip more than 1 dataset");
*ds = std::make_shared<ZipNode>(datasets);
RETURN_IF_NOT_OK(ZipNode::from_json(datasets, ds));
} else if (json_obj["op_type"] == "Concat") {
CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should concat more than 1 dataset");
RETURN_IF_NOT_OK(ConcatNode::from_json(json_obj, datasets, ds));
} else {
return Status(StatusCode::kMDUnexpectedError, "Operation is not supported");
}
}
return Status::OK();
}
@ -125,7 +130,9 @@ Status Serdes::CreateNode(std::shared_ptr<DatasetNode> child_ds, nlohmann::json
}
Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, std::shared_ptr<DatasetNode> *ds) {
if (op_type == kCelebANode) {
if (op_type == kAlbumNode) {
RETURN_IF_NOT_OK(AlbumNode::from_json(json_obj, ds));
} else if (op_type == kCelebANode) {
RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds));
} else if (op_type == kCifar10Node) {
RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds));
@ -137,6 +144,8 @@ Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, s
RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds));
} else if (op_type == kCSVNode) {
RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds));
} else if (op_type == kFlickrNode) {
RETURN_IF_NOT_OK(FlickrNode::from_json(json_obj, ds));
} else if (op_type == kImageFolderNode) {
RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds));
} else if (op_type == kManifestNode) {
@ -227,6 +236,7 @@ Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shar
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
Serdes::InitializeFuncPtr() {
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr;
ops_ptr[vision::kAdjustGammaOperation] = &(vision::AdjustGammaOperation::from_json);
ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json);
ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json);
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json);
@ -235,6 +245,13 @@ Serdes::InitializeFuncPtr() {
ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json);
ops_ptr[vision::kCutOutOperation] = &(vision::CutOutOperation::from_json);
ops_ptr[vision::kDecodeOperation] = &(vision::DecodeOperation::from_json);
#ifdef ENABLE_ACL
ops_ptr[vision::kDvppCropJpegOperation] = &(vision::DvppCropJpegOperation::from_json);
ops_ptr[vision::kDvppDecodeResizeOperation] = &(vision::DvppDecodeResizeOperation::from_json);
ops_ptr[vision::kDvppDecodeResizeCropOperation] = &(vision::DvppDecodeResizeCropOperation::from_json);
ops_ptr[vision::kDvppNormalizeOperation] = &(vision::DvppNormalizeOperation::from_json);
ops_ptr[vision::kDvppResizeJpegOperation] = &(vision::DvppResizeJpegOperation::from_json);
#endif
ops_ptr[vision::kEqualizeOperation] = &(vision::EqualizeOperation::from_json);
ops_ptr[vision::kGaussianBlurOperation] = &(vision::GaussianBlurOperation::from_json);
ops_ptr[vision::kHorizontalFlipOperation] = &(vision::HorizontalFlipOperation::from_json);

View File

@ -32,6 +32,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
@ -43,12 +44,14 @@
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
@ -74,7 +77,9 @@
#include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"

View File

@ -528,6 +528,8 @@ class SchemaObj {
/// \return JSON string of the schema
std::string to_json() { return CharToString(to_json_char()); }
Status schema_to_json(nlohmann::json *out_json);
/// \brief Get a JSON string of the schema
std::string to_string() { return to_json(); }
@ -540,6 +542,11 @@ class SchemaObj {
/// \brief Get the current num_rows
int32_t get_num_rows() const;
/// \brief Get schema file from JSON file
/// \param[in] json_obj parsed JSON object
/// \return Status code
Status from_json(nlohmann::json json_obj);
/// \brief Get schema file from JSON file
/// \param[in] json_string Name of JSON file to be parsed.
/// \return Status code
@ -559,11 +566,6 @@ class SchemaObj {
/// \return Status code
Status parse_column(nlohmann::json columns);
/// \brief Get schema file from JSON file
/// \param[in] json_obj parsed JSON object
/// \return Status code
Status from_json(nlohmann::json json_obj);
// Char constructor of SchemaObj
explicit SchemaObj(const std::vector<char> &schema_file);

View File

@ -49,6 +49,15 @@ Status AdjustGammaOperation::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status AdjustGammaOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("gamma") != op_params.end(), "Failed to find gamma");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("gain") != op_params.end(), "Failed to find gain");
float gamma = op_params["gamma"];
float gain = op_params["gain"];
*operation = std::make_shared<vision::AdjustGammaOperation>(gamma, gain);
return Status::OK();
}
#endif
} // namespace vision

View File

@ -49,6 +49,8 @@ class AdjustGammaOperation : public TensorOperation {
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
float gamma_;
float gain_;

View File

@ -397,7 +397,6 @@ TEST_F(MindDataTestDeserialize, TestDeserializeCoco) {
TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
MS_LOG(INFO) << "Doing MindDataTestDeserialize-TFRecord.";
std::string schema = "./data/dataset/testTFTestAllTypes/datasetSchema.json";
int num_samples = 12;
int32_t num_shards = 1;
int32_t shard_id = 0;
@ -405,6 +404,11 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
std::shared_ptr<DatasetCache> cache = nullptr;
std::vector<std::string> columns_list = {};
std::vector<std::string> dataset_files = {"./data/dataset/testTFTestAllTypes/test.data"};
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeInt32, {4}));
ASSERT_OK(schema->add_column("col2", mindspore::DataType::kNumberTypeInt64, {4}));
std::shared_ptr<DatasetNode> ds =
std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, ShuffleMode::kFiles, num_shards,
shard_id, shard_equal_rows, cache);
@ -425,7 +429,7 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
std::vector<std::string> dataset_files2 = {"./data/dataset/testTextFileDataset/1.txt"};
std::shared_ptr<DatasetNode> ds_child2 =
std::make_shared<TextFileNode>(dataset_files2, 2, ShuffleMode::kFiles, 1, 0, cache);
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds_child1, ds_child2};
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds, ds_child1, ds_child2};
ds = std::make_shared<ZipNode>(datasets);
compare_dataset(ds);
}
@ -499,3 +503,27 @@ TEST_F(MindDataTestDeserialize, DISABLED_TestDeserializeCache) {
std::shared_ptr<DatasetNode> ds = std::make_shared<Cifar10Node>(data_dir, usage, sampler, some_cache);
compare_dataset(ds);
}
TEST_F(MindDataTestDeserialize, TestDeserializeConcatAlbumFlickr) {
MS_LOG(INFO) << "Doing MindDataTestDeserialize-ConcatAlbumFlickr.";
std::string dataset_dir = "./data/dataset/testAlbum";
std::vector<std::string> column_names = {"col1", "col2", "col3"};
bool decode = false;
std::shared_ptr<SamplerObj> sampler = std::make_shared<SequentialSamplerObj>(0, 10);
std::string data_schema = "./data/dataset/testAlbum/datasetSchema.json";
std::shared_ptr<DatasetNode> ds =
std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, nullptr);
std::shared_ptr<TensorOperation> operation = std::make_shared<vision::AdjustGammaOperation>(0.5, 0.5);
std::vector<std::shared_ptr<TensorOperation>> ops = {operation};
ds = std::make_shared<MapNode>(ds, ops);
std::string dataset_path = "./data/dataset/testFlickrData/flickr30k/flickr30k-images";
std::string annotation_file = "./data/dataset/testFlickrData/flickr30k/test1.token";
std::shared_ptr<DatasetNode> ds_child1 =
std::make_shared<FlickrNode>(dataset_path, annotation_file, decode, sampler, nullptr);
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds, ds_child1};
std::pair<int, int> pair = std::make_pair(1, 1);
std::vector<std::pair<int, int>> children_flag_and_nums = {pair};
std::vector<std::pair<int, int>> children_start_end_index = {pair};
ds = std::make_shared<ConcatNode>(datasets, sampler, children_flag_and_nums, children_start_end_index);
compare_dataset(ds);
}