!11784 Change samplers binding in Python to SamplerObj

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-30 10:16:30 +08:00 committed by Gitee
commit 5f0f9da6c6
11 changed files with 423 additions and 267 deletions

View File

@ -7,11 +7,11 @@ if(ENABLE_PYTHON)
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/execute_binding.cc
python/bindings/dataset/include/sampler_bindings.cc
python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc

View File

@ -1,93 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows",
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](SamplerRT &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
}));
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
THROW_IF_ERROR(self->AddChildSampler(child));
});
}));
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist) {
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
.def(py::init([](int64_t num_samples, py::object sampler) {
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
THROW_IF_ERROR(sampler_obj->ValidateParams());
return sampler_obj;
}));
}));
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
"to create a PKSamplerObj")
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
std::shared_ptr<PKSamplerObj> sampler =
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
std::shared_ptr<RandomSamplerObj> sampler =
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
.def(py::init([](int64_t start_index, int64_t num_samples) {
std::shared_ptr<SequentialSamplerObj> sampler =
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetSamplerObj> sampler =
std::make_shared<SubsetSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetRandomSamplerObj> sampler =
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
std::shared_ptr<WeightedRandomSamplerObj> sampler =
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) {
// Common Sampler
std::shared_ptr<SamplerRT> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create");
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
} else {
// Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset");
sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
}
return sampler_obj;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
}
#endif
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK();
}
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
: replacement_(replacement), num_samples_(num_samples) {}
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK();
}
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}
@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
bool reshuffle_each_epoch_ = true;
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK();
}
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
}
#endif
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK();
}
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -66,6 +66,8 @@ class SamplerObj {
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj {
class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples);
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};
class SequentialSamplerObj : public SamplerObj {
@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
protected:
@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
~SubsetRandomSamplerObj() = default;
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj {
return sampler;
}
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ class BuiltinSampler:
self.child_sampler = None
self.num_samples = num_samples
def create(self):
def parse(self):
pass
def add_child(self, sampler):
@ -59,16 +59,16 @@ class BuiltinSampler:
def get_child(self):
return self.child_sampler
def create_child(self):
def parse_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def create_child_for_minddataset(self):
def parse_child_for_minddataset(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create_for_minddataset()
c_child_sampler = self.child_sampler.parse_for_minddataset()
return c_child_sampler
def is_shuffled(self):
@ -158,6 +158,8 @@ class Sampler(BuiltinSampler):
def __init__(self, num_samples=None):
super().__init__(num_samples)
self.dataset_size = 0
self.child_sampler = None
self.num_samples = num_samples
def __iter__(self):
"""
@ -192,13 +194,26 @@ class Sampler(BuiltinSampler):
# Instance fetcher
# Do not override this method!
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PythonSampler(num_samples, self)
c_child_sampler = self.create_child()
c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def parse_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler):
"""
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
if num_shards <= 0:
raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards))
if not isinstance(num_shards, int):
raise ValueError("num_shards must be integer but was: {}.".format(num_shards))
if shard_id < 0 or shard_id >= num_shards:
raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id))
if not isinstance(shard_id, int):
raise ValueError("shard_id must be integer but was: {}.".format(shard_id))
if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if offset > num_shards:
raise ValueError("offset should be no more than num_shards: {}, "
"but got offset: {}".format(num_shards, offset))
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle = shuffle
@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler):
self.offset = offset
super().__init__(num_samples)
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
offset = self.offset if self.offset is not None else -1
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id,
self.shuffle, self.seed, self.offset)
c_child_sampler = self.create_child()
c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
shuffle, num_samples, self.seed, offset, True)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
def parse_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
self.seed, num_samples, self.offset)
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler):
"""
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if num_val <= 0:
raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val))
if not isinstance(num_val, int):
raise ValueError("num_val must be integer but was: {}.".format(num_val))
if num_class is not None:
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler):
if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
super().__init__(num_samples)
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
c_child_sampler = self.create_child()
shuffle = self.shuffle if self.shuffle is not None else False
c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def create_for_minddataset(self):
def parse_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column: {}.".format(class_column))
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler):
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
self.deterministic = False
self.replacement = replacement
self.reshuffle_each_epoch = True
super().__init__(num_samples)
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child()
replacement = self.replacement if self.replacement is not None else False
c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
def parse_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler):
"""
def __init__(self, start_index=None, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if start_index is not None:
if start_index < 0:
raise ValueError("start_index should be a positive integer "
"value or 0, but got start_index: {}.".format(start_index))
self.start_index = start_index
super().__init__(num_samples)
def create(self):
def parse(self):
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child()
c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
def parse_for_minddataset(self):
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler):
"""
def __init__(self, indices, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if not isinstance(indices, list):
indices = [indices]
for i, item in enumerate(indices):
if not isinstance(item, numbers.Number):
raise TypeError("type of weights element should be number, "
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
self.indices = indices
super().__init__(num_samples)
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetSampler(num_samples, self.indices)
c_child_sampler = self.create_child()
c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def create_for_minddataset(self):
def parse_for_minddataset(self):
c_sampler = cde.MindrecordSubsetSampler(self.indices)
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler):
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
c_child_sampler = self.create_child()
c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def create_for_minddataset(self):
def parse_for_minddataset(self):
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
c_child_sampler = self.create_child_for_minddataset()
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler):
raise TypeError("type of weights element should be number, "
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
if weights == []:
raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)) != []:
raise ValueError("weights should not contain negative numbers.")
if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zeros.")
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler):
self.replacement = replacement
super().__init__(num_samples)
def create(self):
def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
c_child_sampler = self.create_child()
replacement = self.replacement if self.replacement is not None else True
c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception():
weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights)
error_msg_3 = "weights size should not be 0"
with pytest.raises(ValueError, match=error_msg_3):
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
with pytest.raises(RuntimeError, match=error_msg_3):
weights = []
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_4 = "weights should not contain negative numbers"
with pytest.raises(ValueError, match=error_msg_4):
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
with pytest.raises(RuntimeError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_5 = "elements of weights should not be all zero"
with pytest.raises(ValueError, match=error_msg_5):
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
with pytest.raises(RuntimeError, match=error_msg_5):
weights = [0, 0, 0, 0, 0]
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
def test_chained_sampler_01():

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0():
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=0)
shard_id=partition_id, num_samples=-1)
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
with pytest.raises(Exception) as error_info:
with pytest.raises(ValueError) as error_info:
partitions(5)
try:
assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value)
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False):
def test_sampler_py_api():
sampler = ds.SequentialSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.RandomSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.DistributedSampler(8, 4).create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.SequentialSampler().parse()
sampler1 = ds.RandomSampler().parse()
sampler1.add_child(sampler)
def test_python_sampler():
@ -158,12 +144,6 @@ def test_python_sampler():
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator()
sp1 = Sp1().create()
sp1.set_num_rows(5)
sp1.set_num_samples(5)
sp1.initialize()
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sequential_sampler2():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
@ -229,8 +209,8 @@ def test_subset_sampler():
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=0,
exception_msg="num_samples should be a positive integer value, but got num_samples: 0.")
test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="SubsetRandomSampler: invalid num_samples: -1")
def test_sampler_chain():
@ -280,9 +260,9 @@ def test_add_sampler_invalid_input():
def test_distributed_sampler_invalid_offset():
with pytest.raises(ValueError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
assert "offset should be no more than num_shards" in str(info.value)
with pytest.raises(RuntimeError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
if __name__ == '__main__':

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -377,7 +377,7 @@ def test_serdes_exception():
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
"""
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
after serializing and if it remains the same after repeatly saving and loading.
after serializing and if it remains the same after repeatedly saving and loading.
:param data_orig: original data pipeline to be serialized
:param filename: filename to be saved as json format
:param remove_json_files: whether to remove the json file after testing