forked from mindspore-Ecosystem/mindspore
!20683 deserializer 2nd part
Merge pull request !20683 from zetongzhao/deserialize_2
This commit is contained in:
commit
1e4dace193
|
@ -2,4 +2,5 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-ir-cache OBJECT
|
||||
pre_built_dataset_cache.cc
|
||||
dataset_cache_impl.cc)
|
||||
dataset_cache_impl.cc
|
||||
dataset_cache.cc)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::dataset {
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status DatasetCache::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache) {
|
||||
if (json_obj.find("cache") != json_obj.end()) {
|
||||
nlohmann::json json_cache = json_obj["cache"];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("session_id") != json_cache.end(), "Failed to find session_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("cache_memory_size") != json_cache.end(),
|
||||
"Failed to find cache_memory_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_cache.find("spill") != json_cache.end(), "Failed to find spill");
|
||||
session_id_type id = static_cast<session_id_type>(json_cache["session_id"]);
|
||||
uint64_t mem_sz = json_cache["cache_memory_size"];
|
||||
bool spill = json_cache["spill"];
|
||||
std::optional<std::vector<char>> hostname_c = std::nullopt;
|
||||
std::optional<int32_t> port = std::nullopt;
|
||||
std::optional<int32_t> num_connections = std::nullopt;
|
||||
std::optional<int32_t> prefetch_sz = std::nullopt;
|
||||
if (json_cache.find("hostname") != json_cache.end()) {
|
||||
std::optional<std::string> hostname = json_cache["hostname"];
|
||||
hostname_c = std::vector<char>(hostname->begin(), hostname->end());
|
||||
}
|
||||
if (json_cache.find("port") != json_cache.end()) port = json_cache["port"];
|
||||
if (json_cache.find("num_connections") != json_cache.end()) num_connections = json_cache["num_connections"];
|
||||
if (json_cache.find("prefetch_size") != json_cache.end()) prefetch_sz = json_cache["prefetch_size"];
|
||||
*cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname_c, port, num_connections, prefetch_sz);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore::dataset
|
|
@ -35,6 +35,10 @@ class DatasetCache {
|
|||
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetCache> *cache);
|
||||
#endif
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
|
|
|
@ -169,5 +169,19 @@ Status BatchNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder");
|
||||
int32_t batch_size = json_obj["batch_size"];
|
||||
bool drop_remainder = json_obj["drop_remainder"];
|
||||
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
|
||||
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -105,6 +105,14 @@ class BatchNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
int32_t batch_size_;
|
||||
bool drop_remainder_;
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
@ -154,7 +157,6 @@ Status MapNode::to_json(nlohmann::json *out_json) {
|
|||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
|
||||
std::vector<nlohmann::json> ops;
|
||||
std::vector<int32_t> cbs;
|
||||
for (auto op : operations_) {
|
||||
|
@ -177,6 +179,26 @@ Status MapNode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations");
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
std::vector<std::string> project_columns = json_obj["project_columns"];
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_obj["operations"], &operations));
|
||||
*result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns);
|
||||
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Gets the dataset size
|
||||
Status MapNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
|
|
|
@ -93,6 +93,16 @@ class MapNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
#endif
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
|
|
|
@ -66,5 +66,13 @@ Status ProjectNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns");
|
||||
std::vector<std::string> columns = json_obj["columns"];
|
||||
*result = std::make_shared<ProjectNode>(ds, columns);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,6 +63,14 @@ class ProjectNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
|
|
@ -72,5 +72,16 @@ Status RenameNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
*result = std::make_shared<RenameNode>(ds, input_columns, output_columns);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,6 +65,14 @@ class RenameNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
|
|
|
@ -104,5 +104,14 @@ Status RepeatNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RepeatNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<RepeatNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -123,6 +123,14 @@ class RepeatNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
protected:
|
||||
std::shared_ptr<RepeatOp> op_; // keep its corresponding run-time op of EpochCtrlNode and RepeatNode
|
||||
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
|
||||
|
|
|
@ -66,9 +66,19 @@ Status ShuffleNode::ValidateParams() {
|
|||
Status ShuffleNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["buffer_size"] = shuffle_size_;
|
||||
args["reshuffle_each_epoch"] = reset_every_epoch_;
|
||||
args["reset_each_epoch"] = reset_every_epoch_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShuffleNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reset_each_epoch") != json_obj.end(), "Failed to find reset_each_epoch");
|
||||
int32_t buffer_size = json_obj["buffer_size"];
|
||||
bool reset_every_epoch = json_obj["reset_each_epoch"];
|
||||
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,6 +63,14 @@ class ShuffleNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
int32_t shuffle_size_;
|
||||
uint32_t shuffle_seed_;
|
||||
|
|
|
@ -93,5 +93,13 @@ Status SkipNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<SkipNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,6 +88,14 @@ class SkipNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
|
||||
#include "debug/common.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -182,5 +185,28 @@ Status CelebANode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
bool decode = json_obj["decode"];
|
||||
std::set<std::string> extension = json_obj["extensions"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -82,6 +82,14 @@ class CelebANode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -117,5 +120,24 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,14 @@ class Cifar100Node : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -118,5 +121,24 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,14 @@ class Cifar10Node : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -249,6 +249,29 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent
|
||||
// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is
|
||||
// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing
|
||||
|
|
|
@ -86,6 +86,12 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief CLUE by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -181,6 +184,7 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
|
|||
args["annotation_file"] = annotation_file_;
|
||||
args["task"] = task_;
|
||||
args["decode"] = decode_;
|
||||
args["extra_metadata"] = extra_metadata_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
|
@ -189,5 +193,30 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string annotation_file = json_obj["annotation_file"];
|
||||
std::string task = json_obj["task"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
bool extra_metadata = json_obj["extra_metadata"];
|
||||
*ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -81,6 +81,14 @@ class CocoNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -187,6 +187,32 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string field_delim = json_obj["field_delim"];
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults = {};
|
||||
std::vector<std::string> column_names = json_obj["column_names"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples,
|
||||
shuffle, num_shards, shard_id, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// CSV by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
|
|
|
@ -107,6 +107,12 @@ class CSVNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief CSV by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -113,6 +116,7 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
|
|||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["recursive"] = recursive_;
|
||||
args["decode"] = decode_;
|
||||
args["extensions"] = exts_;
|
||||
args["class_indexing"] = class_indexing_;
|
||||
|
@ -124,5 +128,36 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("recursive") != json_obj.end(), "Failed to find recursive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
bool recursive = json_obj["recursive"];
|
||||
std::set<std::string> extension = json_obj["extensions"];
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -87,6 +87,14 @@ class ImageFolderNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -152,5 +155,34 @@ Status ManifestNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
std::string dataset_file = json_obj["dataset_file"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,9 +78,18 @@ class ManifestNode : public MappableSourceNode {
|
|||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \param[in] cache Dataset cache for constructor input
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -111,5 +114,24 @@ Status MnistNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,14 @@ class MnistNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -106,6 +106,30 @@ Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
|
||||
int64_t num_shards = json_obj["num_shards"];
|
||||
int64_t shard_id = json_obj["shard_id"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
uint32_t seed = json_obj["seed"];
|
||||
int64_t offset = json_obj["offset"];
|
||||
bool even_dist = json_obj["even_dist"];
|
||||
*sampler =
|
||||
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() {
|
||||
auto sampler =
|
||||
std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_);
|
||||
|
|
|
@ -56,6 +56,15 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
|
|
|
@ -60,6 +60,19 @@ Status PKSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status PKSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
int64_t num_val = json_obj["num_val"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
*sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status PKSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
|
||||
// runtime sampler object
|
||||
*sampler = std::make_shared<dataset::PKSamplerRT>(num_val_, shuffle_, num_samples_);
|
||||
|
|
|
@ -55,6 +55,15 @@ class PKSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -56,6 +56,20 @@ Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status RandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
|
||||
"Failed to find reshuffle_each_epoch");
|
||||
bool replacement = json_obj["replacement"];
|
||||
bool reshuffle_each_epoch = json_obj["reshuffle_each_epoch"];
|
||||
*sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
|
||||
// runtime sampler object
|
||||
*sampler = std::make_shared<dataset::RandomSamplerRT>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||
|
|
|
@ -55,6 +55,15 @@ class RandomSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
||||
|
@ -73,5 +76,15 @@ Status SamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler) {
|
||||
for (nlohmann::json child : json_obj["child_sampler"]) {
|
||||
std::shared_ptr<SamplerObj> child_sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(child, &child_sampler));
|
||||
(*parent_sampler)->AddChildSampler(child_sampler);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -67,6 +67,14 @@ class SamplerObj {
|
|||
|
||||
virtual Status to_json(nlohmann::json *const out_json);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to construct children samplers
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] parent_sampler given parent sampler, output constructed parent sampler with children samplers added
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *parent_sampler);
|
||||
#endif
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -61,6 +61,18 @@ Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SequentialSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index");
|
||||
int64_t start_index = json_obj["start_index"];
|
||||
*sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
|
||||
// runtime sampler object
|
||||
*sampler = std::make_shared<dataset::SequentialSamplerRT>(start_index_, num_samples_);
|
||||
|
|
|
@ -55,6 +55,15 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -63,6 +63,19 @@ Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SubsetRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
|
||||
std::vector<int64_t> indices = json_obj["indices"];
|
||||
*sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> SubsetRandomSamplerObj::SamplerCopy() {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
|
|
|
@ -45,6 +45,10 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
|
|||
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
|
|
@ -72,6 +72,17 @@ Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SubsetSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
|
||||
std::vector<int64_t> indices = json_obj["indices"];
|
||||
*sampler = std::make_shared<SubsetSamplerObj>(indices, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> SubsetSamplerObj::SamplerCopy() {
|
||||
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
|
|
|
@ -55,6 +55,15 @@ class SubsetSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -63,6 +63,20 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
std::vector<double> weights = json_obj["weights"];
|
||||
bool replacement = json_obj["replacement"];
|
||||
*sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
|
||||
*sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
|
||||
Status s = BuildChildren(sampler);
|
||||
|
|
|
@ -51,6 +51,15 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[in] num_samples number of sample in the sampler
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -153,6 +153,26 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// TextFile by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
|
|
|
@ -83,6 +83,12 @@ class TextFileNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief TextFile by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
|
|
|
@ -229,6 +229,33 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string schema = json_obj["schema"];
|
||||
std::vector<std::string> columns_list = json_obj["columns_list"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
bool shard_equal_rows = json_obj["shard_equal_rows"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
|
||||
shard_equal_rows, cache);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// TFRecord by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
|
|
|
@ -126,6 +126,12 @@ class TFRecordNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief TFRecord by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -169,6 +172,7 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
|
|||
args["usage"] = usage_;
|
||||
args["class_indexing"] = class_index_;
|
||||
args["decode"] = decode_;
|
||||
args["extra_metadata"] = extra_metadata_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
|
@ -177,5 +181,38 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
bool extra_metadata = json_obj["extra_metadata"];
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
|
||||
*ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata);
|
||||
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -83,6 +83,14 @@ class VOCNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to read dataset in json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
#endif
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
|
|
@ -91,5 +91,13 @@ Status TakeNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TakeNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<TakeNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,6 +88,14 @@ class TakeNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
int32_t take_count_;
|
||||
};
|
||||
|
|
|
@ -126,5 +126,25 @@ Status TransferNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransferNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(),
|
||||
"Failed to find create_data_info_queue");
|
||||
std::string queue_name = json_obj["queue_name"];
|
||||
std::string device_type = json_obj["device_type"];
|
||||
int32_t device_id = json_obj["device_id"];
|
||||
bool send_epoch_end = json_obj["send_epoch_end"];
|
||||
int32_t total_batch = json_obj["total_batch"];
|
||||
bool create_data_info_queue = json_obj["create_data_info_queue"];
|
||||
*result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch,
|
||||
create_data_info_queue);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -84,6 +84,14 @@ class TransferNode : public DatasetNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Function for read dataset operation from json
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
private:
|
||||
std::string queue_name_;
|
||||
int32_t device_id_;
|
||||
|
|
|
@ -124,584 +124,97 @@ Status Serdes::CreateNode(std::shared_ptr<DatasetNode> child_ds, nlohmann::json
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
bool decode = json_obj["decode"];
|
||||
std::set<std::string> extension = json_obj["extensions"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string annotation_file = json_obj["annotation_file"];
|
||||
std::string task = json_obj["task"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// default value for cache and extra_metadata - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
bool extra_metadata = false;
|
||||
*ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string field_delim = json_obj["field_delim"];
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults = {};
|
||||
std::vector<std::string> column_names = json_obj["column_names"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples,
|
||||
shuffle, num_shards, shard_id, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
|
||||
bool recursive = false;
|
||||
std::set<std::string> extension = json_obj["extensions"];
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
std::string dataset_file = json_obj["dataset_file"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
bool decode = json_obj["decode"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("schema") != json_obj.end(), "Failed to find schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string schema = json_obj["schema"];
|
||||
std::vector<std::string> columns_list = json_obj["columns_list"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
int32_t num_shards = json_obj["num_shards"];
|
||||
int32_t shard_id = json_obj["shard_id"];
|
||||
bool shard_equal_rows = json_obj["shard_equal_rows"];
|
||||
// default value for cache - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
*ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, shard_id,
|
||||
shard_equal_rows, cache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::map<std::string, int32_t> class_indexing;
|
||||
nlohmann::json class_map = json_obj["class_indexing"];
|
||||
for (const auto &class_map_child : class_map) {
|
||||
std::string class_ = class_map_child[0];
|
||||
int32_t indexing = class_map_child[1];
|
||||
class_indexing.insert({class_, indexing});
|
||||
}
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(json_obj["sampler"], &sampler));
|
||||
// default value for cache and extra_metadata - to_json function does not have the output
|
||||
std::shared_ptr<DatasetCache> cache = nullptr;
|
||||
bool extra_metadata = false;
|
||||
*ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache, extra_metadata);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateDatasetNode(nlohmann::json json_obj, std::string op_type, std::shared_ptr<DatasetNode> *ds) {
|
||||
if (op_type == kCelebANode) {
|
||||
RETURN_IF_NOT_OK(CreateCelebADatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCifar10Node) {
|
||||
RETURN_IF_NOT_OK(CreateCifar10DatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds));
|
||||
} else if (op_type == kCifar100Node) {
|
||||
RETURN_IF_NOT_OK(CreateCifar100DatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(Cifar100Node::from_json(json_obj, ds));
|
||||
} else if (op_type == kCLUENode) {
|
||||
RETURN_IF_NOT_OK(CreateCLUEDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(CLUENode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCocoNode) {
|
||||
RETURN_IF_NOT_OK(CreateCocoDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kCSVNode) {
|
||||
RETURN_IF_NOT_OK(CreateCSVDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kImageFolderNode) {
|
||||
RETURN_IF_NOT_OK(CreateImageFolderDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kManifestNode) {
|
||||
RETURN_IF_NOT_OK(CreateManifestDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(ManifestNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kMnistNode) {
|
||||
RETURN_IF_NOT_OK(CreateMnistDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(MnistNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kTextFileNode) {
|
||||
RETURN_IF_NOT_OK(CreateTextFileDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(TextFileNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kTFRecordNode) {
|
||||
RETURN_IF_NOT_OK(CreateTFRecordDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(TFRecordNode::from_json(json_obj, ds));
|
||||
} else if (op_type == kVOCNode) {
|
||||
RETURN_IF_NOT_OK(CreateVOCDatasetNode(json_obj, ds));
|
||||
RETURN_IF_NOT_OK(VOCNode::from_json(json_obj, ds));
|
||||
} else {
|
||||
return Status(StatusCode::kMDUnexpectedError, op_type + " is not supported");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder");
|
||||
int32_t batch_size = json_obj["batch_size"];
|
||||
bool drop_remainder = json_obj["drop_remainder"];
|
||||
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations");
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
std::vector<std::string> project_columns = json_obj["project_columns"];
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations;
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(json_obj["operations"], &operations));
|
||||
*result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns);
|
||||
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns");
|
||||
std::vector<std::string> columns = json_obj["columns"];
|
||||
*result = std::make_shared<ProjectNode>(ds, columns);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
*result = std::make_shared<RenameNode>(ds, input_columns, output_columns);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<RepeatNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("buffer_size") != json_obj.end(), "Failed to find buffer_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
|
||||
"Failed to find reshuffle_each_epoch");
|
||||
int32_t buffer_size = json_obj["buffer_size"];
|
||||
bool reset_every_epoch = json_obj["reshuffle_each_epoch"];
|
||||
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<SkipNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(),
|
||||
"Failed to find create_data_info_queue");
|
||||
std::string queue_name = json_obj["queue_name"];
|
||||
std::string device_type = json_obj["device_type"];
|
||||
int32_t device_id = json_obj["device_id"];
|
||||
bool send_epoch_end = json_obj["send_epoch_end"];
|
||||
int32_t total_batch = json_obj["total_batch"];
|
||||
bool create_data_info_queue = json_obj["create_data_info_queue"];
|
||||
*result = std::make_shared<TransferNode>(ds, queue_name, device_type, device_id, send_epoch_end, total_batch,
|
||||
create_data_info_queue);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<TakeNode>(ds, count);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj, std::string op_type,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
if (op_type == kBatchNode) {
|
||||
RETURN_IF_NOT_OK(CreateBatchOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(BatchNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kMapNode) {
|
||||
RETURN_IF_NOT_OK(CreateMapOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(MapNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kProjectNode) {
|
||||
RETURN_IF_NOT_OK(CreateProjectOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(ProjectNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kRenameNode) {
|
||||
RETURN_IF_NOT_OK(CreateRenameOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(RenameNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kRepeatNode) {
|
||||
RETURN_IF_NOT_OK(CreateRepeatOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(RepeatNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kShuffleNode) {
|
||||
RETURN_IF_NOT_OK(CreateShuffleOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(ShuffleNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kSkipNode) {
|
||||
RETURN_IF_NOT_OK(CreateSkipOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(SkipNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kTransferNode) {
|
||||
RETURN_IF_NOT_OK(CreateTransferOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(TransferNode::from_json(json_obj, ds, result));
|
||||
} else if (op_type == kTakeNode) {
|
||||
RETURN_IF_NOT_OK(CreateTakeOperationNode(ds, json_obj, result));
|
||||
RETURN_IF_NOT_OK(TakeNode::from_json(json_obj, ds, result));
|
||||
} else {
|
||||
return Status(StatusCode::kMDUnexpectedError, op_type + " operation is not supported");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
|
||||
int64_t num_shards = json_obj["num_shards"];
|
||||
int64_t shard_id = json_obj["shard_id"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
uint32_t seed = json_obj["seed"];
|
||||
int64_t offset = json_obj["offset"];
|
||||
bool even_dist = json_obj["even_dist"];
|
||||
*sampler =
|
||||
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
int64_t num_val = json_obj["num_val"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
*sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
bool replacement = json_obj["replacement"];
|
||||
*sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index");
|
||||
int64_t start_index = json_obj["start_index"];
|
||||
*sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
|
||||
std::vector<int64_t> indices = json_obj["indices"];
|
||||
*sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
|
||||
bool replacement = json_obj["replacement"];
|
||||
std::vector<double> weights = json_obj["weights"];
|
||||
*sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
if (json_obj.find("child_sampler") != json_obj.end()) {
|
||||
std::shared_ptr<SamplerObj> parent_sampler = *sampler;
|
||||
RETURN_IF_NOT_OK(ChildSamplerFromJson(json_obj, parent_sampler, sampler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler_name") != json_obj.end(), "Failed to find sampler_name");
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
std::string sampler_name = json_obj["sampler_name"];
|
||||
if (sampler_name == "DistributedSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructDistributedSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(DistributedSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "PKSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructPKSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(PKSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "RandomSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructRandomSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(RandomSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "SequentialSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructSequentialSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(SequentialSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "SubsetSampler") {
|
||||
RETURN_IF_NOT_OK(SubsetSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "SubsetRandomSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructSubsetRandomSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(SubsetRandomSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else if (sampler_name == "WeightedRandomSampler") {
|
||||
RETURN_IF_NOT_OK(ConstructWeightedRandomSampler(json_obj, num_samples, sampler));
|
||||
RETURN_IF_NOT_OK(WeightedRandomSamplerObj::from_json(json_obj, num_samples, sampler));
|
||||
} else {
|
||||
return Status(StatusCode::kMDUnexpectedError, sampler_name + "Sampler is not supported");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("child_sampler") != json_obj.end(), "Failed to find child_sampler");
|
||||
for (nlohmann::json child : json_obj["child_sampler"]) {
|
||||
std::shared_ptr<SamplerObj> child_sampler;
|
||||
RETURN_IF_NOT_OK(ConstructSampler(child, &child_sampler));
|
||||
parent_sampler.get()->AddChildSampler(child_sampler);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio");
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms;
|
||||
std::vector<nlohmann::json> json_operations = {};
|
||||
json_operations.push_back(op_params["transform"]);
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(json_operations, &transforms));
|
||||
float ratio = op_params["ratio"];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1,
|
||||
"Expect size one of transforms parameter, but got:" + std::to_string(transforms.size()));
|
||||
*operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy");
|
||||
nlohmann::json policy_json = op_params["policy"];
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy;
|
||||
std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items;
|
||||
for (nlohmann::json item : policy_json) {
|
||||
for (nlohmann::json item_pair : item) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op");
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations;
|
||||
std::pair<std::shared_ptr<TensorOperation>, double> policy_pair;
|
||||
std::shared_ptr<TensorOperation> operation;
|
||||
nlohmann::json tensor_op_json;
|
||||
double prob = item_pair["prob"];
|
||||
tensor_op_json.push_back(item_pair["tensor_op"]);
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(tensor_op_json, &operations));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation");
|
||||
policy_pair = std::make_pair(operations[0], prob);
|
||||
policy_items.push_back(policy_pair);
|
||||
}
|
||||
policy.push_back(policy_items);
|
||||
}
|
||||
*operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops");
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms = {};
|
||||
RETURN_IF_NOT_OK(ConstructTensorOps(op_params["transforms"], &transforms));
|
||||
int32_t num_ops = op_params["num_ops"];
|
||||
*operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Serdes::ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result) {
|
||||
Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> output;
|
||||
for (auto op : operations) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op.find("is_python_front_end_op") == op.end(),
|
||||
for (nlohmann::json item : json_obj) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("is_python_front_end_op") == item.end(),
|
||||
"python operation is not yet supported");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_name") != op.end(), "Failed to find tensor_op_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op.find("tensor_op_params") != op.end(), "Failed to find tensor_op_params");
|
||||
std::string op_name = op["tensor_op_name"];
|
||||
nlohmann::json op_params = op["tensor_op_params"];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
|
||||
std::string op_name = item["tensor_op_name"];
|
||||
nlohmann::json op_params = item["tensor_op_params"];
|
||||
std::shared_ptr<TensorOperation> operation = nullptr;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
|
||||
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
|
||||
|
@ -716,7 +229,7 @@ Serdes::InitializeFuncPtr() {
|
|||
std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr;
|
||||
ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json);
|
||||
ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json);
|
||||
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(BoundingBoxAugmentFromJson);
|
||||
ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json);
|
||||
ops_ptr[vision::kCenterCropOperation] = &(vision::CenterCropOperation::from_json);
|
||||
ops_ptr[vision::kCropOperation] = &(vision::CropOperation::from_json);
|
||||
ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json);
|
||||
|
@ -745,7 +258,7 @@ Serdes::InitializeFuncPtr() {
|
|||
ops_ptr[vision::kRandomResizedCropOperation] = &(vision::RandomResizedCropOperation::from_json);
|
||||
ops_ptr[vision::kRandomResizedCropWithBBoxOperation] = &(vision::RandomResizedCropWithBBoxOperation::from_json);
|
||||
ops_ptr[vision::kRandomRotationOperation] = &(vision::RandomRotationOperation::from_json);
|
||||
ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(RandomSelectSubpolicyFromJson);
|
||||
ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(vision::RandomSelectSubpolicyOperation::from_json);
|
||||
ops_ptr[vision::kRandomSharpnessOperation] = &(vision::RandomSharpnessOperation::from_json);
|
||||
ops_ptr[vision::kRandomSolarizeOperation] = &(vision::RandomSolarizeOperation::from_json);
|
||||
ops_ptr[vision::kRandomVerticalFlipOperation] = &(vision::RandomVerticalFlipOperation::from_json);
|
||||
|
@ -766,7 +279,7 @@ Serdes::InitializeFuncPtr() {
|
|||
&(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json);
|
||||
ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json);
|
||||
ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json);
|
||||
ops_ptr[vision::kUniformAugOperation] = &(UniformAugFromJson);
|
||||
ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json);
|
||||
ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json);
|
||||
ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json);
|
||||
ops_ptr[transforms::kOneHotOperation] = &(transforms::OneHotOperation::from_json);
|
||||
|
|
|
@ -159,6 +159,18 @@ class Serdes {
|
|||
/// \return Status The status code returned
|
||||
static Status ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief Helper functions for creating sampler, separate different samplers and call the related function
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] sampler Deserialized sampler
|
||||
/// \return Status The status code returned
|
||||
static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
|
||||
|
||||
/// \brief helper function to construct tensor operations
|
||||
/// \param[in] json_obj json object of operations to be deserilized
|
||||
/// \param[out] vector of tensor operation pointer
|
||||
/// \return Status The status code returned
|
||||
static Status ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result);
|
||||
|
||||
protected:
|
||||
/// \brief Helper function to save JSON to a file
|
||||
/// \param[in] json_string The JSON string to be saved to the file
|
||||
|
@ -189,91 +201,6 @@ class Serdes {
|
|||
static Status CreateDatasetOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::string op_type, std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
/// \brief Helper functions for creating sampler, separate different samplers and call the related function
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] sampler Deserialized sampler
|
||||
/// \return Status The status code returned
|
||||
static Status ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
|
||||
|
||||
/// \brief helper function to construct tensor operations
|
||||
/// \param[in] operations operations to be deserilized
|
||||
/// \param[out] vector of tensor operation pointer
|
||||
/// \return Status The status code returned
|
||||
static Status ConstructTensorOps(nlohmann::json operations, std::vector<std::shared_ptr<TensorOperation>> *result);
|
||||
|
||||
/// \brief Helper functions for different datasets
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] ds Deserialized dataset
|
||||
/// \return Status The status code returned
|
||||
static Status CreateCelebADatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateCifar10DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateCifar100DatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateCLUEDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateCocoDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateCSVDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateImageFolderDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateManifestDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateMnistDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateTextFileDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateTFRecordDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
static Status CreateVOCDatasetNode(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds);
|
||||
|
||||
/// \brief Helper functions for different operations
|
||||
/// \param[in] ds dataset node constructed
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] result Deserialized dataset after the operation
|
||||
/// \return Status The status code returned
|
||||
static Status CreateBatchOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateMapOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateProjectOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateRenameOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateRepeatOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateShuffleOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateSkipOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateTransferOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
static Status CreateTakeOperationNode(std::shared_ptr<DatasetNode> ds, nlohmann::json json_obj,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
/// \brief Helper functions for different samplers
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[out] sampler Deserialized sampler
|
||||
/// \return Status The status code returned
|
||||
static Status ConstructDistributedSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
static Status ConstructPKSampler(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler);
|
||||
static Status ConstructRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
static Status ConstructSequentialSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
static Status ConstructSubsetRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
static Status ConstructWeightedRandomSampler(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
|
||||
/// \brief Helper functions to construct children samplers
|
||||
/// \param[in] json_obj The JSON object to be deserialized
|
||||
/// \param[in] parent_sampler given parent sampler
|
||||
/// \param[out] sampler sampler constructed - parent sampler with children samplers added
|
||||
/// \return Status The status code returned
|
||||
static Status ChildSamplerFromJson(nlohmann::json json_obj, std::shared_ptr<SamplerObj> parent_sampler,
|
||||
std::shared_ptr<SamplerObj> *sampler);
|
||||
|
||||
/// \brief Helper functions for vision operations, which requires tensor operations as input
|
||||
/// \param[in] op_params operation parameters for the operation
|
||||
/// \param[out] operation deserialized operation
|
||||
/// \return Status The status code returned
|
||||
static Status BoundingBoxAugmentFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
static Status RandomSelectSubpolicyFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
static Status UniformAugFromJson(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
/// \brief Helper function to map the function pointers
|
||||
/// \return map of key to function pointer
|
||||
static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#endif
|
||||
|
||||
|
@ -56,6 +57,20 @@ Status BoundingBoxAugmentOperation::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BoundingBoxAugmentOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transform") != op_params.end(), "Failed to find transform");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ratio") != op_params.end(), "Failed to find ratio");
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms;
|
||||
std::vector<nlohmann::json> json_operations = {};
|
||||
json_operations.push_back(op_params["transform"]);
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_operations, &transforms));
|
||||
float ratio = op_params["ratio"];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(transforms.size() == 1,
|
||||
"Expect size one of transforms parameter, but got:" + std::to_string(transforms.size()));
|
||||
*operation = std::make_shared<vision::BoundingBoxAugmentOperation>(transforms[0], ratio);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
|
|
|
@ -49,6 +49,8 @@ class BoundingBoxAugmentOperation : public TensorOperation {
|
|||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorOperation> transform_;
|
||||
float ratio_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/random_select_subpolicy_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
||||
#endif
|
||||
|
||||
|
@ -100,6 +101,33 @@ Status RandomSelectSubpolicyOperation::to_json(nlohmann::json *out_json) {
|
|||
(*out_json)["policy"] = policy_tensor_ops;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomSelectSubpolicyOperation::from_json(nlohmann::json op_params,
|
||||
std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find policy");
|
||||
nlohmann::json policy_json = op_params["policy"];
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy;
|
||||
std::vector<std::pair<std::shared_ptr<TensorOperation>, double>> policy_items;
|
||||
for (nlohmann::json item : policy_json) {
|
||||
for (nlohmann::json item_pair : item) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("prob") != item_pair.end(), "Failed to find prob");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item_pair.find("tensor_op") != item_pair.end(), "Failed to find tensor_op");
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations;
|
||||
std::pair<std::shared_ptr<TensorOperation>, double> policy_pair;
|
||||
std::shared_ptr<TensorOperation> operation;
|
||||
nlohmann::json tensor_op_json;
|
||||
double prob = item_pair["prob"];
|
||||
tensor_op_json.push_back(item_pair["tensor_op"]);
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(tensor_op_json, &operations));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(operations.size() == 1, "There should be only 1 tensor operation");
|
||||
policy_pair = std::make_pair(operations[0], prob);
|
||||
policy_items.push_back(policy_pair);
|
||||
}
|
||||
policy.push_back(policy_items);
|
||||
}
|
||||
*operation = std::make_shared<vision::RandomSelectSubpolicyOperation>(policy);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
|
|
|
@ -50,6 +50,8 @@ class RandomSelectSubpolicyOperation : public TensorOperation {
|
|||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy_;
|
||||
};
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/serdes.h"
|
||||
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
||||
#endif
|
||||
|
||||
|
@ -74,6 +75,16 @@ Status UniformAugOperation::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UniformAugOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("transforms") != op_params.end(), "Failed to find transforms");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_ops") != op_params.end(), "Failed to find num_ops");
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms = {};
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(op_params["transforms"], &transforms));
|
||||
int32_t num_ops = op_params["num_ops"];
|
||||
*operation = std::make_shared<vision::UniformAugOperation>(transforms, num_ops);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
|
|
|
@ -49,6 +49,8 @@ class UniformAugOperation : public TensorOperation {
|
|||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
||||
int32_t num_ops_;
|
||||
|
|
|
@ -462,6 +462,7 @@ TEST_F(MindDataTestDeserialize, TestDeserializeFill) {
|
|||
std::shared_ptr<TensorOperation> operation2 = std::make_shared<text::ToNumberOperation>("int32_t");
|
||||
std::vector<std::shared_ptr<TensorOperation>> ops = {operation1, operation2};
|
||||
ds = std::make_shared<MapNode>(ds, ops);
|
||||
ds = std::make_shared<TransferNode>(ds, "queue", "type", 1, true, 10, true);
|
||||
compare_dataset(ds);
|
||||
}
|
||||
|
||||
|
@ -482,3 +483,19 @@ TEST_F(MindDataTestDeserialize, TestDeserializeTensor) {
|
|||
json_ss1 << json_obj1;
|
||||
EXPECT_EQ(json_ss.str(), json_ss1.str());
|
||||
}
|
||||
|
||||
// Helper function to get the session id from SESSION_ID env variable
|
||||
Status GetSessionFromEnv(session_id_type *session_id);
|
||||
|
||||
TEST_F(MindDataTestDeserialize, DISABLED_TestDeserializeCache) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestDeserialize-Cache.";
|
||||
std::string data_dir = "./data/dataset/testCache";
|
||||
std::string usage = "all";
|
||||
session_id_type env_session;
|
||||
ASSERT_TRUE(GetSessionFromEnv(&env_session));
|
||||
std::shared_ptr<DatasetCache> some_cache = CreateDatasetCache(env_session, 0, false, "127.0.0.1", 50052, 1, 1);
|
||||
|
||||
std::shared_ptr<SamplerObj> sampler = std::make_shared<SequentialSamplerObj>(0, 10);
|
||||
std::shared_ptr<DatasetNode> ds = std::make_shared<Cifar10Node>(data_dir, usage, sampler, some_cache);
|
||||
compare_dataset(ds);
|
||||
}
|
Loading…
Reference in New Issue