!9534 Remove std::optional in pybind

From: @alex-yuyue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-08 05:15:16 +08:00 committed by Gitee
commit 66de0eec3a
4 changed files with 124 additions and 137 deletions

View File

@ -96,23 +96,24 @@ PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
// PYBIND FOR LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(
CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode,
std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) {
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
toStringSet(extensions), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(celebA->ValidateParams());
return celebA;
}));
}));
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
"to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
py::list extensions, std::shared_ptr<CacheClient> cc) {
auto celebA =
std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
toStringSet(extensions), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(celebA->ValidateParams());
return celebA;
}));
}));
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
"to create a Cifar10Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
std::shared_ptr<CacheClient> cc) {
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar10->ValidateParams());
@ -123,8 +124,8 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
"to create a Cifar100Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
std::shared_ptr<CacheClient> cc) {
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar100->ValidateParams());
@ -136,7 +137,7 @@ PYBIND_REGISTER(
CLUENode, 2, ([](const py::module *m) {
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<CLUENode> clue_node =
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
@ -145,24 +146,24 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(coco->ValidateParams());
return coco;
}));
}));
PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
"to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<CocoNode> coco =
std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(coco->ValidateParams());
return coco;
}));
}));
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id,
std::optional<std::shared_ptr<CacheClient>> cc) {
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) {
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
column_names, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
@ -194,10 +195,10 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode")
.def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler,
std::optional<py::list> extensions, std::optional<py::dict> class_indexing,
std::optional<std::shared_ptr<CacheClient>> cc) {
bool recursive = false;
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
py::dict class_indexing, std::shared_ptr<CacheClient> cc) {
// Don't update recursive to true
bool recursive = false; // Will be removed in future PR
auto imagefolder = std::make_shared<ImageFolderNode>(
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
@ -209,9 +210,8 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode")
.def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler,
std::optional<py::dict> class_indexing, bool decode,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
py::dict class_indexing, bool decode, std::shared_ptr<CacheClient> cc) {
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
toStringMap(class_indexing), decode,
toDatasetCache(std::move(cc)));
@ -223,8 +223,8 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
"to create a MindDataNode")
.def(py::init([](std::string dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
.def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler,
py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
@ -235,8 +235,8 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
THROW_IF_ERROR(minddata->ValidateParams());
return minddata;
}))
.def(py::init([](py::list dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
.def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler,
py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
@ -252,8 +252,8 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
"to create an MnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler,
std::shared_ptr<CacheClient> cc) {
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(mnist->ValidateParams());
@ -264,15 +264,14 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(
RandomNode, 2, ([](const py::module *m) {
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list,
std::shared_ptr<CacheClient> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
return random_node;
}))
.def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list, std::shared_ptr<CacheClient> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
@ -284,7 +283,7 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode")
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
int32_t shard_id, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
toDatasetCache(std::move(cc)));
@ -293,44 +292,34 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(
TFRecordNode, 2, ([](const py::module *m) {
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
"to create a TFRecordNode")
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}))
.def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}));
}));
PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
"to create a TFRecordNode")
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}))
.def(py::init([](py::list dataset_files, std::string schema, py::list columns_list,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}));
}));
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(
py::init([](std::string dataset_dir, std::string task, std::string usage,
std::optional<py::dict> class_indexing, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
py::init([](std::string dataset_dir, std::string task, std::string usage, py::dict class_indexing,
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) {
std::shared_ptr<VOCNode> voc =
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
@ -416,15 +405,14 @@ PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
"to create a ConcatNode")
.def(
py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler,
py::list children_flag_and_nums, py::list children_start_end_index) {
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
toPairVector(children_flag_and_nums),
toPairVector(children_start_end_index));
THROW_IF_ERROR(concat->ValidateParams());
return concat;
}));
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, py::handle sampler,
py::list children_flag_and_nums, py::list children_start_end_index) {
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
toPairVector(children_flag_and_nums),
toPairVector(children_start_end_index));
THROW_IF_ERROR(concat->ValidateParams());
return concat;
}));
}));
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
@ -441,10 +429,8 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations,
std::optional<py::list> input_columns, std::optional<py::list> output_columns,
std::optional<py::list> project_columns,
std::optional<std::shared_ptr<CacheClient>> cc,
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
py::list output_columns, py::list project_columns, std::shared_ptr<CacheClient> cc,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
@ -465,17 +451,15 @@ PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
"to create a RenameNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns,
std::optional<py::list> output_columns) {
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
toStringVector(output_columns));
THROW_IF_ERROR(rename->ValidateParams());
return rename;
}));
}));
PYBIND_REGISTER(
RenameNode, 2, ([](const py::module *m) {
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode", "to create a RenameNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list input_columns, py::list output_columns) {
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns), toStringVector(output_columns));
THROW_IF_ERROR(rename->ValidateParams());
return rename;
}));
}));
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",

View File

