forked from mindspore-Ecosystem/mindspore
deserialize and tests
This commit is contained in:
parent
1e4dace193
commit
b17464e30c
|
@ -656,7 +656,7 @@ Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vect
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<char> SchemaObj::to_json_char() {
|
||||
Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json json_file;
|
||||
json_file["columns"] = data_->columns_;
|
||||
std::string str_dataset_type_(data_->dataset_type_);
|
||||
|
@ -667,7 +667,13 @@ const std::vector<char> SchemaObj::to_json_char() {
|
|||
if (data_->num_rows_ > 0) {
|
||||
json_file["numRows"] = data_->num_rows_;
|
||||
}
|
||||
*out_json = json_file;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<char> SchemaObj::to_json_char() {
|
||||
nlohmann::json json_file;
|
||||
this->schema_to_json(&json_file);
|
||||
return StringToChar(json_file.dump(2));
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -148,5 +151,32 @@ Status ConcatNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
|
||||
}
|
||||
|
||||
Status ConcatNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["children_flag_and_nums"] = children_flag_and_nums_;
|
||||
args["children_start_end_index"] = children_start_end_index_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status ConcatNode::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<DatasetNode>> datasets,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_flag_and_nums") != json_obj.end(),
|
||||
"Failed to find children_flag_and_nums");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_start_end_index") != json_obj.end(),
|
||||
"Failed to find children_start_end_index");
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums = json_obj["children_flag_and_nums"];
|
||||
std::vector<std::pair<int, int>> children_start_end_index = json_obj["children_start_end_index"];
|
||||
*result = std::make_shared<ConcatNode>(datasets, sampler, children_flag_and_nums, children_start_end_index);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,6 +70,21 @@ class ConcatNode : public DatasetNode {
|
|||
|
||||
bool IsSizeDefined() override { return false; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] datasets A vector of datasets for Concat input
|
||||
/// \param[out] result Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<DatasetNode>> datasets,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
#endif
|
||||
|
||||
/// \brief Getter functions
|
||||
const std::vector<std::pair<int, int>> &ChildrenFlagAndNums() const { return children_flag_and_nums_; }
|
||||
const std::vector<std::pair<int, int>> &ChildrenStartEndIndex() const { return children_start_end_index_; }
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/album_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -125,5 +128,45 @@ Status AlbumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlbumNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["decode"] = decode_;
|
||||
args["data_schema"] = schema_path_;
|
||||
args["column_names"] = column_names_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("data_schema") != json_obj.end(), "Failed to find data_schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string data_schema = json_obj["data_schema"];
|
||||
std::vector<std::string> column_names = json_obj["column_names"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -82,6 +82,19 @@ class AlbumNode : public MappableSourceNode {
|
|||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string schema_path_;
|
||||
|
|
|
@ -22,6 +22,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/flickr_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -148,5 +152,23 @@ Status FlickrNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string annotation_file = json_obj["annotation_file"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -83,6 +83,12 @@ class FlickrNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -204,7 +204,6 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
|
|||
nlohmann::json args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_files"] = dataset_files_;
|
||||
args["schema"] = schema_path_;
|
||||
args["columns_list"] = columns_list_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["shuffle_global"] = (shuffle_ == ShuffleMode::kGlobal);
|
||||
|
@ -221,7 +220,9 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
|
|||
if (schema_obj_ != nullptr) {
|
||||
schema_obj_->set_dataset_type("TF");
|
||||
schema_obj_->set_num_rows(num_samples_);
|
||||
args["schema_json_string"] = schema_obj_->to_json();
|
||||
nlohmann::json schema_json_string;
|
||||
schema_obj_->schema_to_json(&schema_json_string);
|
||||
args["schema_json_string"] = schema_json_string;
|
||||
} else {
|
||||
args["schema_file_path"] = schema_path_;
|
||||
}
|
||||
|
@ -233,7 +234,6 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
|
@ -241,7 +241,6 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string schema = json_obj["schema"];
|
||||
std::vector<std::string> columns_list = json_obj["columns_list"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
|
@ -250,8 +249,18 @@ Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
|
|||
bool shard_equal_rows = json_obj["shard_equal_rows"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
|
||||
shard_equal_rows, cache);
|
||||
if (json_obj.find("schema_file_path") != json_obj.end()) {
|
||||
std::string schema_file_path = json_obj["schema_file_path"];
|
||||
*ds = std::make_shared<TFRecordNode>(dataset_files, schema_file_path, columns_list, num_samples, shuffle,
|
||||
num_shards, shard_id, shard_equal_rows, cache);
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema_json_string") != json_obj.end(),
|
||||
"Failed to find either schema_file_path or schema_json_string");
|
||||
std::shared_ptr<SchemaObj> schema_obj = Schema();
|
||||
RETURN_IF_NOT_OK(schema_obj->from_json(json_obj["schema_json_string"]));
|
||||
*ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
|
||||
shard_id, shard_equal_rows, cache);
|
||||
}
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -97,5 +97,9 @@ Status ZipNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
|||
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
|
||||
}
|
||||
|
||||
Status ZipNode::from_json(std::vector<std::shared_ptr<DatasetNode>> datasets, std::shared_ptr<DatasetNode> *result) {
|
||||
*result = std::make_shared<ZipNode>(datasets);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,12 @@ class ZipNode : public DatasetNode {
|
|||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] datasets A vector of datasets for Zip input
|
||||
/// \param[out] result Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(std::vector<std::shared_ptr<DatasetNode>> datasets, std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets_;
|
||||
};
|
||||
|
|
|
@ -97,15 +97,20 @@ Status Serdes::ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<Datase
|
|||
RETURN_IF_NOT_OK(ConstructPipeline(json_obj["children"][0], &child_ds));
|
||||
RETURN_IF_NOT_OK(CreateNode(child_ds, json_obj, ds));
|
||||
} else {
|
||||
// if json object has more than 1 children, the operation must be zip.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((json_obj["op_type"] == "Zip"), "Failed to find right op_type - zip");
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets;
|
||||
for (auto child_json_obj : json_obj["children"]) {
|
||||
RETURN_IF_NOT_OK(ConstructPipeline(child_json_obj, &child_ds));
|
||||
datasets.push_back(child_ds);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should zip more than 1 dataset");
|
||||
*ds = std::make_shared<ZipNode>(datasets);
|
||||
if (json_obj["op_type"] == "Zip") {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should zip more than 1 dataset");
|
||||
RETURN_IF_NOT_OK(ZipNode::from_json(datasets, ds));
|
||||
} else if (json_obj["op_type"] == "Concat") {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should concat more than 1 dataset");
|
||||
RETURN_IF_NOT_OK(ConcatNode::from_json(json_obj, datasets, ds));
|
||||
} else {
|
||||
return Status(StatusCode::kMDUnexpectedError, "Operation is not supported");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -125,7 +130,9 @@ Status Serdes::CreateNode(std::shared_ptr<DatasetNode> child_ds, nlohmann::json
|
|||
}
|
||||
|
||||
Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, std::shared_ptr<DatasetNode> *ds) {
|
||||
if (op_type == kCelebANode) {
|
||||
if (op_type == kAlbumNode) {
|
||||
RETURN_IF_NOT_OK(AlbumNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCelebANode) {
|
||||
RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCifar10Node) {
|
||||
RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds));
|
||||
|
@ -137,6 +144,8 @@ Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, s
|
|||
RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCSVNode) {
|
||||
RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kFlickrNode) {
|
||||
RETURN_IF_NOT_OK(FlickrNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kImageFolderNode) {
|
||||
RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kManifestNode) {
|
||||
|
@ -227,6 +236,7 @@ Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shar
|
|||
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
|
||||
Serdes::InitializeFuncPtr() {
|
||||
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr;
|
||||
ops_ptr[vision::kAdjustGammaOperation] = &(vision::AdjustGammaOperation::from_json);
|
||||
ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json);
|
||||
ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json);
|
||||
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json);
|
||||
|
@ -235,6 +245,13 @@ Serdes::InitializeFuncPtr() {
|
|||
ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json);
|
||||
ops_ptr[vision::kCutOutOperation] = &(vision::CutOutOperation::from_json);
|
||||
ops_ptr[vision::kDecodeOperation] = &(vision::DecodeOperation::from_json);
|
||||
#ifdef ENABLE_ACL
|
||||
ops_ptr[vision::kDvppCropJpegOperation] = &(vision::DvppCropJpegOperation::from_json);
|
||||
ops_ptr[vision::kDvppDecodeResizeOperation] = &(vision::DvppDecodeResizeOperation::from_json);
|
||||
ops_ptr[vision::kDvppDecodeResizeCropOperation] = &(vision::DvppDecodeResizeCropOperation::from_json);
|
||||
ops_ptr[vision::kDvppNormalizeOperation] = &(vision::DvppNormalizeOperation::from_json);
|
||||
ops_ptr[vision::kDvppResizeJpegOperation] = &(vision::DvppResizeJpegOperation::from_json);
|
||||
#endif
|
||||
ops_ptr[vision::kEqualizeOperation] = &(vision::EqualizeOperation::from_json);
|
||||
ops_ptr[vision::kGaussianBlurOperation] = &(vision::GaussianBlurOperation::from_json);
|
||||
ops_ptr[vision::kHorizontalFlipOperation] = &(vision::HorizontalFlipOperation::from_json);
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "minddata/dataset/core/tensor.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
|
@ -43,12 +44,14 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
@ -74,7 +77,9 @@
|
|||
#include "minddata/dataset/include/dataset/vision.h"
|
||||
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
|
||||
|
|
|
@ -528,6 +528,8 @@ class SchemaObj {
|
|||
/// \return JSON string of the schema
|
||||
std::string to_json() { return CharToString(to_json_char()); }
|
||||
|
||||
Status schema_to_json(nlohmann::json *out_json);
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
std::string to_string() { return to_json(); }
|
||||
|
||||
|
@ -540,6 +542,11 @@ class SchemaObj {
|
|||
/// \brief Get the current num_rows
|
||||
int32_t get_num_rows() const;
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_obj parsed JSON object
|
||||
/// \return Status code
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_string Name of JSON file to be parsed.
|
||||
/// \return Status code
|
||||
|
@ -559,11 +566,6 @@ class SchemaObj {
|
|||
/// \return Status code
|
||||
Status parse_column(nlohmann::json columns);
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_obj parsed JSON object
|
||||
/// \return Status code
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
// Char constructor of SchemaObj
|
||||
explicit SchemaObj(const std::vector<char> &schema_file);
|
||||
|
||||
|
|
|
@ -49,6 +49,15 @@ Status AdjustGammaOperation::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AdjustGammaOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("gamma") != op_params.end(), "Failed to find gamma");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("gain") != op_params.end(), "Failed to find gain");
|
||||
float gamma = op_params["gamma"];
|
||||
float gain = op_params["gain"];
|
||||
*operation = std::make_shared<vision::AdjustGammaOperation>(gamma, gain);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace vision
|
||||
|
|
|
@ -49,6 +49,8 @@ class AdjustGammaOperation : public TensorOperation {
|
|||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
float gamma_;
|
||||
float gain_;
|
||||
|
|
|
@ -397,7 +397,6 @@ TEST_F(MindDataTestDeserialize, TestDeserializeCoco) {
|
|||
|
||||
TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestDeserialize-TFRecord.";
|
||||
std::string schema = "./data/dataset/testTFTestAllTypes/datasetSchema.json";
|
||||
int num_samples = 12;
|
||||
int32_t num_shards = 1;
|
||||
int32_t shard_id = 0;
|
||||
|
@ -405,6 +404,11 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
|
|||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
std::vector<std::string> columns_list = {};
|
||||
std::vector<std::string> dataset_files = {"./data/dataset/testTFTestAllTypes/test.data"};
|
||||
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeInt32, {4}));
|
||||
ASSERT_OK(schema->add_column("col2", mindspore::DataType::kNumberTypeInt64, {4}));
|
||||
|
||||
std::shared_ptr<DatasetNode> ds =
|
||||
std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, ShuffleMode::kFiles, num_shards,
|
||||
shard_id, shard_equal_rows, cache);
|
||||
|
@ -425,7 +429,7 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTFRecord) {
|
|||
std::vector<std::string> dataset_files2 = {"./data/dataset/testTextFileDataset/1.txt"};
|
||||
std::shared_ptr<DatasetNode> ds_child2 =
|
||||
std::make_shared<TextFileNode>(dataset_files2, 2, ShuffleMode::kFiles, 1, 0, cache);
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds_child1, ds_child2};
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds, ds_child1, ds_child2};
|
||||
ds = std::make_shared<ZipNode>(datasets);
|
||||
compare_dataset(ds);
|
||||
}
|
||||
|
@ -499,3 +503,27 @@ TEST_F(MindDataTestDeserialize, DISABLED_TestDeserializeCache) {
|
|||
std::shared_ptr<DatasetNode> ds = std::make_shared<Cifar10Node>(data_dir, usage, sampler, some_cache);
|
||||
compare_dataset(ds);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestDeserialize, TestDeserializeConcatAlbumFlickr) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestDeserialize-ConcatAlbumFlickr.";
|
||||
std::string dataset_dir = "./data/dataset/testAlbum";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3"};
|
||||
bool decode = false;
|
||||
std::shared_ptr<SamplerObj> sampler = std::make_shared<SequentialSamplerObj>(0, 10);
|
||||
std::string data_schema = "./data/dataset/testAlbum/datasetSchema.json";
|
||||
std::shared_ptr<DatasetNode> ds =
|
||||
std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, nullptr);
|
||||
std::shared_ptr<TensorOperation> operation = std::make_shared<vision::AdjustGammaOperation>(0.5, 0.5);
|
||||
std::vector<std::shared_ptr<TensorOperation>> ops = {operation};
|
||||
ds = std::make_shared<MapNode>(ds, ops);
|
||||
std::string dataset_path = "./data/dataset/testFlickrData/flickr30k/flickr30k-images";
|
||||
std::string annotation_file = "./data/dataset/testFlickrData/flickr30k/test1.token";
|
||||
std::shared_ptr<DatasetNode> ds_child1 =
|
||||
std::make_shared<FlickrNode>(dataset_path, annotation_file, decode, sampler, nullptr);
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets = {ds, ds_child1};
|
||||
std::pair<int, int> pair = std::make_pair(1, 1);
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums = {pair};
|
||||
std::vector<std::pair<int, int>> children_start_end_index = {pair};
|
||||
ds = std::make_shared<ConcatNode>(datasets, sampler, children_flag_and_nums, children_start_end_index);
|
||||
compare_dataset(ds);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue