!20683 deserializer 2nd part

Merge pull request !20683 from zetongzhao/deserialize_2
This commit is contained in:
i-robot 2021-08-17 19:50:43 +00:00 committed by Gitee
commit 1e4dace193
70 changed files with 966 additions and 614 deletions

View File

@ -2,4 +2,5 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-ir-cache OBJECT
pre_built_dataset_cache.cc
dataset_cache_impl.cc)
dataset_cache_impl.cc
dataset_cache.cc)

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include <memory>
#include <string>
#include <optional>
#include <vector>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#endif
namespace mindspore::dataset {
#ifndef ENABLE_ANDROID
Status DatasetCache::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache) {
if (json_obj.find("cache") != json_obj.end()) {
nlohmann::json json_cache = json_obj["cache"];
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("session_id") != json_cache.end(), "Failed to find session_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("cache_memory_size") != json_cache.end(),
"Failed to find cache_memory_size");
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("spill") != json_cache.end(), "Failed to find spill");
session_id_type id = static_cast<session_id_type>(json_cache["session_id"]);
uint64_t mem_sz = json_cache["cache_memory_size"];
bool spill = json_cache["spill"];
std::optional<std::vector<char>> hostname_c = std::nullopt;
std::optional<int32_t> port = std::nullopt;
std::optional<int32_t> num_connections = std::nullopt;
std::optional<int32_t> prefetch_sz = std::nullopt;
if (json_cache.find("hostname") != json_cache.end()) {
std::optional<std::string> hostname = json_cache["hostname"];
hostname_c = std::vector<char>(hostname->begin(), hostname->end());
}
if (json_cache.find("port") != json_cache.end()) port = json_cache["port"];
if (json_cache.find("num_connections") != json_cache.end()) num_connections = json_cache["num_connections"];
if (json_cache.find("prefetch_size") != json_cache.end()) prefetch_sz = json_cache["prefetch_size"];
*cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname_c, port, num_connections, prefetch_sz);
}
return Status::OK();
}
#endif
} // namespace mindspore::dataset

View File

@ -35,6 +35,10 @@ class DatasetCache {
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
std::shared_ptr<DatasetOp> *ds) = 0;
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
#ifndef ENABLE_ANDROID
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache);
#endif
};
} // namespace mindspore::dataset

View File

@ -169,5 +169,19 @@ Status BatchNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder");
int32_t batch_size = json_obj["batch_size"];
bool drop_remainder = json_obj["drop_remainder"];
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -105,6 +105,14 @@ class BatchNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
int32_t batch_size_;
bool drop_remainder_;

View File

@ -22,6 +22,9 @@
#include <utility>
#include <vector>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
@ -154,7 +157,6 @@ Status MapNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
std::vector<nlohmann::json> ops;
std::vector<int32_t> cbs;
for (auto op : operations_) {
@ -177,6 +179,26 @@ Status MapNode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations");
std::vector<std::string> input_columns = json_obj["input_columns"];
std::vector<std::string> output_columns = json_obj["output_columns"];
std::vector<std::string> project_columns = json_obj["project_columns"];
std::vector<std::shared_ptr<TensorOperation>> operations;
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_obj["operations"], &operations));
*result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns);
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
// Gets the dataset size
Status MapNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {

View File

@ -93,6 +93,16 @@ class MapNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
#endif
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting

View File

@ -66,5 +66,13 @@ Status ProjectNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns");
std::vector<std::string> columns = json_obj["columns"];
*result = std::make_shared<ProjectNode>(ds, columns);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -63,6 +63,14 @@ class ProjectNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
std::vector<std::string> columns_;
};

View File

@ -72,5 +72,16 @@ Status RenameNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
std::vector<std::string> input_columns = json_obj["input_columns"];
std::vector<std::string> output_columns = json_obj["output_columns"];
*result = std::make_shared<RenameNode>(ds, input_columns, output_columns);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -65,6 +65,14 @@ class RenameNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;

View File

@ -104,5 +104,14 @@ Status RepeatNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status RepeatNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<RepeatNode>(ds, count);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -123,6 +123,14 @@ class RepeatNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
protected:
std::shared_ptr<RepeatOp> op_; // keep its corresponding run-time op of EpochCtrlNode and RepeatNode
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass

View File

@ -66,9 +66,19 @@ Status ShuffleNode::ValidateParams() {
Status ShuffleNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["buffer_size"] = shuffle_size_;
args["reshuffle_each_epoch"] = reset_every_epoch_;
args["reset_each_epoch"] = reset_every_epoch_;
*out_json = args;
return Status::OK();
}
Status ShuffleNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reset_each_epoch") != json_obj.end(), "Failed to find reset_each_epoch");
int32_t buffer_size = json_obj["buffer_size"];
bool reset_every_epoch = json_obj["reset_each_epoch"];
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -63,6 +63,14 @@ class ShuffleNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;

View File

@ -93,5 +93,13 @@ Status SkipNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status SkipNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<SkipNode>(ds, count);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -88,6 +88,14 @@ class SkipNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
int32_t skip_count_;
};

View File

@ -25,6 +25,9 @@
#include "debug/common.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -182,5 +185,28 @@ Status CelebANode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
bool decode = json_obj["decode"];
std::set<std::string> extension = json_obj["extensions"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -82,6 +82,14 @@ class CelebANode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -22,6 +22,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -117,5 +120,24 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -78,6 +78,14 @@ class Cifar100Node : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -22,6 +22,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -118,5 +121,24 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -78,6 +78,14 @@ class Cifar10Node : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -249,6 +249,29 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_dir"];
std::string task = json_obj["task"];
std::string usage = json_obj["usage"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent
// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is
// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing

View File

@ -86,6 +86,12 @@ class CLUENode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief CLUE by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.

View File

@ -22,6 +22,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -181,6 +184,7 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
args["annotation_file"] = annotation_file_;
args["task"] = task_;
args["decode"] = decode_;
args["extra_metadata"] = extra_metadata_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
@ -189,5 +193,30 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
std::string dataset_dir = json_obj["dataset_dir"];
std::string annotation_file = json_obj["annotation_file"];
std::string task = json_obj["task"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
bool extra_metadata = json_obj["extra_metadata"];
*ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -81,6 +81,14 @@ class CocoNode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -187,6 +187,32 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
std::string field_delim = json_obj["field_delim"];
std::vector<std::shared_ptr<CsvBase>> column_defaults = {};
std::vector<std::string> column_names = json_obj["column_names"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples,
shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// CSV by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can

View File

@ -107,6 +107,12 @@ class CSVNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief CSV by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.

View File

@ -24,6 +24,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -113,6 +116,7 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["recursive"] = recursive_;
args["decode"] = decode_;
args["extensions"] = exts_;
args["class_indexing"] = class_indexing_;
@ -124,5 +128,36 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("recursive") != json_obj.end(), "Failed to find recursive");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
std::string dataset_dir = json_obj["dataset_dir"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
bool recursive = json_obj["recursive"];
std::set<std::string> extension = json_obj["extensions"];
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -87,6 +87,14 @@ class ImageFolderNode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -23,6 +23,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -152,5 +155,34 @@ Status ManifestNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
std::string dataset_file = json_obj["dataset_file"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
bool decode = json_obj["decode"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -78,9 +78,18 @@ class ManifestNode : public MappableSourceNode {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \param[in] cache Dataset cache for constructor input
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -22,6 +22,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -111,5 +114,24 @@ Status MnistNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -78,6 +78,14 @@ class MnistNode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -106,6 +106,30 @@ Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
int64_t num_shards = json_obj["num_shards"];
int64_t shard_id = json_obj["shard_id"];
bool shuffle = json_obj["shuffle"];
uint32_t seed = json_obj["seed"];
int64_t offset = json_obj["offset"];
bool even_dist = json_obj["even_dist"];
*sampler =
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() {
auto sampler =
std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_);

View File

@ -56,6 +56,15 @@ class DistributedSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler

View File

@ -60,6 +60,19 @@ Status PKSamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status PKSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
int64_t num_val = json_obj["num_val"];
bool shuffle = json_obj["shuffle"];
*sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
Status PKSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
*sampler = std::make_shared<dataset::PKSamplerRT>(num_val_, shuffle_, num_samples_);

View File

@ -55,6 +55,15 @@ class PKSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
private:

View File

@ -56,6 +56,20 @@ Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status RandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
"Failed to find reshuffle_each_epoch");
bool replacement = json_obj["replacement"];
bool reshuffle_each_epoch = json_obj["reshuffle_each_epoch"];
*sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
*sampler = std::make_shared<dataset::RandomSamplerRT>(replacement_, num_samples_, reshuffle_each_epoch_);

View File

@ -55,6 +55,15 @@ class RandomSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
private:

View File

@ -16,6 +16,9 @@
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/core/config_manager.h"
@ -73,5 +76,15 @@ Status SamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status SamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler) {
for (nlohmann::json child : json_obj["child_sampler"]) {
std::shared_ptr<SamplerObj> child_sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(child, &child_sampler));
(*parent_sampler)->AddChildSampler(child_sampler);
}
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -67,6 +67,14 @@ class SamplerObj {
virtual Status to_json(nlohmann::json *const out_json);
#ifndef ENABLE_ANDROID
/// \brief Function to construct children samplers
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] parent_sampler given parent sampler, output constructed parent sampler with children samplers added
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler);
#endif
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID

View File

@ -61,6 +61,18 @@ Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status SequentialSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index");
int64_t start_index = json_obj["start_index"];
*sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
*sampler = std::make_shared<dataset::SequentialSamplerRT>(start_index_, num_samples_);

View File

@ -55,6 +55,15 @@ class SequentialSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
private:

View File

@ -63,6 +63,19 @@ Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status SubsetRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
std::vector<int64_t> indices = json_obj["indices"];
*sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
std::shared_ptr<SamplerObj> SubsetRandomSamplerObj::SamplerCopy() {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (const auto &child : children_) {

View File

@ -45,6 +45,10 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override;

View File

@ -72,6 +72,17 @@ Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status SubsetSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
std::vector<int64_t> indices = json_obj["indices"];
*sampler = std::make_shared<SubsetSamplerObj>(indices, num_samples);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
std::shared_ptr<SamplerObj> SubsetSamplerObj::SamplerCopy() {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
for (const auto &child : children_) {

View File

@ -55,6 +55,15 @@ class SubsetSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
protected:

View File

@ -63,6 +63,20 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
std::vector<double> weights = json_obj["weights"];
bool replacement = json_obj["replacement"];
*sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
// Run common code in super class to add children samplers
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
return Status::OK();
}
#endif
Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
*sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
Status s = BuildChildren(sampler);

View File

@ -51,6 +51,15 @@ class WeightedRandomSamplerObj : public SamplerObj {
/// \return Status of the function
Status to_json(nlohmann::json *const out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function for read sampler from JSON object
/// \param[in] json_obj JSON object to be read
/// \param[in] num_samples number of sample in the sampler
/// \param[out] sampler Sampler constructed from parameters in JSON object
/// \return Status of the function
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
#endif
Status ValidateParams() override;
private:

View File

@ -153,6 +153,26 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// TextFile by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can

View File

@ -83,6 +83,12 @@ class TextFileNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief TextFile by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.

View File

@ -229,6 +229,33 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
std::string schema = json_obj["schema"];
std::vector<std::string> columns_list = json_obj["columns_list"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
bool shard_equal_rows = json_obj["shard_equal_rows"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// TFRecord by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can

View File

@ -126,6 +126,12 @@ class TFRecordNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief TFRecord by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.

View File

@ -23,6 +23,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -169,6 +172,7 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
args["usage"] = usage_;
args["class_indexing"] = class_index_;
args["decode"] = decode_;
args["extra_metadata"] = extra_metadata_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
@ -177,5 +181,38 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
std::string dataset_dir = json_obj["dataset_dir"];
std::string task = json_obj["task"];
std::string usage = json_obj["usage"];
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
bool extra_metadata = json_obj["extra_metadata"];
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -83,6 +83,14 @@ class VOCNode : public MappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
#ifndef ENABLE_ANDROID
/// \brief Function to read dataset in json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
#endif
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

View File

@ -91,5 +91,13 @@ Status TakeNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status TakeNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<TakeNode>(ds, count);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -88,6 +88,14 @@ class TakeNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
int32_t take_count_;
};

View File

@ -126,5 +126,25 @@ Status TransferNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status TransferNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(),
"Failed to find create_data_info_queue");
std::string queue_name = json_obj["queue_name"];
std::string device_type = json_obj["device_type"];
int32_t device_id = json_obj["device_id"];
bool send_epoch_end = json_obj["send_epoch_end"];
int32_t total_batch = json_obj["total_batch"];
bool create_data_info_queue = json_obj["create_data_info_queue"];
*result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch,
create_data_info_queue);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -84,6 +84,14 @@ class TransferNode : public DatasetNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief Function for read dataset operation from json
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] ds dataset node constructed
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
private:
std::string queue_name_;
int32_t device_id_;

View File

@ -124,584 +124,97 @@ Status Serdes::CreateNode(std::shared_ptr<DatasetNode> child_ds, nlohmann::json
return Status::OK();
}
Status Serdes::CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
bool decode = json_obj["decode"];
std::set<std::string> extension = json_obj["extensions"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache);
return Status::OK();
}
Status Serdes::CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
return Status::OK();
}
Status Serdes::CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
return Status::OK();
}
Status Serdes::CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_dir"];
std::string task = json_obj["task"];
std::string usage = json_obj["usage"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
return Status::OK();
}
Status Serdes::CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string annotation_file = json_obj["annotation_file"];
std::string task = json_obj["task"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// default value for cache and extra_metadata - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
bool extra_metadata = false;
*ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata);
return Status::OK();
}
Status Serdes::CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
std::string field_delim = json_obj["field_delim"];
std::vector<std::shared_ptr<CsvBase>> column_defaults = {};
std::vector<std::string> column_names = json_obj["column_names"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples,
shuffle, num_shards, shard_id, cache);
return Status::OK();
}
Status Serdes::CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
std::string dataset_dir = json_obj["dataset_dir"];
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false;
std::set<std::string> extension = json_obj["extensions"];
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache);
return Status::OK();
}
Status Serdes::CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
std::string dataset_file = json_obj["dataset_file"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
bool decode = json_obj["decode"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
return Status::OK();
}
Status Serdes::CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string usage = json_obj["usage"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
return Status::OK();
}
Status Serdes::CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
return Status::OK();
}
Status Serdes::CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
std::vector<std::string> dataset_files = json_obj["dataset_files"];
std::string schema = json_obj["schema"];
std::vector<std::string> columns_list = json_obj["columns_list"];
int64_t num_samples = json_obj["num_samples"];
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
int32_t num_shards = json_obj["num_shards"];
int32_t shard_id = json_obj["shard_id"];
bool shard_equal_rows = json_obj["shard_equal_rows"];
// default value for cache - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache);
return Status::OK();
}
Status Serdes::CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
std::string dataset_dir = json_obj["dataset_dir"];
std::string task = json_obj["task"];
std::string usage = json_obj["usage"];
std::map<std::string, int32_t> class_indexing;
nlohmann::json class_map = json_obj["class_indexing"];
for (const auto &class_map_child : class_map) {
std::string class_ = class_map_child[0];
int32_t indexing = class_map_child[1];
class_indexing.insert({class_, indexing});
}
bool decode = json_obj["decode"];
std::shared_ptr<SamplerObj> sampler;
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
// default value for cache and extra_metadata - to_json function does not have the output
std::shared_ptr<DatasetCache> cache = nullptr;
bool extra_metadata = false;
*ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata);
return Status::OK();
}
Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, std::shared_ptr<DatasetNode> *ds) {
if (op_type == kCelebANode) {
RETURN_IF_NOT_OK(CreateCelebADatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds));
} else if (op_type == kCifar10Node) {
RETURN_IF_NOT_OK(CreateCifar10DatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds));
} else if (op_type == kCifar100Node) {
RETURN_IF_NOT_OK(CreateCifar100DatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(Cifar100Node::from_json(json_obj, ds));
} else if (op_type == kCLUENode) {
RETURN_IF_NOT_OK(CreateCLUEDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(CLUENode::from_json(json_obj, ds));
} else if (op_type == kCocoNode) {
RETURN_IF_NOT_OK(CreateCocoDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds));
} else if (op_type == kCSVNode) {
RETURN_IF_NOT_OK(CreateCSVDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds));
} else if (op_type == kImageFolderNode) {
RETURN_IF_NOT_OK(CreateImageFolderDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds));
} else if (op_type == kManifestNode) {
RETURN_IF_NOT_OK(CreateManifestDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(ManifestNode::from_json(json_obj, ds));
} else if (op_type == kMnistNode) {
RETURN_IF_NOT_OK(CreateMnistDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(MnistNode::from_json(json_obj, ds));
} else if (op_type == kTextFileNode) {
RETURN_IF_NOT_OK(CreateTextFileDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(TextFileNode::from_json(json_obj, ds));
} else if (op_type == kTFRecordNode) {
RETURN_IF_NOT_OK(CreateTFRecordDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(TFRecordNode::from_json(json_obj, ds));
} else if (op_type == kVOCNode) {
RETURN_IF_NOT_OK(CreateVOCDatasetNode(json_obj, ds));
RETURN_IF_NOT_OK(VOCNode::from_json(json_obj, ds));
} else {
return Status(StatusCode::kMDUnexpectedError, op_type + " is not supported");
}
return Status::OK();
}
Status Serdes::CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder");
int32_t batch_size = json_obj["batch_size"];
bool drop_remainder = json_obj["drop_remainder"];
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
return Status::OK();
}
Status Serdes::CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
"Failed to find num_parallel_workers");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations");
std::vector<std::string> input_columns = json_obj["input_columns"];
std::vector<std::string> output_columns = json_obj["output_columns"];
std::vector<std::string> project_columns = json_obj["project_columns"];
std::vector<std::shared_ptr<TensorOperation>> operations;
RETURN_IF_NOT_OK(ConstructTensorOps(json_obj["operations"], &operations));
*result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns);
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
return Status::OK();
}
Status Serdes::CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns");
std::vector<std::string> columns = json_obj["columns"];
*result = std::make_shared<ProjectNode>(ds, columns);
return Status::OK();
}
Status Serdes::CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
std::vector<std::string> input_columns = json_obj["input_columns"];
std::vector<std::string> output_columns = json_obj["output_columns"];
*result = std::make_shared<RenameNode>(ds, input_columns, output_columns);
return Status::OK();
}
Status Serdes::CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<RepeatNode>(ds, count);
return Status::OK();
}
Status Serdes::CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
"Failed to find reshuffle_each_epoch");
int32_t buffer_size = json_obj["buffer_size"];
bool reset_every_epoch = json_obj["reshuffle_each_epoch"];
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
return Status::OK();
}
Status Serdes::CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<SkipNode>(ds, count);
return Status::OK();
}
Status Serdes::CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(),
"Failed to find create_data_info_queue");
std::string queue_name = json_obj["queue_name"];
std::string device_type = json_obj["device_type"];
int32_t device_id = json_obj["device_id"];
bool send_epoch_end = json_obj["send_epoch_end"];
int32_t total_batch = json_obj["total_batch"];
bool create_data_info_queue = json_obj["create_data_info_queue"];
*result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch,
create_data_info_queue);
return Status::OK();
}
Status Serdes::CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
int32_t count = json_obj["count"];
*result = std::make_shared<TakeNode>(ds, count);
return Status::OK();
}
Status Serdes::CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, std::string op_type,
std::shared_ptr<DatasetNode> *result) {
if (op_type == kBatchNode) {
RETURN_IF_NOT_OK(CreateBatchOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(BatchNode::from_json(json_obj, ds, result));
} else if (op_type == kMapNode) {
RETURN_IF_NOT_OK(CreateMapOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(MapNode::from_json(json_obj, ds, result));
} else if (op_type == kProjectNode) {
RETURN_IF_NOT_OK(CreateProjectOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(ProjectNode::from_json(json_obj, ds, result));
} else if (op_type == kRenameNode) {
RETURN_IF_NOT_OK(CreateRenameOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(RenameNode::from_json(json_obj, ds, result));
} else if (op_type == kRepeatNode) {
RETURN_IF_NOT_OK(CreateRepeatOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(RepeatNode::from_json(json_obj, ds, result));
} else if (op_type == kShuffleNode) {
RETURN_IF_NOT_OK(CreateShuffleOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(ShuffleNode::from_json(json_obj, ds, result));
} else if (op_type == kSkipNode) {
RETURN_IF_NOT_OK(CreateSkipOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(SkipNode::from_json(json_obj, ds, result));
} else if (op_type == kTransferNode) {
RETURN_IF_NOT_OK(CreateTransferOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(TransferNode::from_json(json_obj, ds, result));
} else if (op_type == kTakeNode) {
RETURN_IF_NOT_OK(CreateTakeOperationNode(ds, json_obj, result));
RETURN_IF_NOT_OK(TakeNode::from_json(json_obj, ds, result));
} else {
return Status(StatusCode::kMDUnexpectedError, op_type + " operation is not supported");
}
return Status::OK();
}
Status Serdes::ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
int64_t num_shards = json_obj["num_shards"];
int64_t shard_id = json_obj["shard_id"];
bool shuffle = json_obj["shuffle"];
uint32_t seed = json_obj["seed"];
int64_t offset = json_obj["offset"];
bool even_dist = json_obj["even_dist"];
*sampler =
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
int64_t num_val = json_obj["num_val"];
bool shuffle = json_obj["shuffle"];
*sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
bool replacement = json_obj["replacement"];
*sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index");
int64_t start_index = json_obj["start_index"];
*sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
std::vector<int64_t> indices = json_obj["indices"];
*sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
bool replacement = json_obj["replacement"];
std::vector<double> weights = json_obj["weights"];
*sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
if (json_obj.find("child_sampler") != json_obj.end()) {
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
}
return Status::OK();
}
Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler_name") != json_obj.end(), "Failed to find sampler_name");
int64_t num_samples = json_obj["num_samples"];
std::string sampler_name = json_obj["sampler_name"];
if (sampler_name == "DistributedSampler") {
RETURN_IF_NOT_OK(ConstructDistributedSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(DistributedSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "PKSampler") {
RETURN_IF_NOT_OK(ConstructPKSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(PKSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "RandomSampler") {
RETURN_IF_NOT_OK(ConstructRandomSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(RandomSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "SequentialSampler") {
RETURN_IF_NOT_OK(ConstructSequentialSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(SequentialSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "SubsetSampler") {
RETURN_IF_NOT_OK(SubsetSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "SubsetRandomSampler") {
RETURN_IF_NOT_OK(ConstructSubsetRandomSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(SubsetRandomSamplerObj::from_json(json_obj, num_samples, sampler));
} else if (sampler_name == "WeightedRandomSampler") {
RETURN_IF_NOT_OK(ConstructWeightedRandomSampler(json_obj, num_samples, sampler));
RETURN_IF_NOT_OK(WeightedRandomSamplerObj::from_json(json_obj, num_samples, sampler));
} else {
return Status(StatusCode::kMDUnexpectedError, sampler_name + "Sampler is not supported");
}
return Status::OK();
}
Status Serdes::ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler,
std::shared_ptr<SamplerObj> *sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("child_sampler") != json_obj.end(), "Failed to find child_sampler");
for (nlohmann::json child : json_obj["child_sampler"]) {
std::shared_ptr<SamplerObj> child_sampler;
RETURN_IF_NOT_OK(ConstructSampler(child, &child_sampler));
parent_sampler.get()->AddChildSampler(child_sampler);
}
return Status::OK();
}
Status Serdes::BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio");
std::vector<std::shared_ptr<TensorOperation>> transforms;
std::vector<nlohmann::json> json_operations = {};
json_operations.push_back(op_params["transform"]);
RETURN_IF_NOT_OK(ConstructTensorOps(json_operations, &transforms));
float ratio = op_params["ratio"];
CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1,
"Expect size one of transforms parameter, but got:" + std::to_string(transforms.size()));
*operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio);
return Status::OK();
}
Status Serdes::RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy");
nlohmann::json policy_json = op_params["policy"];
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy;
std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items;
for (nlohmann::json item : policy_json) {
for (nlohmann::json item_pair : item) {
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob");
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op");
std::vector<std::shared_ptr<TensorOperation>> operations;
std::pair<std::shared_ptr<TensorOperation>, double> policy_pair;
std::shared_ptr<TensorOperation> operation;
nlohmann::json tensor_op_json;
double prob = item_pair["prob"];
tensor_op_json.push_back(item_pair["tensor_op"]);
RETURN_IF_NOT_OK(ConstructTensorOps(tensor_op_json, &operations));
CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation");
policy_pair = std::make_pair(operations[0], prob);
policy_items.push_back(policy_pair);
}
policy.push_back(policy_items);
}
*operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy);
return Status::OK();
}
Status Serdes::UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops");
std::vector<std::shared_ptr<TensorOperation>> transforms = {};
RETURN_IF_NOT_OK(ConstructTensorOps(op_params["transforms"], &transforms));
int32_t num_ops = op_params["num_ops"];
*operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops);
return Status::OK();
}
Status Serdes::ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result) {
Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
std::vector<std::shared_ptr<TensorOperation>> output;
for (auto op : operations) {
CHECK_FAIL_RETURN_UNEXPECTED(op.find("is_python_front_end_op") == op.end(),
for (nlohmann::json item : json_obj) {
CHECK_FAIL_RETURN_UNEXPECTED(item.find("is_python_front_end_op") == item.end(),
"python operation is not yet supported");
CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_name") != op.end(), "Failed to find tensor_op_name");
CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_params") != op.end(), "Failed to find tensor_op_params");
std::string op_name = op["tensor_op_name"];
nlohmann::json op_params = op["tensor_op_params"];
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
std::string op_name = item["tensor_op_name"];
nlohmann::json op_params = item["tensor_op_params"];
std::shared_ptr<TensorOperation> operation = nullptr;
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
@ -716,7 +229,7 @@ Serdes::InitializeFuncPtr() {
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr;
ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json);
ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json);
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(BoundingBoxAugmentFromJson);
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json);
ops_ptr[vision::kCenterCropOperation] = &(vision::CenterCropOperation::from_json);
ops_ptr[vision::kCropOperation] = &(vision::CropOperation::from_json);
ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json);
@ -745,7 +258,7 @@ Serdes::InitializeFuncPtr() {
ops_ptr[vision::kRandomResizedCropOperation] = &(vision::RandomResizedCropOperation::from_json);
ops_ptr[vision::kRandomResizedCropWithBBoxOperation] = &(vision::RandomResizedCropWithBBoxOperation::from_json);
ops_ptr[vision::kRandomRotationOperation] = &(vision::RandomRotationOperation::from_json);
ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(RandomSelectSubpolicyFromJson);
ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(vision::RandomSelectSubpolicyOperation::from_json);
ops_ptr[vision::kRandomSharpnessOperation] = &(vision::RandomSharpnessOperation::from_json);
ops_ptr[vision::kRandomSolarizeOperation] = &(vision::RandomSolarizeOperation::from_json);
ops_ptr[vision::kRandomVerticalFlipOperation] = &(vision::RandomVerticalFlipOperation::from_json);
@ -766,7 +279,7 @@ Serdes::InitializeFuncPtr() {
&(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json);
ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json);
ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json);
ops_ptr[vision::kUniformAugOperation] = &(UniformAugFromJson);
ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json);
ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json);
ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json);
ops_ptr[transforms::kOneHotOperation] = &(transforms::OneHotOperation::from_json);

View File

@ -159,6 +159,18 @@ class Serdes {
/// \return Status The status code returned
static Status ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief Helper functions for creating sampler, separate different samplers and call the related function
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] sampler Deserialized sampler
/// \return Status The status code returned
static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
/// \brief helper function to construct tensor operations
/// \param[in] json_obj json object of operations to be deserilized
/// \param[out] vector of tensor operation pointer
/// \return Status The status code returned
static Status ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result);
protected:
/// \brief Helper function to save JSON to a file
/// \param[in] json_string The JSON string to be saved to the file
@ -189,91 +201,6 @@ class Serdes {
static Status CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::string op_type, std::shared_ptr<DatasetNode> *result);
/// \brief Helper functions for creating sampler, separate different samplers and call the related function
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] sampler Deserialized sampler
/// \return Status The status code returned
static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
/// \brief helper function to construct tensor operations
/// \param[in] operations operations to be deserilized
/// \param[out] vector of tensor operation pointer
/// \return Status The status code returned
static Status ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result);
/// \brief Helper functions for different datasets
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] ds Deserialized dataset
/// \return Status The status code returned
static Status CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
static Status CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
/// \brief Helper functions for different operations
/// \param[in] ds dataset node constructed
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] result Deserialized dataset after the operation
/// \return Status The status code returned
static Status CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
static Status CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
std::shared_ptr<DatasetNode> *result);
/// \brief Helper functions for different samplers
/// \param[in] json_obj The JSON object to be deserialized
/// \param[out] sampler Deserialized sampler
/// \return Status The status code returned
static Status ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler);
static Status ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
static Status ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler);
static Status ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler);
static Status ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler);
static Status ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples,
std::shared_ptr<SamplerObj> *sampler);
/// \brief Helper functions to construct children samplers
/// \param[in] json_obj The JSON object to be deserialized
/// \param[in] parent_sampler given parent sampler
/// \param[out] sampler sampler constructed - parent sampler with children samplers added
/// \return Status The status code returned
static Status ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler,
std::shared_ptr<SamplerObj> *sampler);
/// \brief Helper functions for vision operations, which requires tensor operations as input
/// \param[in] op_params operation parameters for the operation
/// \param[out] operation deserialized operation
/// \return Status The status code returned
static Status BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
static Status RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
static Status UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
/// \brief Helper function to map the function pointers
/// \return map of key to function pointer
static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
#endif
@ -56,6 +57,20 @@ Status BoundingBoxAugmentOperation::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status BoundingBoxAugmentOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio");
std::vector<std::shared_ptr<TensorOperation>> transforms;
std::vector<nlohmann::json> json_operations = {};
json_operations.push_back(op_params["transform"]);
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_operations, &transforms));
float ratio = op_params["ratio"];
CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1,
"Expect size one of transforms parameter, but got:" + std::to_string(transforms.size()));
*operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset

View File

@ -49,6 +49,8 @@ class BoundingBoxAugmentOperation : public TensorOperation {
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
std::shared_ptr<TensorOperation> transform_;
float ratio_;

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/kernels/ir/vision/random_select_subpolicy_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
#endif
@ -100,6 +101,33 @@ Status RandomSelectSubpolicyOperation::to_json(nlohmann::json *out_json) {
(*out_json)["policy"] = policy_tensor_ops;
return Status::OK();
}
Status RandomSelectSubpolicyOperation::from_json(nlohmann::json op_params,
std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy");
nlohmann::json policy_json = op_params["policy"];
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy;
std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items;
for (nlohmann::json item : policy_json) {
for (nlohmann::json item_pair : item) {
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob");
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op");
std::vector<std::shared_ptr<TensorOperation>> operations;
std::pair<std::shared_ptr<TensorOperation>, double> policy_pair;
std::shared_ptr<TensorOperation> operation;
nlohmann::json tensor_op_json;
double prob = item_pair["prob"];
tensor_op_json.push_back(item_pair["tensor_op"]);
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(tensor_op_json, &operations));
CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation");
policy_pair = std::make_pair(operations[0], prob);
policy_items.push_back(policy_pair);
}
policy.push_back(policy_items);
}
*operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset

View File

@ -50,6 +50,8 @@ class RandomSelectSubpolicyOperation : public TensorOperation {
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy_;
};

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/serdes.h"
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
#endif
@ -74,6 +75,16 @@ Status UniformAugOperation::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status UniformAugOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops");
std::vector<std::shared_ptr<TensorOperation>> transforms = {};
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(op_params["transforms"], &transforms));
int32_t num_ops = op_params["num_ops"];
*operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset

View File

@ -49,6 +49,8 @@ class UniformAugOperation : public TensorOperation {
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
int32_t num_ops_;

View File

@ -462,6 +462,7 @@ TEST_F(MindDataTestDeserialize, TestDeserializeFill) {
std::shared_ptr<TensorOperation> operation2 = std::make_shared<text::ToNumberOperation>("int32_t");
std::vector<std::shared_ptr<TensorOperation>> ops = {operation1, operation2};
ds = std::make_shared<MapNode>(ds, ops);
ds = std::make_shared<TransferNode>(ds, "queue", "type", 1, true, 10, true);
compare_dataset(ds);
}
@ -482,3 +483,19 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTensor) {
json_ss1 << json_obj1;
EXPECT_EQ(json_ss.str(), json_ss1.str());
}
// Helper function to get the session id from SESSION_ID env variable
Status GetSessionFromEnv(session_id_type *session_id);
TEST_F(MindDataTestDeserialize, DISABLED_TestDeserializeCache) {
MS_LOG(INFO) << "Doing MindDataTestDeserialize-Cache.";
std::string data_dir = "./data/dataset/testCache";
std::string usage = "all";
session_id_type env_session;
ASSERT_TRUE(GetSessionFromEnv(&env_session));
std::shared_ptr<DatasetCache> some_cache = CreateDatasetCache(env_session, 0, false, "127.0.0.1", 50052, 1, 1);
std::shared_ptr<SamplerObj> sampler = std::make_shared<SequentialSamplerObj>(0, 10);
std::shared_ptr<DatasetNode> ds = std::make_shared<Cifar10Node>(data_dir, usage, sampler, some_cache);
compare_dataset(ds);
}