@ -28,10 +28,10 @@ bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>
std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
std::set<std::string> toStringSet(const std::optional<py::list> list) {
std::set<std::string> toStringSet(const py::list list) {
std::set<std::string> set;
if (list) {
for (auto l : *list) {
if (!list.empty()) {
for (auto l : list) {
if (!l.is_none()) {
(void)set.insert(py::str(l));
}
@ -40,20 +40,20 @@ std::set<std::string> toStringSet(const std::optional<py::list> list) {
return set;
}
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) {
std::map<std::string, int32_t> toStringMap(const py::dict dict) {
std::map<std::string, int32_t> map;
if (dict) {
for (auto p : *dict) {
if (!dict.empty()) {
for (auto p : dict) {
(void)map.emplace(toString(p.first), toInt(p.second));
}
}
return map;
}
std::vector<std::string> toStringVector(const std::optional<py::list> list) {
std::vector<std::string> toStringVector(const py::list list) {
std::vector<std::string> vector;
if (list) {
for (auto l : *list) {
if (!list.empty()) {
for (auto l : list) {
if (l.is_none())
vector.emplace_back("");
else
@ -63,10 +63,10 @@ std::vector<std::string> toStringVector(const std::optional<py::list> list) {
return vector;
}
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) {
std::pair<int64_t, int64_t> toIntPair(const py::tuple tuple) {
std::pair<int64_t, int64_t> pair;
if (tuple) {
pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1]));
if (!tuple.empty()) {
pair = std::make_pair(toInt64((tuple)[0]), toInt64((tuple)[1]));
}
return pair;
}
@ -85,10 +85,10 @@ std::vector<std::pair<int, int>> toPairVector(const py::list list) {
return vector;
}
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) {
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operations) {
std::vector<std::shared_ptr<TensorOperation>> vector;
if (operations) {
for (auto op : *operations) {
if (!operations.empty()) {
for (auto op : operations) {
std::shared_ptr<TensorOp> tensor_op;
if (py::isinstance<TensorOp>(op)) {
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
@ -132,19 +132,19 @@ std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetN
return vector;
}
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) {
std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset) {
if (py_sampler) {
std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) {
// Common Sampler
std::shared_ptr<SamplerRT> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create");
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create");
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
} else {
// Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset");
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
}
@ -156,10 +156,10 @@ std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, b
}
// Here we take in a python object, that holds a reference to a C++ object
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<DatasetCache> toDatasetCache(std::shared_ptr<CacheClient> cc) {
if (cc) {
std::shared_ptr<DatasetCache> built_cache;
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value()));
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc));
return built_cache;
} else {
// don't need to check here as cache is not enabled.

View File

@ -47,25 +47,25 @@ bool toBool(const py::handle &handle);
std::string toString(const py::handle &handle);
std::set<std::string> toStringSet(const std::optional<py::list> list);
std::set<std::string> toStringSet(const py::list list);
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict);
std::map<std::string, int32_t> toStringMap(const py::dict dict);
std::vector<std::string> toStringVector(const std::optional<py::list> list);
std::vector<std::string> toStringVector(const py::list list);
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple);
std::pair<int64_t, int64_t> toIntPair(const py::tuple tuple);
std::vector<std::pair<int, int>> toPairVector(const py::list list);
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operations);
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation);
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset = false);
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc);
std::shared_ptr<DatasetCache> toDatasetCache(std::shared_ptr<CacheClient> cc);
ShuffleMode toShuffleMode(const int32_t shuffle);

View File

@ -2266,13 +2266,13 @@ class MapDataset(Dataset):
if start_ind != end_ind:
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
operations = new_ops
self.operations = operations
self.operations = replace_none(operations, [])
if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns]
self.input_columns = replace_none(input_columns, [])
if output_columns is not None and not isinstance(output_columns, list):
output_columns = [output_columns]
self.output_columns = replace_none(output_columns, input_columns)
self.output_columns = replace_none(output_columns, self.input_columns)
self.cache = cache
self.column_order = column_order
@ -3025,8 +3025,9 @@ class ImageFolderDataset(MappableDataset):
cc = self.cache.cache_client
else:
cc = None
class_indexing = replace_none(self.class_indexing, {})
return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions,
self.class_indexing, cc).SetNumWorkers(self.num_parallel_workers)
class_indexing, cc).SetNumWorkers(self.num_parallel_workers)
def get_args(self):
args = super().get_args()
@ -4043,7 +4044,8 @@ class ManifestDataset(MappableDataset):
cc = self.cache.cache_client
else:
cc = None
return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, self.class_indexing,
class_indexing = replace_none(self.class_indexing, {})
return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, class_indexing,
self.decode, cc).SetNumWorkers(self.num_parallel_workers)
@check_manifestdataset
@ -4701,7 +4703,8 @@ class VOCDataset(MappableDataset):
cc = self.cache.cache_client
else:
cc = None
return cde.VOCNode(self.dataset_dir, self.task, self.usage, self.class_indexing, self.decode,
class_indexing = replace_none(self.class_indexing, {})
return cde.VOCNode(self.dataset_dir, self.task, self.usage, class_indexing, self.decode,
self.sampler, cc).SetNumWorkers(self.num_parallel_workers)
@check_vocdataset