!11784 Change samplers binding in Python to SamplerObj
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5f0f9da6c6
|
@ -7,11 +7,11 @@ if(ENABLE_PYTHON)
|
|||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/execute_binding.cc
|
||||
python/bindings/dataset/include/sampler_bindings.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/kernels/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
|
||||
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
|
||||
.def("set_num_rows",
|
||||
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||
.def("set_num_samples",
|
||||
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||
.def("get_indices",
|
||||
[](SamplerRT &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
})
|
||||
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
|
||||
THROW_IF_ERROR(self->AddChild(child));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
|
||||
*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
|
||||
.def(py::init<int64_t, int64_t, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
|
||||
.def(py::init<int64_t, py::object>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
|
||||
.def(py::init<int64_t, bool, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
|
||||
*m, "SequentialSampler")
|
||||
.def(py::init<int64_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
|
||||
*m, "SubsetRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
|
||||
*m, "WeightedRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
|
||||
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
|
||||
THROW_IF_ERROR(self->AddChildSampler(child));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
|
||||
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
|
||||
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed, int64_t offset, bool even_dist) {
|
||||
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
|
||||
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
|
||||
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
|
||||
.def(py::init([](int64_t num_samples, py::object sampler) {
|
||||
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
|
||||
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
|
||||
THROW_IF_ERROR(sampler_obj->ValidateParams());
|
||||
return sampler_obj;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
|
||||
"to create a PKSamplerObj")
|
||||
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
|
||||
std::shared_ptr<PKSamplerObj> sampler =
|
||||
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
|
||||
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
|
||||
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
|
||||
std::shared_ptr<RandomSamplerObj> sampler =
|
||||
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
|
||||
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
|
||||
.def(py::init([](int64_t start_index, int64_t num_samples) {
|
||||
std::shared_ptr<SequentialSamplerObj> sampler =
|
||||
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
|
||||
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
|
||||
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||
std::shared_ptr<SubsetSamplerObj> sampler =
|
||||
std::make_shared<SubsetSamplerObj>(indices, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
|
||||
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
|
||||
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||
std::shared_ptr<SubsetRandomSamplerObj> sampler =
|
||||
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
|
||||
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
|
||||
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
|
||||
std::shared_ptr<WeightedRandomSamplerObj> sampler =
|
||||
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
|
|||
std::shared_ptr<SamplerObj> sampler_obj;
|
||||
if (!isMindDataset) {
|
||||
// Common Sampler
|
||||
std::shared_ptr<SamplerRT> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<SamplerRT>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
|
||||
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
|
||||
} else {
|
||||
// Mindrecord Sampler
|
||||
std::shared_ptr<mindrecord::ShardOperator> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset");
|
||||
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset");
|
||||
sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
}
|
||||
return sampler_obj;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
|
|||
}
|
||||
#endif
|
||||
|
||||
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "DistributedSampler";
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["offset"] = offset_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// PKSampler
|
||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||
|
@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "PKSampler";
|
||||
args["num_val"] = num_val_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||
|
@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||
if (shuffle_ == true) {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed(), num_samples_);
|
||||
} else {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
||||
}
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
||||
|
||||
|
@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||
if (shuffle_ == true) {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed(), num_samples_);
|
||||
} else {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
||||
}
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// RandomSampler
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
|
||||
: replacement_(replacement), num_samples_(num_samples) {}
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
|
||||
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
|
||||
|
||||
Status RandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
|
@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "RandomSampler";
|
||||
args["replacement"] = replacement_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
bool reshuffle_each_epoch = true;
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
|||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
bool reshuffle_each_epoch_ = true;
|
||||
auto mind_sampler =
|
||||
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
|
||||
|
@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SequentialSampler";
|
||||
args["start_index"] = start_index_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||
|
@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
|
|||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// SubsetRandomSampler
|
||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
|
@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
|
|||
}
|
||||
#endif
|
||||
|
||||
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetRandomSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// WeightedRandomSampler
|
||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "WeightedRandomSampler";
|
||||
args["weights"] = weights_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["replacement"] = replacement_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||
BuildChildren(sampler);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -66,6 +66,8 @@ class SamplerObj {
|
|||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
|
@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
|
@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj {
|
|||
|
||||
class RandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples);
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
|
||||
|
||||
~RandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
|
@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
bool replacement_;
|
||||
int64_t num_samples_;
|
||||
bool reshuffle_each_epoch_;
|
||||
};
|
||||
|
||||
class SequentialSamplerObj : public SamplerObj {
|
||||
|
@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
protected:
|
||||
|
@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
|
|||
|
||||
~SubsetRandomSamplerObj() = default;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
|
@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -36,7 +36,7 @@ class BuiltinSampler:
|
|||
self.child_sampler = None
|
||||
self.num_samples = num_samples
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
pass
|
||||
|
||||
def add_child(self, sampler):
|
||||
|
@ -59,16 +59,16 @@ class BuiltinSampler:
|
|||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def create_child(self):
|
||||
def parse_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
c_child_sampler = self.child_sampler.parse()
|
||||
return c_child_sampler
|
||||
|
||||
def create_child_for_minddataset(self):
|
||||
def parse_child_for_minddataset(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create_for_minddataset()
|
||||
c_child_sampler = self.child_sampler.parse_for_minddataset()
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
|
@ -158,6 +158,8 @@ class Sampler(BuiltinSampler):
|
|||
def __init__(self, num_samples=None):
|
||||
super().__init__(num_samples)
|
||||
self.dataset_size = 0
|
||||
self.child_sampler = None
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
@ -192,13 +194,26 @@ class Sampler(BuiltinSampler):
|
|||
|
||||
# Instance fetcher
|
||||
# Do not override this method!
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.PythonSampler(num_samples, self)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def add_child(self, sampler):
|
||||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def parse_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.parse()
|
||||
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
||||
if num_shards <= 0:
|
||||
raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards))
|
||||
if not isinstance(num_shards, int):
|
||||
raise ValueError("num_shards must be integer but was: {}.".format(num_shards))
|
||||
|
||||
if shard_id < 0 or shard_id >= num_shards:
|
||||
raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id))
|
||||
if not isinstance(shard_id, int):
|
||||
raise ValueError("shard_id must be integer but was: {}.".format(shard_id))
|
||||
|
||||
if not isinstance(shuffle, bool):
|
||||
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
if offset > num_shards:
|
||||
raise ValueError("offset should be no more than num_shards: {}, "
|
||||
"but got offset: {}".format(num_shards, offset))
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
|
@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler):
|
|||
self.offset = offset
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
shuffle = self.shuffle if self.shuffle is not None else True
|
||||
offset = self.offset if self.offset is not None else -1
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
self.seed += 1
|
||||
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id,
|
||||
self.shuffle, self.seed, self.offset)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
|
||||
shuffle, num_samples, self.seed, offset, True)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
|
||||
self.seed, num_samples, self.offset)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
||||
if num_val <= 0:
|
||||
raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val))
|
||||
if not isinstance(num_val, int):
|
||||
raise ValueError("num_val must be integer but was: {}.".format(num_val))
|
||||
|
||||
if num_class is not None:
|
||||
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
|
||||
|
@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler):
|
|||
if not isinstance(shuffle, bool):
|
||||
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
|
||||
c_child_sampler = self.create_child()
|
||||
shuffle = self.shuffle if self.shuffle is not None else False
|
||||
c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
raise ValueError("class_column should be a not empty string value, \
|
||||
but got class_column: {}.".format(class_column))
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler):
|
|||
if not isinstance(replacement, bool):
|
||||
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
self.deterministic = False
|
||||
self.replacement = replacement
|
||||
self.reshuffle_each_epoch = True
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.create_child()
|
||||
replacement = self.replacement if self.replacement is not None else False
|
||||
c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, start_index=None, num_samples=None):
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
if start_index is not None:
|
||||
if start_index < 0:
|
||||
raise ValueError("start_index should be a positive integer "
|
||||
"value or 0, but got start_index: {}.".format(start_index))
|
||||
|
||||
self.start_index = start_index
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
start_index = self.start_index if self.start_index is not None else 0
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.SequentialSampler(num_samples, start_index)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
start_index = self.start_index if self.start_index is not None else 0
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, indices, num_samples=None):
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
|
||||
for i, item in enumerate(indices):
|
||||
if not isinstance(item, numbers.Number):
|
||||
raise TypeError("type of weights element should be number, "
|
||||
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
|
||||
|
||||
self.indices = indices
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.SubsetSampler(num_samples, self.indices)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
c_sampler = cde.MindrecordSubsetSampler(self.indices)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler):
|
|||
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
"""
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def create_for_minddataset(self):
|
||||
def parse_for_minddataset(self):
|
||||
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
raise TypeError("type of weights element should be number, "
|
||||
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
|
||||
|
||||
if weights == []:
|
||||
raise ValueError("weights size should not be 0")
|
||||
|
||||
if list(filter(lambda x: x < 0, weights)) != []:
|
||||
raise ValueError("weights should not contain negative numbers.")
|
||||
|
||||
if list(filter(lambda x: x == 0, weights)) == weights:
|
||||
raise ValueError("elements of weights should not be all zeros.")
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples: {}.".format(num_samples))
|
||||
|
||||
if not isinstance(replacement, bool):
|
||||
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
||||
|
||||
|
@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
self.replacement = replacement
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
def parse(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
|
||||
c_child_sampler = self.create_child()
|
||||
replacement = self.replacement if self.replacement is not None else True
|
||||
c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
|
||||
c_child_sampler = self.parse_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception():
|
|||
weights = (0.9, 0.8, 1.1)
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_3 = "weights size should not be 0"
|
||||
with pytest.raises(ValueError, match=error_msg_3):
|
||||
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
weights = []
|
||||
ds.WeightedRandomSampler(weights)
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
error_msg_4 = "weights should not contain negative numbers"
|
||||
with pytest.raises(ValueError, match=error_msg_4):
|
||||
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
||||
ds.WeightedRandomSampler(weights)
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
error_msg_5 = "elements of weights should not be all zero"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
|
||||
with pytest.raises(RuntimeError, match=error_msg_5):
|
||||
weights = [0, 0, 0, 0, 0]
|
||||
ds.WeightedRandomSampler(weights)
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
|
||||
def test_chained_sampler_01():
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0():
|
|||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id, num_samples=0)
|
||||
shard_id=partition_id, num_samples=-1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
with pytest.raises(Exception) as error_info:
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
partitions(5)
|
||||
try:
|
||||
assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value)
|
||||
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False):
|
|||
|
||||
|
||||
def test_sampler_py_api():
|
||||
sampler = ds.SequentialSampler().create()
|
||||
sampler.set_num_rows(128)
|
||||
sampler.set_num_samples(64)
|
||||
sampler.initialize()
|
||||
sampler.get_indices()
|
||||
|
||||
sampler = ds.RandomSampler().create()
|
||||
sampler.set_num_rows(128)
|
||||
sampler.set_num_samples(64)
|
||||
sampler.initialize()
|
||||
sampler.get_indices()
|
||||
|
||||
sampler = ds.DistributedSampler(8, 4).create()
|
||||
sampler.set_num_rows(128)
|
||||
sampler.set_num_samples(64)
|
||||
sampler.initialize()
|
||||
sampler.get_indices()
|
||||
sampler = ds.SequentialSampler().parse()
|
||||
sampler1 = ds.RandomSampler().parse()
|
||||
sampler1.add_child(sampler)
|
||||
|
||||
|
||||
def test_python_sampler():
|
||||
|
@ -158,12 +144,6 @@ def test_python_sampler():
|
|||
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||
test_generator()
|
||||
|
||||
sp1 = Sp1().create()
|
||||
sp1.set_num_rows(5)
|
||||
sp1.set_num_samples(5)
|
||||
sp1.initialize()
|
||||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_sequential_sampler2():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
|
@ -229,8 +209,8 @@ def test_subset_sampler():
|
|||
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
|
||||
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
|
||||
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
|
||||
test_config([0, 9, 3, 2], num_samples=0,
|
||||
exception_msg="num_samples should be a positive integer value, but got num_samples: 0.")
|
||||
test_config([0, 9, 3, 2], num_samples=-1,
|
||||
exception_msg="SubsetRandomSampler: invalid num_samples: -1")
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
|
@ -280,9 +260,9 @@ def test_add_sampler_invalid_input():
|
|||
|
||||
|
||||
def test_distributed_sampler_invalid_offset():
|
||||
with pytest.raises(ValueError) as info:
|
||||
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
|
||||
assert "offset should be no more than num_shards" in str(info.value)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
|
||||
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -377,7 +377,7 @@ def test_serdes_exception():
|
|||
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
|
||||
"""
|
||||
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
|
||||
after serializing and if it remains the same after repeatly saving and loading.
|
||||
after serializing and if it remains the same after repeatedly saving and loading.
|
||||
:param data_orig: original data pipeline to be serialized
|
||||
:param filename: filename to be saved as json format
|
||||
:param remove_json_files: whether to remove the json file after testing
|
||||
|
|
Loading…
Reference in New Issue