From 8f34faeb7a7f7b63d9e9497dfecc9b0b85babab1 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Tue, 22 Dec 2020 17:56:35 -0500 Subject: [PATCH] Changed bindings to SamplerObj --- .../ccsrc/minddata/dataset/api/CMakeLists.txt | 2 +- .../datasetops/source/sampler/bindings.cc | 93 ---------- .../dataset/include/sampler_bindings.cc | 127 +++++++++++++ .../dataset/api/python/pybind_conversion.cc | 12 +- .../ccsrc/minddata/dataset/api/samplers.cc | 171 +++++++++++++++--- .../ccsrc/minddata/dataset/include/samplers.h | 41 ++++- mindspore/dataset/engine/samplers.py | 171 ++++++++---------- .../dataset/test_datasets_imagefolder.py | 23 ++- .../dataset/test_minddataset_exception.py | 8 +- tests/ut/python/dataset/test_sampler.py | 38 +--- .../ut/python/dataset/test_serdes_dataset.py | 4 +- 11 files changed, 423 insertions(+), 267 deletions(-) delete mode 100644 mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc create mode 100644 mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/sampler_bindings.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt index 1bc16a42913..dda0b2bc525 100644 --- a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -7,11 +7,11 @@ if(ENABLE_PYTHON) python/bindings/dataset/engine/cache/bindings.cc python/bindings/dataset/engine/datasetops/bindings.cc python/bindings/dataset/engine/datasetops/source/bindings.cc - python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc python/bindings/dataset/engine/gnn/bindings.cc python/bindings/dataset/include/datasets_bindings.cc python/bindings/dataset/include/iterator_bindings.cc python/bindings/dataset/include/execute_binding.cc + python/bindings/dataset/include/sampler_bindings.cc python/bindings/dataset/include/schema_bindings.cc python/bindings/dataset/kernels/bindings.cc python/bindings/dataset/kernels/data/bindings.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc deleted file mode 100644 index 18b604214bb..00000000000 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/api/python/pybind_register.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" - -namespace mindspore { -namespace dataset { - -PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) { - (void)py::class_>(*m, "Sampler") - .def("set_num_rows", - [](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) - .def("set_num_samples", - [](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) - .def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); }) - .def("get_indices", - [](SamplerRT &self) { - py::array ret; - THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); - return ret; - }) - .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { - THROW_IF_ERROR(self->AddChild(child)); - }); - })); - -PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DistributedSampler") - .def(py::init()); - })); - -PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>(*m, "PKSampler") - .def(py::init()); - })); - -PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>(*m, "PythonSampler") - .def(py::init()); - })); - -PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>(*m, "RandomSampler") - .def(py::init()); - })); - -PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "SequentialSampler") - .def(py::init()); - })); - -PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) { - (void)py::class_>( - *m, "SubsetRandomSampler") - .def(py::init>()); - })); - -PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>(*m, "SubsetSampler") - .def(py::init>()); - })); - -PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "WeightedRandomSampler") - .def(py::init, bool>()); - })); - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/sampler_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/sampler_bindings.cc new file mode 100644 index 00000000000..9d7c54349e3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/sampler_bindings.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" +#include "minddata/dataset/api/python/pybind_conversion.h" +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/callback/py_ds_callback.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/datasets.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) { + (void)py::class_>(*m, "SamplerObj", "to create a SamplerObj") + .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { + THROW_IF_ERROR(self->AddChildSampler(child)); + }); + })); + +PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "DistributedSamplerObj", "to create a DistributedSamplerObj") + .def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, + uint32_t seed, int64_t offset, bool even_dist) { + std::shared_ptr sampler = std::make_shared( + num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj") + .def(py::init([](int64_t num_samples, py::object sampler) { + auto sampler_rt = std::make_shared(num_samples, sampler); + auto sampler_obj = std::make_shared(std::move(sampler_rt)); + THROW_IF_ERROR(sampler_obj->ValidateParams()); + return sampler_obj; + })); + })); + +PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>(*m, "PKSamplerObj", + "to create a PKSamplerObj") + .def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) { + std::shared_ptr sampler = + std::make_shared(num_val, shuffle, num_samples); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomSamplerObj", "to create a RandomSamplerObj") + .def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) { + std::shared_ptr sampler = + std::make_shared(replacement, num_samples, reshuffle_each_epoch); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "SequentialSamplerObj", "to create a SequentialSamplerObj") + .def(py::init([](int64_t start_index, int64_t num_samples) { + std::shared_ptr sampler = + std::make_shared(start_index, num_samples); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "SubsetSamplerObj", "to create a SubsetSamplerObj") + .def(py::init([](std::vector indices, int64_t num_samples) { + std::shared_ptr sampler = + std::make_shared(indices, num_samples); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) { + (void)py::class_>( + *m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj") + .def(py::init([](std::vector indices, int64_t num_samples) { + std::shared_ptr sampler = + std::make_shared(indices, num_samples); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); + +PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj") + .def(py::init([](std::vector weights, int64_t num_samples, bool replacement) { + std::shared_ptr sampler = + std::make_shared(weights, num_samples, replacement); + THROW_IF_ERROR(sampler->ValidateParams()); + return sampler; + })); + })); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc index a979035cb3f..835d20c38f0 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -150,15 +150,13 @@ std::shared_ptr toSamplerObj(py::handle py_sampler, bool isMindDatas std::shared_ptr sampler_obj; if (!isMindDataset) { // Common Sampler - std::shared_ptr sampler; - auto create = py::reinterpret_borrow(py_sampler).attr("create"); - sampler = create().cast>(); - sampler_obj = std::make_shared(std::move(sampler)); + auto parse = py::reinterpret_borrow(py_sampler).attr("parse"); + sampler_obj = parse().cast>(); } else { // Mindrecord Sampler std::shared_ptr sampler; - auto create = py::reinterpret_borrow(py_sampler).attr("create_for_minddataset"); - sampler = create().cast>(); + auto parse = py::reinterpret_borrow(py_sampler).attr("parse_for_minddataset"); + sampler = parse().cast>(); sampler_obj = std::make_shared(std::move(sampler)); } return sampler_obj; diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index c340997898f..b02ff7f1b6e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -211,6 +211,27 @@ std::shared_ptr DistributedSamplerObj::BuildForMindDa } #endif +Status DistributedSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "DistributedSampler"; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + args["shuffle"] = shuffle_; + args["num_samples"] = num_samples_; + args["offset"] = offset_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + // PKSampler PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} @@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() { return Status::OK(); } +Status PKSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "PKSampler"; + args["num_val"] = num_val_; + args["shuffle"] = shuffle_; + args["num_samples"] = num_samples_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + std::shared_ptr PKSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, num_val_, shuffle_); @@ -233,6 +273,21 @@ std::shared_ptr PKSamplerObj::SamplerBuild() { return sampler; } +#ifndef ENABLE_ANDROID +std::shared_ptr PKSamplerObj::BuildForMindDataset() { + // runtime mindrecord sampler object + std::shared_ptr mind_sampler; + if (shuffle_ == true) { + mind_sampler = std::make_shared("label", num_val_, std::numeric_limits::max(), + GetSeed(), num_samples_); + } else { + mind_sampler = std::make_shared("label", num_val_, num_samples_); + } + + return mind_sampler; +} +#endif + // PreBuiltOperation PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr sampler) : sp_(std::move(sampler)) {} @@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) { return Status::OK(); } -#ifndef ENABLE_ANDROID -std::shared_ptr PKSamplerObj::BuildForMindDataset() { - // runtime mindrecord sampler object - std::shared_ptr mind_sampler; - if (shuffle_ == true) { - mind_sampler = std::make_shared("label", num_val_, std::numeric_limits::max(), - GetSeed(), num_samples_); - } else { - mind_sampler = std::make_shared("label", num_val_, num_samples_); - } - - return mind_sampler; -} -#endif - // RandomSampler -RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) - : replacement_(replacement), num_samples_(num_samples) {} +RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch) + : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {} Status RandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { @@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() { return Status::OK(); } +Status RandomSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "RandomSampler"; + args["replacement"] = replacement_; + args["num_samples"] = num_samples_; + args["reshuffle_each_epoch"] = reshuffle_each_epoch_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + std::shared_ptr RandomSamplerObj::SamplerBuild() { // runtime sampler object - bool reshuffle_each_epoch = true; - auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); + auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch_); BuildChildren(sampler); return sampler; } @@ -311,7 +369,6 @@ std::shared_ptr RandomSamplerObj::SamplerBuild() { #ifndef ENABLE_ANDROID std::shared_ptr RandomSamplerObj::BuildForMindDataset() { // runtime mindrecord sampler object - bool reshuffle_each_epoch_ = true; auto mind_sampler = std::make_shared(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_); @@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() { return Status::OK(); } +Status SequentialSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "SequentialSampler"; + args["start_index"] = start_index_; + args["num_samples"] = num_samples_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + std::shared_ptr SequentialSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, start_index_); @@ -378,6 +453,23 @@ std::shared_ptr SubsetSamplerObj::BuildForMindDataset return mind_sampler; } #endif +Status SubsetSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "SubsetSampler"; + args["indices"] = indices_; + args["num_samples"] = num_samples_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} // SubsetRandomSampler SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector indices, int64_t num_samples) @@ -399,6 +491,24 @@ std::shared_ptr SubsetRandomSamplerObj::BuildForMindD } #endif +Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "SubsetRandomSampler"; + args["indices"] = indices_; + args["num_samples"] = num_samples_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + // WeightedRandomSampler WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector weights, int64_t num_samples, bool replacement) : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} @@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() { return Status::OK(); } +Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "WeightedRandomSampler"; + args["weights"] = weights_; + args["num_samples"] = num_samples_; + args["replacement"] = replacement_; + if (!children_.empty()) { + std::vector children_args; + for (auto child : children_) { + nlohmann::json child_arg; + RETURN_IF_NOT_OK(child->to_json(&child_arg)); + children_args.push_back(child_arg); + } + args["child_sampler"] = children_args; + } + *out_json = args; + return Status::OK(); +} + std::shared_ptr WeightedRandomSamplerObj::SamplerBuild() { auto sampler = std::make_shared(num_samples_, weights_, replacement_); BuildChildren(sampler); diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 26768203e73..7062ff3b6b2 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,8 @@ class SamplerObj { virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } + std::vector> GetChild() { return children_; } + #ifndef ENABLE_ANDROID /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler @@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; /// \brief Function to get the shard id of sampler @@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; private: @@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj { class RandomSamplerObj : public SamplerObj { public: - RandomSamplerObj(bool replacement, int64_t num_samples); + RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true); ~RandomSamplerObj() = default; std::shared_ptr SamplerBuild() override; std::shared_ptr SamplerCopy() override { - auto sampler = std::make_shared(replacement_, num_samples_); + auto sampler = std::make_shared(replacement_, num_samples_, reshuffle_each_epoch_); for (auto child : children_) { sampler->AddChildSampler(child); } @@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; private: bool replacement_; int64_t num_samples_; + bool reshuffle_each_epoch_; }; class SequentialSamplerObj : public SamplerObj { @@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; private: @@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; protected: @@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj { ~SubsetRandomSamplerObj() = default; + Status to_json(nlohmann::json *out_json) override; + std::shared_ptr SamplerBuild() override; std::shared_ptr SamplerCopy() override { @@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj { return sampler; } + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + Status ValidateParams() override; private: diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 4bd443a2f07..aeecb2695ae 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class BuiltinSampler: self.child_sampler = None self.num_samples = num_samples - def create(self): + def parse(self): pass def add_child(self, sampler): @@ -59,16 +59,16 @@ class BuiltinSampler: def get_child(self): return self.child_sampler - def create_child(self): + def parse_child(self): c_child_sampler = None if self.child_sampler is not None: - c_child_sampler = self.child_sampler.create() + c_child_sampler = self.child_sampler.parse() return c_child_sampler - def create_child_for_minddataset(self): + def parse_child_for_minddataset(self): c_child_sampler = None if self.child_sampler is not None: - c_child_sampler = self.child_sampler.create_for_minddataset() + c_child_sampler = self.child_sampler.parse_for_minddataset() return c_child_sampler def is_shuffled(self): @@ -158,6 +158,8 @@ class Sampler(BuiltinSampler): def __init__(self, num_samples=None): super().__init__(num_samples) self.dataset_size = 0 + self.child_sampler = None + self.num_samples = num_samples def __iter__(self): """ @@ -192,13 +194,26 @@ class Sampler(BuiltinSampler): # Instance fetcher # Do not override this method! - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.PythonSampler(num_samples, self) - c_child_sampler = self.create_child() + c_sampler = cde.PreBuiltSamplerObj(num_samples, self) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler + def add_child(self, sampler): + self.child_sampler = sampler + + def get_child(self): + return self.child_sampler + + def parse_child(self): + c_child_sampler = None + if self.child_sampler is not None: + c_child_sampler = self.child_sampler.parse() + + return c_child_sampler + def is_shuffled(self): if self.child_sampler is None: return False @@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler): """ def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): - if num_shards <= 0: - raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards)) + if not isinstance(num_shards, int): + raise ValueError("num_shards must be integer but was: {}.".format(num_shards)) - if shard_id < 0 or shard_id >= num_shards: - raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id)) + if not isinstance(shard_id, int): + raise ValueError("shard_id must be integer but was: {}.".format(shard_id)) if not isinstance(shuffle, bool): raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - - if offset > num_shards: - raise ValueError("offset should be no more than num_shards: {}, " - "but got offset: {}".format(num_shards, offset)) - self.num_shards = num_shards self.shard_id = shard_id self.shuffle = shuffle @@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler): self.offset = offset super().__init__(num_samples) - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 + shuffle = self.shuffle if self.shuffle is not None else True + offset = self.offset if self.offset is not None else -1 # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle self.seed += 1 - c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, - self.shuffle, self.seed, self.offset) - c_child_sampler = self.create_child() + c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id, + shuffle, num_samples, self.seed, offset, True) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler - def create_for_minddataset(self): + def parse_for_minddataset(self): num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed, num_samples, self.offset) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler): """ def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): - if num_val <= 0: - raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val)) + if not isinstance(num_val, int): + raise ValueError("num_val must be integer but was: {}.".format(num_val)) if num_class is not None: raise NotImplementedError("Not supported to specify num_class for PKSampler.") @@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler): if not isinstance(shuffle, bool): raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - self.num_val = num_val self.shuffle = shuffle self.class_column = class_column # work for minddataset super().__init__(num_samples) - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle) - c_child_sampler = self.create_child() + shuffle = self.shuffle if self.shuffle is not None else False + c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def create_for_minddataset(self): + def parse_for_minddataset(self): if not self.class_column or not isinstance(self.class_column, str): raise ValueError("class_column should be a not empty string value, \ but got class_column: {}.".format(class_column)) num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler): if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - self.deterministic = False self.replacement = replacement self.reshuffle_each_epoch = True super().__init__(num_samples) - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) - c_child_sampler = self.create_child() + replacement = self.replacement if self.replacement is not None else False + c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler - def create_for_minddataset(self): + def parse_for_minddataset(self): num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler): """ def __init__(self, start_index=None, num_samples=None): - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - - if start_index is not None: - if start_index < 0: - raise ValueError("start_index should be a positive integer " - "value or 0, but got start_index: {}.".format(start_index)) - self.start_index = start_index super().__init__(num_samples) - def create(self): + def parse(self): start_index = self.start_index if self.start_index is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.SequentialSampler(num_samples, start_index) - c_child_sampler = self.create_child() + c_sampler = cde.SequentialSamplerObj(start_index, num_samples) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler - def create_for_minddataset(self): + def parse_for_minddataset(self): start_index = self.start_index if self.start_index is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler): """ def __init__(self, indices, num_samples=None): - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - if not isinstance(indices, list): indices = [indices] + for i, item in enumerate(indices): + if not isinstance(item, numbers.Number): + raise TypeError("type of weights element should be number, " + "but got w[{}]: {}, type: {}.".format(i, item, type(item))) + self.indices = indices super().__init__(num_samples) - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.SubsetSampler(num_samples, self.indices) - c_child_sampler = self.create_child() + c_sampler = cde.SubsetSamplerObj(self.indices, num_samples) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def create_for_minddataset(self): + def parse_for_minddataset(self): c_sampler = cde.MindrecordSubsetSampler(self.indices) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler): >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler) """ - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.SubsetRandomSampler(num_samples, self.indices) - c_child_sampler = self.create_child() + c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler def is_shuffled(self): return True - def create_for_minddataset(self): + def parse_for_minddataset(self): c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) - c_child_sampler = self.create_child_for_minddataset() + c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler @@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler): raise TypeError("type of weights element should be number, " "but got w[{}]: {}, type: {}.".format(ind, w, type(w))) - if weights == []: - raise ValueError("weights size should not be 0") - - if list(filter(lambda x: x < 0, weights)) != []: - raise ValueError("weights should not contain negative numbers.") - - if list(filter(lambda x: x == 0, weights)) == weights: - raise ValueError("elements of weights should not be all zeros.") - - if num_samples is not None: - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples: {}.".format(num_samples)) - if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) @@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler): self.replacement = replacement super().__init__(num_samples) - def create(self): + def parse(self): num_samples = self.num_samples if self.num_samples is not None else 0 - c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement) - c_child_sampler = self.create_child() + replacement = self.replacement if self.replacement is not None else True + c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement) + c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler diff --git a/tests/ut/python/dataset/test_datasets_imagefolder.py b/tests/ut/python/dataset/test_datasets_imagefolder.py index e52d38d04dd..aac12537b31 100644 --- a/tests/ut/python/dataset/test_datasets_imagefolder.py +++ b/tests/ut/python/dataset/test_datasets_imagefolder.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception(): weights = (0.9, 0.8, 1.1) ds.WeightedRandomSampler(weights) - error_msg_3 = "weights size should not be 0" - with pytest.raises(ValueError, match=error_msg_3): + error_msg_3 = "WeightedRandomSampler: weights vector must not be empty" + with pytest.raises(RuntimeError, match=error_msg_3): weights = [] - ds.WeightedRandomSampler(weights) + sampler = ds.WeightedRandomSampler(weights) + sampler.parse() - error_msg_4 = "weights should not contain negative numbers" - with pytest.raises(ValueError, match=error_msg_4): + error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: " + with pytest.raises(RuntimeError, match=error_msg_4): weights = [1.0, 0.1, 0.02, 0.3, -0.4] - ds.WeightedRandomSampler(weights) + sampler = ds.WeightedRandomSampler(weights) + sampler.parse() - error_msg_5 = "elements of weights should not be all zero" - with pytest.raises(ValueError, match=error_msg_5): + error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero" + with pytest.raises(RuntimeError, match=error_msg_5): weights = [0, 0, 0, 0, 0] - ds.WeightedRandomSampler(weights) + sampler = ds.WeightedRandomSampler(weights) + sampler.parse() def test_chained_sampler_01(): diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index f654084af58..0d3ccf5e44f 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0(): for partition_id in range(num_shards): data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, num_shards=num_shards, - shard_id=partition_id, num_samples=0) + shard_id=partition_id, num_samples=-1) num_iter = 0 for _ in data_set.create_dict_iterator(num_epochs=1): num_iter += 1 - with pytest.raises(Exception) as error_info: + with pytest.raises(ValueError) as error_info: partitions(5) try: - assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value) + assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value) except Exception as error: os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 7e502849620..83e0ec49d8e 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False): def test_sampler_py_api(): - sampler = ds.SequentialSampler().create() - sampler.set_num_rows(128) - sampler.set_num_samples(64) - sampler.initialize() - sampler.get_indices() - - sampler = ds.RandomSampler().create() - sampler.set_num_rows(128) - sampler.set_num_samples(64) - sampler.initialize() - sampler.get_indices() - - sampler = ds.DistributedSampler(8, 4).create() - sampler.set_num_rows(128) - sampler.set_num_samples(64) - sampler.initialize() - sampler.get_indices() + sampler = ds.SequentialSampler().parse() + sampler1 = ds.RandomSampler().parse() + sampler1.add_child(sampler) def test_python_sampler(): @@ -158,12 +144,6 @@ def test_python_sampler(): assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] test_generator() - sp1 = Sp1().create() - sp1.set_num_rows(5) - sp1.set_num_samples(5) - sp1.initialize() - assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] - def test_sequential_sampler2(): manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" @@ -229,8 +209,8 @@ def test_subset_sampler(): test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]") # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset - test_config([0, 9, 3, 2], num_samples=0, - exception_msg="num_samples should be a positive integer value, but got num_samples: 0.") + test_config([0, 9, 3, 2], num_samples=-1, + exception_msg="SubsetRandomSampler: invalid num_samples: -1") def test_sampler_chain(): @@ -280,9 +260,9 @@ def test_add_sampler_invalid_input(): def test_distributed_sampler_invalid_offset(): - with pytest.raises(ValueError) as info: - sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5) - assert "offset should be no more than num_shards" in str(info.value) + with pytest.raises(RuntimeError) as info: + sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse() + assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 725b4f7eed7..92935cd1603 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -377,7 +377,7 @@ def test_serdes_exception(): def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): """ Utility function for testing serdes files. It is to check if a json file is indeed created with correct name - after serializing and if it remains the same after repeatly saving and loading. + after serializing and if it remains the same after repeatedly saving and loading. :param data_orig: original data pipeline to be serialized :param filename: filename to be saved as json format :param remove_json_files: whether to remove the json file after testing