forked from mindspore-Ecosystem/mindspore
Changed bindings to SamplerObj
This commit is contained in:
parent
96f007ebb4
commit
8f34faeb7a
|
@ -7,11 +7,11 @@ if(ENABLE_PYTHON)
|
||||||
python/bindings/dataset/engine/cache/bindings.cc
|
python/bindings/dataset/engine/cache/bindings.cc
|
||||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
|
||||||
python/bindings/dataset/engine/gnn/bindings.cc
|
python/bindings/dataset/engine/gnn/bindings.cc
|
||||||
python/bindings/dataset/include/datasets_bindings.cc
|
python/bindings/dataset/include/datasets_bindings.cc
|
||||||
python/bindings/dataset/include/iterator_bindings.cc
|
python/bindings/dataset/include/iterator_bindings.cc
|
||||||
python/bindings/dataset/include/execute_binding.cc
|
python/bindings/dataset/include/execute_binding.cc
|
||||||
|
python/bindings/dataset/include/sampler_bindings.cc
|
||||||
python/bindings/dataset/include/schema_bindings.cc
|
python/bindings/dataset/include/schema_bindings.cc
|
||||||
python/bindings/dataset/kernels/bindings.cc
|
python/bindings/dataset/kernels/bindings.cc
|
||||||
python/bindings/dataset/kernels/data/bindings.cc
|
python/bindings/dataset/kernels/data/bindings.cc
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace dataset {
|
|
||||||
|
|
||||||
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
|
|
||||||
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
|
|
||||||
.def("set_num_rows",
|
|
||||||
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
|
||||||
.def("set_num_samples",
|
|
||||||
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
|
||||||
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
|
|
||||||
.def("get_indices",
|
|
||||||
[](SamplerRT &self) {
|
|
||||||
py::array ret;
|
|
||||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
|
||||||
return ret;
|
|
||||||
})
|
|
||||||
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
|
|
||||||
THROW_IF_ERROR(self->AddChild(child));
|
|
||||||
});
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
|
|
||||||
*m, "DistributedSampler")
|
|
||||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
|
|
||||||
.def(py::init<int64_t, int64_t, bool>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
|
|
||||||
.def(py::init<int64_t, py::object>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
|
|
||||||
.def(py::init<int64_t, bool, bool>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
|
|
||||||
*m, "SequentialSampler")
|
|
||||||
.def(py::init<int64_t, int64_t>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
|
|
||||||
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
|
|
||||||
*m, "SubsetRandomSampler")
|
|
||||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
|
|
||||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
|
|
||||||
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
|
|
||||||
*m, "WeightedRandomSampler")
|
|
||||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
|
||||||
}));
|
|
||||||
|
|
||||||
} // namespace dataset
|
|
||||||
} // namespace mindspore
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "pybind11/stl_bind.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||||
|
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||||
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
|
#include "minddata/dataset/core/constants.h"
|
||||||
|
#include "minddata/dataset/core/global_context.h"
|
||||||
|
#include "minddata/dataset/include/datasets.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
|
||||||
|
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
|
||||||
|
THROW_IF_ERROR(self->AddChildSampler(child));
|
||||||
|
});
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
|
||||||
|
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
|
||||||
|
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||||
|
uint32_t seed, int64_t offset, bool even_dist) {
|
||||||
|
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
|
||||||
|
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
|
||||||
|
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
|
||||||
|
.def(py::init([](int64_t num_samples, py::object sampler) {
|
||||||
|
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
|
||||||
|
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
|
||||||
|
THROW_IF_ERROR(sampler_obj->ValidateParams());
|
||||||
|
return sampler_obj;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
|
||||||
|
"to create a PKSamplerObj")
|
||||||
|
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
|
||||||
|
std::shared_ptr<PKSamplerObj> sampler =
|
||||||
|
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
|
||||||
|
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
|
||||||
|
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
|
||||||
|
std::shared_ptr<RandomSamplerObj> sampler =
|
||||||
|
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
|
||||||
|
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
|
||||||
|
.def(py::init([](int64_t start_index, int64_t num_samples) {
|
||||||
|
std::shared_ptr<SequentialSamplerObj> sampler =
|
||||||
|
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
|
||||||
|
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
|
||||||
|
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||||
|
std::shared_ptr<SubsetSamplerObj> sampler =
|
||||||
|
std::make_shared<SubsetSamplerObj>(indices, num_samples);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
|
||||||
|
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
|
||||||
|
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
|
||||||
|
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||||
|
std::shared_ptr<SubsetRandomSamplerObj> sampler =
|
||||||
|
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
|
||||||
|
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
|
||||||
|
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
|
||||||
|
std::shared_ptr<WeightedRandomSamplerObj> sampler =
|
||||||
|
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||||
|
THROW_IF_ERROR(sampler->ValidateParams());
|
||||||
|
return sampler;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
|
||||||
std::shared_ptr<SamplerObj> sampler_obj;
|
std::shared_ptr<SamplerObj> sampler_obj;
|
||||||
if (!isMindDataset) {
|
if (!isMindDataset) {
|
||||||
// Common Sampler
|
// Common Sampler
|
||||||
std::shared_ptr<SamplerRT> sampler;
|
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
|
||||||
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create");
|
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
|
||||||
sampler = create().cast<std::shared_ptr<SamplerRT>>();
|
|
||||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
|
||||||
} else {
|
} else {
|
||||||
// Mindrecord Sampler
|
// Mindrecord Sampler
|
||||||
std::shared_ptr<mindrecord::ShardOperator> sampler;
|
std::shared_ptr<mindrecord::ShardOperator> sampler;
|
||||||
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset");
|
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset");
|
||||||
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||||
}
|
}
|
||||||
return sampler_obj;
|
return sampler_obj;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "DistributedSampler";
|
||||||
|
args["num_shards"] = num_shards_;
|
||||||
|
args["shard_id"] = shard_id_;
|
||||||
|
args["shuffle"] = shuffle_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
args["offset"] = offset_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// PKSampler
|
// PKSampler
|
||||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||||
|
@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "PKSampler";
|
||||||
|
args["num_val"] = num_val_;
|
||||||
|
args["shuffle"] = shuffle_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||||
// runtime sampler object
|
// runtime sampler object
|
||||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||||
|
@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||||
return sampler;
|
return sampler;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||||
|
// runtime mindrecord sampler object
|
||||||
|
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||||
|
if (shuffle_ == true) {
|
||||||
|
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
||||||
|
GetSeed(), num_samples_);
|
||||||
|
} else {
|
||||||
|
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mind_sampler;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// PreBuiltOperation
|
// PreBuiltOperation
|
||||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
||||||
|
|
||||||
|
@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
|
||||||
// runtime mindrecord sampler object
|
|
||||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
|
||||||
if (shuffle_ == true) {
|
|
||||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
|
||||||
GetSeed(), num_samples_);
|
|
||||||
} else {
|
|
||||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
|
||||||
}
|
|
||||||
|
|
||||||
return mind_sampler;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// RandomSampler
|
// RandomSampler
|
||||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
|
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
|
||||||
: replacement_(replacement), num_samples_(num_samples) {}
|
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
|
||||||
|
|
||||||
Status RandomSamplerObj::ValidateParams() {
|
Status RandomSamplerObj::ValidateParams() {
|
||||||
if (num_samples_ < 0) {
|
if (num_samples_ < 0) {
|
||||||
|
@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "RandomSampler";
|
||||||
|
args["replacement"] = replacement_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||||
// runtime sampler object
|
// runtime sampler object
|
||||||
bool reshuffle_each_epoch = true;
|
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
|
||||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
|
|
||||||
BuildChildren(sampler);
|
BuildChildren(sampler);
|
||||||
return sampler;
|
return sampler;
|
||||||
}
|
}
|
||||||
|
@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
||||||
// runtime mindrecord sampler object
|
// runtime mindrecord sampler object
|
||||||
bool reshuffle_each_epoch_ = true;
|
|
||||||
auto mind_sampler =
|
auto mind_sampler =
|
||||||
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
|
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
|
||||||
|
|
||||||
|
@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "SequentialSampler";
|
||||||
|
args["start_index"] = start_index_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
||||||
// runtime sampler object
|
// runtime sampler object
|
||||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||||
|
@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
|
||||||
return mind_sampler;
|
return mind_sampler;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "SubsetSampler";
|
||||||
|
args["indices"] = indices_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// SubsetRandomSampler
|
// SubsetRandomSampler
|
||||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||||
|
@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "SubsetRandomSampler";
|
||||||
|
args["indices"] = indices_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// WeightedRandomSampler
|
// WeightedRandomSampler
|
||||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||||
|
@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["sampler_name"] = "WeightedRandomSampler";
|
||||||
|
args["weights"] = weights_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
args["replacement"] = replacement_;
|
||||||
|
if (!children_.empty()) {
|
||||||
|
std::vector<nlohmann::json> children_args;
|
||||||
|
for (auto child : children_) {
|
||||||
|
nlohmann::json child_arg;
|
||||||
|
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||||
|
children_args.push_back(child_arg);
|
||||||
|
}
|
||||||
|
args["child_sampler"] = children_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
||||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||||
BuildChildren(sampler);
|
BuildChildren(sampler);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -66,6 +66,8 @@ class SamplerObj {
|
||||||
|
|
||||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||||
|
@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
/// \brief Function to get the shard id of sampler
|
/// \brief Function to get the shard id of sampler
|
||||||
|
@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj {
|
||||||
|
|
||||||
class RandomSamplerObj : public SamplerObj {
|
class RandomSamplerObj : public SamplerObj {
|
||||||
public:
|
public:
|
||||||
RandomSamplerObj(bool replacement, int64_t num_samples);
|
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
|
||||||
|
|
||||||
~RandomSamplerObj() = default;
|
~RandomSamplerObj() = default;
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
|
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||||
for (auto child : children_) {
|
for (auto child : children_) {
|
||||||
sampler->AddChildSampler(child);
|
sampler->AddChildSampler(child);
|
||||||
}
|
}
|
||||||
|
@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool replacement_;
|
bool replacement_;
|
||||||
int64_t num_samples_;
|
int64_t num_samples_;
|
||||||
|
bool reshuffle_each_epoch_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SequentialSamplerObj : public SamplerObj {
|
class SequentialSamplerObj : public SamplerObj {
|
||||||
|
@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
|
||||||
|
|
||||||
~SubsetRandomSamplerObj() = default;
|
~SubsetRandomSamplerObj() = default;
|
||||||
|
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||||
|
@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
||||||
return sampler;
|
return sampler;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node
|
||||||
|
/// \param[out] out_json JSON string of all attributes
|
||||||
|
/// \return Status of the function
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -36,7 +36,7 @@ class BuiltinSampler:
|
||||||
self.child_sampler = None
|
self.child_sampler = None
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_child(self, sampler):
|
def add_child(self, sampler):
|
||||||
|
@ -59,16 +59,16 @@ class BuiltinSampler:
|
||||||
def get_child(self):
|
def get_child(self):
|
||||||
return self.child_sampler
|
return self.child_sampler
|
||||||
|
|
||||||
def create_child(self):
|
def parse_child(self):
|
||||||
c_child_sampler = None
|
c_child_sampler = None
|
||||||
if self.child_sampler is not None:
|
if self.child_sampler is not None:
|
||||||
c_child_sampler = self.child_sampler.create()
|
c_child_sampler = self.child_sampler.parse()
|
||||||
return c_child_sampler
|
return c_child_sampler
|
||||||
|
|
||||||
def create_child_for_minddataset(self):
|
def parse_child_for_minddataset(self):
|
||||||
c_child_sampler = None
|
c_child_sampler = None
|
||||||
if self.child_sampler is not None:
|
if self.child_sampler is not None:
|
||||||
c_child_sampler = self.child_sampler.create_for_minddataset()
|
c_child_sampler = self.child_sampler.parse_for_minddataset()
|
||||||
return c_child_sampler
|
return c_child_sampler
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
|
@ -158,6 +158,8 @@ class Sampler(BuiltinSampler):
|
||||||
def __init__(self, num_samples=None):
|
def __init__(self, num_samples=None):
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
self.dataset_size = 0
|
self.dataset_size = 0
|
||||||
|
self.child_sampler = None
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -192,13 +194,26 @@ class Sampler(BuiltinSampler):
|
||||||
|
|
||||||
# Instance fetcher
|
# Instance fetcher
|
||||||
# Do not override this method!
|
# Do not override this method!
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.PythonSampler(num_samples, self)
|
c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
|
||||||
c_child_sampler = self.create_child()
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
def add_child(self, sampler):
|
||||||
|
self.child_sampler = sampler
|
||||||
|
|
||||||
|
def get_child(self):
|
||||||
|
return self.child_sampler
|
||||||
|
|
||||||
|
def parse_child(self):
|
||||||
|
c_child_sampler = None
|
||||||
|
if self.child_sampler is not None:
|
||||||
|
c_child_sampler = self.child_sampler.parse()
|
||||||
|
|
||||||
|
return c_child_sampler
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
if self.child_sampler is None:
|
if self.child_sampler is None:
|
||||||
return False
|
return False
|
||||||
|
@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
||||||
if num_shards <= 0:
|
if not isinstance(num_shards, int):
|
||||||
raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards))
|
raise ValueError("num_shards must be integer but was: {}.".format(num_shards))
|
||||||
|
|
||||||
if shard_id < 0 or shard_id >= num_shards:
|
if not isinstance(shard_id, int):
|
||||||
raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id))
|
raise ValueError("shard_id must be integer but was: {}.".format(shard_id))
|
||||||
|
|
||||||
if not isinstance(shuffle, bool):
|
if not isinstance(shuffle, bool):
|
||||||
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
||||||
|
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
if offset > num_shards:
|
|
||||||
raise ValueError("offset should be no more than num_shards: {}, "
|
|
||||||
"but got offset: {}".format(num_shards, offset))
|
|
||||||
|
|
||||||
self.num_shards = num_shards
|
self.num_shards = num_shards
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler):
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
|
shuffle = self.shuffle if self.shuffle is not None else True
|
||||||
|
offset = self.offset if self.offset is not None else -1
|
||||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||||
self.seed += 1
|
self.seed += 1
|
||||||
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id,
|
c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
|
||||||
self.shuffle, self.seed, self.offset)
|
shuffle, num_samples, self.seed, offset, True)
|
||||||
c_child_sampler = self.create_child()
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
|
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
|
||||||
self.seed, num_samples, self.offset)
|
self.seed, num_samples, self.offset)
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
||||||
if num_val <= 0:
|
if not isinstance(num_val, int):
|
||||||
raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val))
|
raise ValueError("num_val must be integer but was: {}.".format(num_val))
|
||||||
|
|
||||||
if num_class is not None:
|
if num_class is not None:
|
||||||
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
|
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
|
||||||
|
@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler):
|
||||||
if not isinstance(shuffle, bool):
|
if not isinstance(shuffle, bool):
|
||||||
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
|
||||||
|
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
self.num_val = num_val
|
self.num_val = num_val
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.class_column = class_column # work for minddataset
|
self.class_column = class_column # work for minddataset
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
|
shuffle = self.shuffle if self.shuffle is not None else False
|
||||||
c_child_sampler = self.create_child()
|
c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
|
||||||
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler):
|
||||||
|
|
||||||
return self.child_sampler.is_sharded()
|
return self.child_sampler.is_sharded()
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
if not self.class_column or not isinstance(self.class_column, str):
|
if not self.class_column or not isinstance(self.class_column, str):
|
||||||
raise ValueError("class_column should be a not empty string value, \
|
raise ValueError("class_column should be a not empty string value, \
|
||||||
but got class_column: {}.".format(class_column))
|
but got class_column: {}.".format(class_column))
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
|
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler):
|
||||||
if not isinstance(replacement, bool):
|
if not isinstance(replacement, bool):
|
||||||
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
||||||
|
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
self.deterministic = False
|
self.deterministic = False
|
||||||
self.replacement = replacement
|
self.replacement = replacement
|
||||||
self.reshuffle_each_epoch = True
|
self.reshuffle_each_epoch = True
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
replacement = self.replacement if self.replacement is not None else False
|
||||||
c_child_sampler = self.create_child()
|
c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
|
||||||
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, start_index=None, num_samples=None):
|
def __init__(self, start_index=None, num_samples=None):
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
if start_index is not None:
|
|
||||||
if start_index < 0:
|
|
||||||
raise ValueError("start_index should be a positive integer "
|
|
||||||
"value or 0, but got start_index: {}.".format(start_index))
|
|
||||||
|
|
||||||
self.start_index = start_index
|
self.start_index = start_index
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
start_index = self.start_index if self.start_index is not None else 0
|
start_index = self.start_index if self.start_index is not None else 0
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.SequentialSampler(num_samples, start_index)
|
c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
|
||||||
c_child_sampler = self.create_child()
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
start_index = self.start_index if self.start_index is not None else 0
|
start_index = self.start_index if self.start_index is not None else 0
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
|
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, num_samples=None):
|
def __init__(self, indices, num_samples=None):
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
if not isinstance(indices, list):
|
if not isinstance(indices, list):
|
||||||
indices = [indices]
|
indices = [indices]
|
||||||
|
|
||||||
|
for i, item in enumerate(indices):
|
||||||
|
if not isinstance(item, numbers.Number):
|
||||||
|
raise TypeError("type of weights element should be number, "
|
||||||
|
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
|
||||||
|
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.SubsetSampler(num_samples, self.indices)
|
c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
|
||||||
c_child_sampler = self.create_child()
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler):
|
||||||
|
|
||||||
return self.child_sampler.is_sharded()
|
return self.child_sampler.is_sharded()
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
c_sampler = cde.MindrecordSubsetSampler(self.indices)
|
c_sampler = cde.MindrecordSubsetSampler(self.indices)
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler):
|
||||||
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
|
c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
|
||||||
c_child_sampler = self.create_child()
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_for_minddataset(self):
|
def parse_for_minddataset(self):
|
||||||
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
|
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
|
||||||
c_child_sampler = self.create_child_for_minddataset()
|
c_child_sampler = self.parse_child_for_minddataset()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler):
|
||||||
raise TypeError("type of weights element should be number, "
|
raise TypeError("type of weights element should be number, "
|
||||||
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
|
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
|
||||||
|
|
||||||
if weights == []:
|
|
||||||
raise ValueError("weights size should not be 0")
|
|
||||||
|
|
||||||
if list(filter(lambda x: x < 0, weights)) != []:
|
|
||||||
raise ValueError("weights should not contain negative numbers.")
|
|
||||||
|
|
||||||
if list(filter(lambda x: x == 0, weights)) == weights:
|
|
||||||
raise ValueError("elements of weights should not be all zeros.")
|
|
||||||
|
|
||||||
if num_samples is not None:
|
|
||||||
if num_samples <= 0:
|
|
||||||
raise ValueError("num_samples should be a positive integer "
|
|
||||||
"value, but got num_samples: {}.".format(num_samples))
|
|
||||||
|
|
||||||
if not isinstance(replacement, bool):
|
if not isinstance(replacement, bool):
|
||||||
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
|
||||||
|
|
||||||
|
@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler):
|
||||||
self.replacement = replacement
|
self.replacement = replacement
|
||||||
super().__init__(num_samples)
|
super().__init__(num_samples)
|
||||||
|
|
||||||
def create(self):
|
def parse(self):
|
||||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||||
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
|
replacement = self.replacement if self.replacement is not None else True
|
||||||
c_child_sampler = self.create_child()
|
c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
|
||||||
|
c_child_sampler = self.parse_child()
|
||||||
c_sampler.add_child(c_child_sampler)
|
c_sampler.add_child(c_child_sampler)
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception():
|
||||||
weights = (0.9, 0.8, 1.1)
|
weights = (0.9, 0.8, 1.1)
|
||||||
ds.WeightedRandomSampler(weights)
|
ds.WeightedRandomSampler(weights)
|
||||||
|
|
||||||
error_msg_3 = "weights size should not be 0"
|
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
|
||||||
with pytest.raises(ValueError, match=error_msg_3):
|
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||||
weights = []
|
weights = []
|
||||||
ds.WeightedRandomSampler(weights)
|
sampler = ds.WeightedRandomSampler(weights)
|
||||||
|
sampler.parse()
|
||||||
|
|
||||||
error_msg_4 = "weights should not contain negative numbers"
|
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
|
||||||
with pytest.raises(ValueError, match=error_msg_4):
|
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||||
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
||||||
ds.WeightedRandomSampler(weights)
|
sampler = ds.WeightedRandomSampler(weights)
|
||||||
|
sampler.parse()
|
||||||
|
|
||||||
error_msg_5 = "elements of weights should not be all zero"
|
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
|
||||||
with pytest.raises(ValueError, match=error_msg_5):
|
with pytest.raises(RuntimeError, match=error_msg_5):
|
||||||
weights = [0, 0, 0, 0, 0]
|
weights = [0, 0, 0, 0, 0]
|
||||||
ds.WeightedRandomSampler(weights)
|
sampler = ds.WeightedRandomSampler(weights)
|
||||||
|
sampler.parse()
|
||||||
|
|
||||||
|
|
||||||
def test_chained_sampler_01():
|
def test_chained_sampler_01():
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0():
|
||||||
for partition_id in range(num_shards):
|
for partition_id in range(num_shards):
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
||||||
num_shards=num_shards,
|
num_shards=num_shards,
|
||||||
shard_id=partition_id, num_samples=0)
|
shard_id=partition_id, num_samples=-1)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
with pytest.raises(Exception) as error_info:
|
with pytest.raises(ValueError) as error_info:
|
||||||
partitions(5)
|
partitions(5)
|
||||||
try:
|
try:
|
||||||
assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value)
|
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
os.remove("{}.db".format(CV_FILE_NAME))
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False):
|
||||||
|
|
||||||
|
|
||||||
def test_sampler_py_api():
|
def test_sampler_py_api():
|
||||||
sampler = ds.SequentialSampler().create()
|
sampler = ds.SequentialSampler().parse()
|
||||||
sampler.set_num_rows(128)
|
sampler1 = ds.RandomSampler().parse()
|
||||||
sampler.set_num_samples(64)
|
sampler1.add_child(sampler)
|
||||||
sampler.initialize()
|
|
||||||
sampler.get_indices()
|
|
||||||
|
|
||||||
sampler = ds.RandomSampler().create()
|
|
||||||
sampler.set_num_rows(128)
|
|
||||||
sampler.set_num_samples(64)
|
|
||||||
sampler.initialize()
|
|
||||||
sampler.get_indices()
|
|
||||||
|
|
||||||
sampler = ds.DistributedSampler(8, 4).create()
|
|
||||||
sampler.set_num_rows(128)
|
|
||||||
sampler.set_num_samples(64)
|
|
||||||
sampler.initialize()
|
|
||||||
sampler.get_indices()
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_sampler():
|
def test_python_sampler():
|
||||||
|
@ -158,12 +144,6 @@ def test_python_sampler():
|
||||||
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||||
test_generator()
|
test_generator()
|
||||||
|
|
||||||
sp1 = Sp1().create()
|
|
||||||
sp1.set_num_rows(5)
|
|
||||||
sp1.set_num_samples(5)
|
|
||||||
sp1.initialize()
|
|
||||||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
|
||||||
|
|
||||||
|
|
||||||
def test_sequential_sampler2():
|
def test_sequential_sampler2():
|
||||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||||
|
@ -229,8 +209,8 @@ def test_subset_sampler():
|
||||||
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
|
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
|
||||||
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
|
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
|
||||||
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
|
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
|
||||||
test_config([0, 9, 3, 2], num_samples=0,
|
test_config([0, 9, 3, 2], num_samples=-1,
|
||||||
exception_msg="num_samples should be a positive integer value, but got num_samples: 0.")
|
exception_msg="SubsetRandomSampler: invalid num_samples: -1")
|
||||||
|
|
||||||
|
|
||||||
def test_sampler_chain():
|
def test_sampler_chain():
|
||||||
|
@ -280,9 +260,9 @@ def test_add_sampler_invalid_input():
|
||||||
|
|
||||||
|
|
||||||
def test_distributed_sampler_invalid_offset():
|
def test_distributed_sampler_invalid_offset():
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
|
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
|
||||||
assert "offset should be no more than num_shards" in str(info.value)
|
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -377,7 +377,7 @@ def test_serdes_exception():
|
||||||
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
|
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
|
||||||
"""
|
"""
|
||||||
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
|
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
|
||||||
after serializing and if it remains the same after repeatly saving and loading.
|
after serializing and if it remains the same after repeatedly saving and loading.
|
||||||
:param data_orig: original data pipeline to be serialized
|
:param data_orig: original data pipeline to be serialized
|
||||||
:param filename: filename to be saved as json format
|
:param filename: filename to be saved as json format
|
||||||
:param remove_json_files: whether to remove the json file after testing
|
:param remove_json_files: whether to remove the json file after testing
|
||||||
|
|
Loading…
Reference in New Issue