Changed bindings to SamplerObj

This commit is contained in:
Mahdi 2020-12-22 17:56:35 -05:00
parent 96f007ebb4
commit 8f34faeb7a
11 changed files with 423 additions and 267 deletions

View File

@ -7,11 +7,11 @@ if(ENABLE_PYTHON)
python/bindings/dataset/engine/cache/bindings.cc python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/include/datasets_bindings.cc python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/execute_binding.cc python/bindings/dataset/include/execute_binding.cc
python/bindings/dataset/include/sampler_bindings.cc
python/bindings/dataset/include/schema_bindings.cc python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/kernels/bindings.cc python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc python/bindings/dataset/kernels/data/bindings.cc

View File

@ -1,93 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows",
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](SamplerRT &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
}));
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
THROW_IF_ERROR(self->AddChildSampler(child));
});
}));
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist) {
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
.def(py::init([](int64_t num_samples, py::object sampler) {
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
THROW_IF_ERROR(sampler_obj->ValidateParams());
return sampler_obj;
}));
}));
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
"to create a PKSamplerObj")
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
std::shared_ptr<PKSamplerObj> sampler =
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
std::shared_ptr<RandomSamplerObj> sampler =
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
.def(py::init([](int64_t start_index, int64_t num_samples) {
std::shared_ptr<SequentialSamplerObj> sampler =
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetSamplerObj> sampler =
std::make_shared<SubsetSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetRandomSamplerObj> sampler =
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
std::shared_ptr<WeightedRandomSamplerObj> sampler =
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
std::shared_ptr<SamplerObj> sampler_obj; std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) { if (!isMindDataset) {
// Common Sampler // Common Sampler
std::shared_ptr<SamplerRT> sampler; auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create"); sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
} else { } else {
// Mindrecord Sampler // Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler; std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset"); auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
} }
return sampler_obj; return sampler_obj;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
} }
#endif #endif
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// PKSampler // PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
return sampler; return sampler;
} }
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// PreBuiltOperation // PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK(); return Status::OK();
} }
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// RandomSampler // RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples) {} : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
Status RandomSamplerObj::ValidateParams() { Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
bool reshuffle_each_epoch = true; auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
BuildChildren(sampler); BuildChildren(sampler);
return sampler; return sampler;
} }
@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() { std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object // runtime mindrecord sampler object
bool reshuffle_each_epoch_ = true;
auto mind_sampler = auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_); std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
return mind_sampler; return mind_sampler;
} }
#endif #endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// SubsetRandomSampler // SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
} }
#endif #endif
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// WeightedRandomSampler // WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler); BuildChildren(sampler);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -66,6 +66,8 @@ class SamplerObj {
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
/// \brief Function to get the shard id of sampler /// \brief Function to get the shard id of sampler
@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
private: private:
@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj {
class RandomSamplerObj : public SamplerObj { class RandomSamplerObj : public SamplerObj {
public: public:
RandomSamplerObj(bool replacement, int64_t num_samples); RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
~RandomSamplerObj() = default; ~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override; std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override { std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChildSampler(child); sampler->AddChildSampler(child);
} }
@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
private: private:
bool replacement_; bool replacement_;
int64_t num_samples_; int64_t num_samples_;
bool reshuffle_each_epoch_;
}; };
class SequentialSamplerObj : public SamplerObj { class SequentialSamplerObj : public SamplerObj {
@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
private: private:
@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
protected: protected:
@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
~SubsetRandomSamplerObj() = default; ~SubsetRandomSamplerObj() = default;
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override; std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override { std::shared_ptr<SamplerObj> SamplerCopy() override {
@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj {
return sampler; return sampler;
} }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override; Status ValidateParams() override;
private: private:

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ class BuiltinSampler:
self.child_sampler = None self.child_sampler = None
self.num_samples = num_samples self.num_samples = num_samples
def create(self): def parse(self):
pass pass
def add_child(self, sampler): def add_child(self, sampler):
@ -59,16 +59,16 @@ class BuiltinSampler:
def get_child(self): def get_child(self):
return self.child_sampler return self.child_sampler
def create_child(self): def parse_child(self):
c_child_sampler = None c_child_sampler = None
if self.child_sampler is not None: if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create() c_child_sampler = self.child_sampler.parse()
return c_child_sampler return c_child_sampler
def create_child_for_minddataset(self): def parse_child_for_minddataset(self):
c_child_sampler = None c_child_sampler = None
if self.child_sampler is not None: if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create_for_minddataset() c_child_sampler = self.child_sampler.parse_for_minddataset()
return c_child_sampler return c_child_sampler
def is_shuffled(self): def is_shuffled(self):
@ -158,6 +158,8 @@ class Sampler(BuiltinSampler):
def __init__(self, num_samples=None): def __init__(self, num_samples=None):
super().__init__(num_samples) super().__init__(num_samples)
self.dataset_size = 0 self.dataset_size = 0
self.child_sampler = None
self.num_samples = num_samples
def __iter__(self): def __iter__(self):
""" """
@ -192,13 +194,26 @@ class Sampler(BuiltinSampler):
# Instance fetcher # Instance fetcher
# Do not override this method! # Do not override this method!
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PythonSampler(num_samples, self) c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
c_child_sampler = self.create_child() c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def parse_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def is_shuffled(self): def is_shuffled(self):
if self.child_sampler is None: if self.child_sampler is None:
return False return False
@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler):
""" """
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
if num_shards <= 0: if not isinstance(num_shards, int):
raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards)) raise ValueError("num_shards must be integer but was: {}.".format(num_shards))
if shard_id < 0 or shard_id >= num_shards: if not isinstance(shard_id, int):
raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id)) raise ValueError("shard_id must be integer but was: {}.".format(shard_id))
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if offset > num_shards:
raise ValueError("offset should be no more than num_shards: {}, "
"but got offset: {}".format(num_shards, offset))
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.shuffle = shuffle self.shuffle = shuffle
@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler):
self.offset = offset self.offset = offset
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
offset = self.offset if self.offset is not None else -1
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1 self.seed += 1
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
self.shuffle, self.seed, self.offset) shuffle, num_samples, self.seed, offset, True)
c_child_sampler = self.create_child() c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
def create_for_minddataset(self): def parse_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
self.seed, num_samples, self.offset) self.seed, num_samples, self.offset)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler):
""" """
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if num_val <= 0: if not isinstance(num_val, int):
raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val)) raise ValueError("num_val must be integer but was: {}.".format(num_val))
if num_class is not None: if num_class is not None:
raise NotImplementedError("Not supported to specify num_class for PKSampler.") raise NotImplementedError("Not supported to specify num_class for PKSampler.")
@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler):
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset self.class_column = class_column # work for minddataset
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle) shuffle = self.shuffle if self.shuffle is not None else False
c_child_sampler = self.create_child() c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler):
return self.child_sampler.is_sharded() return self.child_sampler.is_sharded()
def create_for_minddataset(self): def parse_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str): if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \ raise ValueError("class_column should be a not empty string value, \
but got class_column: {}.".format(class_column)) but got class_column: {}.".format(class_column))
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler):
if not isinstance(replacement, bool): if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
self.deterministic = False self.deterministic = False
self.replacement = replacement self.replacement = replacement
self.reshuffle_each_epoch = True self.reshuffle_each_epoch = True
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) replacement = self.replacement if self.replacement is not None else False
c_child_sampler = self.create_child() c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
def create_for_minddataset(self): def parse_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler):
""" """
def __init__(self, start_index=None, num_samples=None): def __init__(self, start_index=None, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if start_index is not None:
if start_index < 0:
raise ValueError("start_index should be a positive integer "
"value or 0, but got start_index: {}.".format(start_index))
self.start_index = start_index self.start_index = start_index
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
start_index = self.start_index if self.start_index is not None else 0 start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SequentialSampler(num_samples, start_index) c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
c_child_sampler = self.create_child() c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
def create_for_minddataset(self): def parse_for_minddataset(self):
start_index = self.start_index if self.start_index is not None else 0 start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler):
""" """
def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if not isinstance(indices, list): if not isinstance(indices, list):
indices = [indices] indices = [indices]
for i, item in enumerate(indices):
if not isinstance(item, numbers.Number):
raise TypeError("type of weights element should be number, "
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
self.indices = indices self.indices = indices
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetSampler(num_samples, self.indices) c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
c_child_sampler = self.create_child() c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded() return self.child_sampler.is_sharded()
def create_for_minddataset(self): def parse_for_minddataset(self):
c_sampler = cde.MindrecordSubsetSampler(self.indices) c_sampler = cde.MindrecordSubsetSampler(self.indices)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler):
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler) >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
""" """
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices) c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
c_child_sampler = self.create_child() c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
def is_shuffled(self): def is_shuffled(self):
return True return True
def create_for_minddataset(self): def parse_for_minddataset(self):
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler
@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler):
raise TypeError("type of weights element should be number, " raise TypeError("type of weights element should be number, "
"but got w[{}]: {}, type: {}.".format(ind, w, type(w))) "but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
if weights == []:
raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)) != []:
raise ValueError("weights should not contain negative numbers.")
if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zeros.")
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples: {}.".format(num_samples))
if not isinstance(replacement, bool): if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler):
self.replacement = replacement self.replacement = replacement
super().__init__(num_samples) super().__init__(num_samples)
def create(self): def parse(self):
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement) replacement = self.replacement if self.replacement is not None else True
c_child_sampler = self.create_child() c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception():
weights = (0.9, 0.8, 1.1) weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights) ds.WeightedRandomSampler(weights)
error_msg_3 = "weights size should not be 0" error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
with pytest.raises(ValueError, match=error_msg_3): with pytest.raises(RuntimeError, match=error_msg_3):
weights = [] weights = []
ds.WeightedRandomSampler(weights) sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_4 = "weights should not contain negative numbers" error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
with pytest.raises(ValueError, match=error_msg_4): with pytest.raises(RuntimeError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4] weights = [1.0, 0.1, 0.02, 0.3, -0.4]
ds.WeightedRandomSampler(weights) sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_5 = "elements of weights should not be all zero" error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
with pytest.raises(ValueError, match=error_msg_5): with pytest.raises(RuntimeError, match=error_msg_5):
weights = [0, 0, 0, 0, 0] weights = [0, 0, 0, 0, 0]
ds.WeightedRandomSampler(weights) sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
def test_chained_sampler_01(): def test_chained_sampler_01():

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0():
for partition_id in range(num_shards): for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards, num_shards=num_shards,
shard_id=partition_id, num_samples=0) shard_id=partition_id, num_samples=-1)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1): for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1 num_iter += 1
with pytest.raises(Exception) as error_info: with pytest.raises(ValueError) as error_info:
partitions(5) partitions(5)
try: try:
assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value) assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value)
except Exception as error: except Exception as error:
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False):
def test_sampler_py_api(): def test_sampler_py_api():
sampler = ds.SequentialSampler().create() sampler = ds.SequentialSampler().parse()
sampler.set_num_rows(128) sampler1 = ds.RandomSampler().parse()
sampler.set_num_samples(64) sampler1.add_child(sampler)
sampler.initialize()
sampler.get_indices()
sampler = ds.RandomSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.DistributedSampler(8, 4).create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
def test_python_sampler(): def test_python_sampler():
@ -158,12 +144,6 @@ def test_python_sampler():
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator() test_generator()
sp1 = Sp1().create()
sp1.set_num_rows(5)
sp1.set_num_samples(5)
sp1.initialize()
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sequential_sampler2(): def test_sequential_sampler2():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
@ -229,8 +209,8 @@ def test_subset_sampler():
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]") test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=0, test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="num_samples should be a positive integer value, but got num_samples: 0.") exception_msg="SubsetRandomSampler: invalid num_samples: -1")
def test_sampler_chain(): def test_sampler_chain():
@ -280,9 +260,9 @@ def test_add_sampler_invalid_input():
def test_distributed_sampler_invalid_offset(): def test_distributed_sampler_invalid_offset():
with pytest.raises(ValueError) as info: with pytest.raises(RuntimeError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5) sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
assert "offset should be no more than num_shards" in str(info.value) assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -377,7 +377,7 @@ def test_serdes_exception():
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
""" """
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
after serializing and if it remains the same after repeatly saving and loading. after serializing and if it remains the same after repeatedly saving and loading.
:param data_orig: original data pipeline to be serialized :param data_orig: original data pipeline to be serialized
:param filename: filename to be saved as json format :param filename: filename to be saved as json format
:param remove_json_files: whether to remove the json file after testing :param remove_json_files: whether to remove the json file after testing