forked from mindspore-Ecosystem/mindspore
!21648 [assistant][ops] Add new data loading operator YesNoDataset
Merge pull request !21648 from 杨旭华/YesNoDataset
This commit is contained in:
commit
c6821bde0a
|
@ -115,6 +115,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -1543,6 +1544,27 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_f
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
|
||||
|
||||
// IR leaf nodes disabled for android
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -448,5 +449,15 @@ PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode",
|
||||
"to create a YesNoNode")
|
||||
.def(py::init([](std::string dataset_dir, py::handle sampler) {
|
||||
auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(yes_no->ValidateParams());
|
||||
return yes_no;
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -3,33 +3,34 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
io_block.cc
|
||||
image_folder_op.cc
|
||||
mnist_op.cc
|
||||
coco_op.cc
|
||||
cifar_op.cc
|
||||
random_data_op.cc
|
||||
celeba_op.cc
|
||||
sbu_op.cc
|
||||
text_file_op.cc
|
||||
clue_op.cc
|
||||
csv_op.cc
|
||||
ag_news_op.cc
|
||||
album_op.cc
|
||||
usps_op.cc
|
||||
mappable_leaf_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
celeba_op.cc
|
||||
cifar_op.cc
|
||||
cityscapes_op.cc
|
||||
clue_op.cc
|
||||
coco_op.cc
|
||||
csv_op.cc
|
||||
dbpedia_op.cc
|
||||
div2k_op.cc
|
||||
flickr_op.cc
|
||||
qmnist_op.cc
|
||||
emnist_op.cc
|
||||
fake_image_op.cc
|
||||
lj_speech_op.cc
|
||||
places365_op.cc
|
||||
photo_tour_op.cc
|
||||
fashion_mnist_op.cc
|
||||
ag_news_op.cc
|
||||
dbpedia_op.cc
|
||||
flickr_op.cc
|
||||
image_folder_op.cc
|
||||
io_block.cc
|
||||
lj_speech_op.cc
|
||||
mappable_leaf_op.cc
|
||||
mnist_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
photo_tour_op.cc
|
||||
places365_op.cc
|
||||
qmnist_op.cc
|
||||
random_data_op.cc
|
||||
sbu_op.cc
|
||||
text_file_op.cc
|
||||
usps_op.cc
|
||||
yes_no_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/yes_no_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr float kMaxShortVal = 32767.0;
|
||||
constexpr char kExtension[] = ".wav";
|
||||
constexpr int kStrLen = 15; // the length of name.
|
||||
#ifndef _WIN32
|
||||
constexpr char kSplitSymbol[] = "/";
|
||||
#else
|
||||
constexpr char kSplitSymbol[] = "\\";
|
||||
#endif
|
||||
|
||||
YesNoOp::YesNoOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(file_dir),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
Status YesNoOp::PrepareData() {
|
||||
auto realpath = FileUtils::GetRealPath(dataset_dir_.data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_);
|
||||
}
|
||||
Path dir(realpath.value());
|
||||
if (dir.Exists() == false || dir.IsDirectory() == false) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid parameter, failed to open speech commands: " + dataset_dir_);
|
||||
}
|
||||
std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
|
||||
RETURN_UNEXPECTED_IF_NULL(dir_itr);
|
||||
while (dir_itr->HasNext()) {
|
||||
Path file = dir_itr->Next();
|
||||
if (file.Extension() == kExtension) {
|
||||
all_wave_files_.emplace_back(file.ToString());
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!all_wave_files_.empty(), "Invalid file, no .wav files found under " + dataset_dir_);
|
||||
num_rows_ = all_wave_files_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void YesNoOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\n";
|
||||
} else {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nYesNo directory: " << dataset_dir_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status YesNoOp::Split(const std::string &line, std::vector<int32_t> *split_num) {
|
||||
RETURN_UNEXPECTED_IF_NULL(split_num);
|
||||
std::string str = line;
|
||||
int dot_pos = str.find_last_of(kSplitSymbol);
|
||||
std::string sub_line = line.substr(dot_pos + 1, kStrLen); // (dot_pos + 1) because the index start from 0.
|
||||
std::string::size_type pos;
|
||||
std::vector<std::string> split;
|
||||
sub_line += "_"; // append to sub_line indicating the end of the string.
|
||||
uint32_t size = sub_line.size();
|
||||
for (uint32_t index = 0; index < size;) {
|
||||
pos = sub_line.find("_", index);
|
||||
if (pos != index) {
|
||||
std::string s = sub_line.substr(index, pos - index);
|
||||
split.emplace_back(s);
|
||||
}
|
||||
index = pos + 1;
|
||||
}
|
||||
try {
|
||||
for (int i = 0; i < split.size(); i++) {
|
||||
split_num->emplace_back(stoi(split[i]));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Converting char to int confront with an error in function stoi().";
|
||||
RETURN_STATUS_UNEXPECTED("Converting char to int confront with an error in function stoi().");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoOp::LoadTensorRow(row_id_type index, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::shared_ptr<Tensor> waveform, sample_rate_scalar, label_scalar;
|
||||
int32_t sample_rate;
|
||||
std::string file_name = all_wave_files_[index];
|
||||
std::vector<int32_t> label;
|
||||
std::vector<float> waveform_vec;
|
||||
RETURN_IF_NOT_OK(Split(file_name, &label));
|
||||
RETURN_IF_NOT_OK(ReadWaveFile(file_name, &waveform_vec, &sample_rate));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, &waveform));
|
||||
RETURN_IF_NOT_OK(waveform->ExpandDim(0));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_scalar));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(label, &label_scalar));
|
||||
(*trow) = TensorRow(index, {waveform, sample_rate_scalar, label_scalar});
|
||||
trow->setPath({file_name, file_name, file_name});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoOp::CountTotalRows(int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
if (all_wave_files_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
}
|
||||
*count = static_cast<int64_t>(all_wave_files_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoOp::ComputeColMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class YesNoOp : public MappableLeafOp {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param std::string file_dir - dir directory of YesNo.
|
||||
/// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
/// @param int32_t queue_size - connector queue size.
|
||||
/// @param std::unique_ptr<DataSchema> data_schema - the schema of the YesNo dataset.
|
||||
/// @param std::shared_ptr<Sampler> sampler - sampler tells YesNoOp what to read.
|
||||
YesNoOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// Destructor.
|
||||
~YesNoOp() = default;
|
||||
|
||||
/// A print method typically used for debugging.
|
||||
/// @param std::ostream &out - out stream.
|
||||
/// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// Op name getter.
|
||||
/// @return Name of the current Op.
|
||||
std::string Name() const override { return "YesNoOp"; }
|
||||
|
||||
/// @param int64_t *count - output rows number of YesNoDataset.
|
||||
/// @return Status - The status code returned.
|
||||
Status CountTotalRows(int64_t *count);
|
||||
|
||||
private:
|
||||
/// Load a tensor row according to wave id.
|
||||
/// @param row_id_type row_id - id for this tensor row.
|
||||
/// @param TensorRow trow - wave & target read into this tensor row.
|
||||
/// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *trow) override;
|
||||
|
||||
/// Get file infos by file name.
|
||||
/// @param string line - file name.
|
||||
/// @param vector split_num - vector of annotation.
|
||||
/// @return Status - The status code returned.
|
||||
Status Split(const std::string &line, std::vector<int32_t> *split_num);
|
||||
|
||||
/// Initialize YesNoDataset related var, calls the function to walk all files.
|
||||
/// @return Status - The status code returned.
|
||||
Status PrepareData();
|
||||
|
||||
/// Private function for computing the assignment of the column name map.
|
||||
/// @return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
std::vector<std::string> all_wave_files_;
|
||||
std::string dataset_dir_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H
|
|
@ -104,6 +104,7 @@ constexpr char kTextFileNode[] = "TextFileDataset";
|
|||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
constexpr char kUSPSNode[] = "USPSDataset";
|
||||
constexpr char kVOCNode[] = "VOCDataset";
|
||||
constexpr char kYesNoNode[] = "YesNoDataset";
|
||||
|
||||
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
|
|
@ -32,6 +32,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
tf_record_node.cc
|
||||
usps_node.cc
|
||||
voc_node.cc
|
||||
yes_no_node.cc
|
||||
)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/yes_no_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor for YesNoNode.
|
||||
YesNoNode::YesNoNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> YesNoNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<YesNoNode>(dataset_dir_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void YesNoNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status YesNoNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("YesNoNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("YesNoNode", sampler_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
|
||||
TensorShape lable_scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &lable_scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<YesNoOp>(dataset_dir_, num_workers_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(Build(&ops));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build YesNoOp.");
|
||||
auto op = std::dynamic_pointer_cast<YesNoOp>(ops.front());
|
||||
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status YesNoNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class YesNoNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
YesNoNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~YesNoNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return "YesNoNode"; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id Shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H
|
|
@ -3919,6 +3919,72 @@ inline std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std
|
|||
MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
|
||||
}
|
||||
|
||||
/// \class YesNoDataset.
|
||||
/// \brief A source dataset for reading and parsing YesNo dataset.
|
||||
class YesNoDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of YesNoDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of YesNoDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of YesNoDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of YesNoDataset.
|
||||
~YesNoDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a YesNo Dataset.
|
||||
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the current Dataset.
|
||||
inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a YesNo Dataset.
|
||||
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the current Dataset.
|
||||
inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir, Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a YesNo Dataset.
|
||||
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the current Dataset.
|
||||
inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a cache to be attached to a dataset.
|
||||
/// \note The reason for providing this API is that std::string will be constrained by the
|
||||
/// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction.
|
||||
|
|
|
@ -57,6 +57,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class TFRecordDataset;
|
||||
friend class USPSDataset;
|
||||
friend class VOCDataset;
|
||||
friend class YesNoDataset;
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
|
|
@ -68,7 +68,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
|
||||
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset
|
||||
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
|
||||
check_yes_no_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size, get_auto_offload
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -8362,3 +8363,116 @@ class DIV2KDataset(MappableDataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler)
|
||||
|
||||
|
||||
class YesNoDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the YesNo dataset.
|
||||
|
||||
The generated dataset has three columns :py:obj:`[waveform, sample_rate, labels]`.
|
||||
The tensor of column :py:obj:`waveform` is a vector of the float32 type.
|
||||
The tensor of column :py:obj:`sample_rate` is a scalar of the int32 type.
|
||||
The tensor of column :py:obj:`labels` is a scalar of the int32 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument can only
|
||||
be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> yes_no_dataset_dir = "/path/to/yes_no_dataset_directory"
|
||||
>>>
|
||||
>>> # Read 3 samples from YesNo dataset
|
||||
>>> dataset = ds.YesNoDataset(dataset_dir=yes_no_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # Note: In YesNo dataset, each dictionary has keys "waveform", "sample_rate", "label"
|
||||
|
||||
About YesNo dataset:
|
||||
|
||||
Yesno is an audio dataset consisting of 60 recordings of one individual saying yes or no in Hebrew; each
|
||||
recording is eight words long. It was created for the Kaldi audio project by an author who wishes to
|
||||
remain anonymous.
|
||||
|
||||
Here is the original YesNo dataset structure.
|
||||
You can unzip the dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── yes_no_dataset_dir
|
||||
├── 1_1_0_0_1_1_0_0.wav
|
||||
├── 1_0_0_0_1_1_0_0.wav
|
||||
├── 1_1_0_0_1_1_0_0.wav
|
||||
└──....
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@NetworkResource{Kaldi_audio_project,
|
||||
author = {anonymous},
|
||||
url = "http://wwww.openslr.org/1/"
|
||||
}
|
||||
"""
|
||||
|
||||
@check_yes_no_dataset
|
||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.YesNoNode(self.dataset_dir, self.sampler)
|
||||
|
|
|
@ -1807,3 +1807,29 @@ def check_dbpedia_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_yes_no_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -43,6 +43,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_usps_test.cc
|
||||
c_api_dataset_voc_test.cc
|
||||
c_api_dataset_yes_no_test.cc
|
||||
c_api_datasets_test.cc
|
||||
c_api_epoch_ctrl_test.cc
|
||||
c_api_pull_based_test.cc
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: read data from a single file.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestYesNoDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoDataset.";
|
||||
// Create a YesNoDataset
|
||||
std::string folder_path = datasets_root_path_ + "/testYesNoData/";
|
||||
std::shared_ptr<Dataset> ds = YesNo(folder_path, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
MS_LOG(INFO) << "iter->GetNextRow(&row) OK";
|
||||
|
||||
EXPECT_NE(row.find("waveform"), row.end());
|
||||
EXPECT_NE(row.find("sample_rate"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto waveform = row["waveform"];
|
||||
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: test YesNo dataset with pipeline.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, YesNoDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-YesNoDatasetWithPipeline.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testYesNoData/";
|
||||
std::shared_ptr<Dataset> ds1 = YesNo(folder_path, std::make_shared<RandomSampler>(false, 1));
|
||||
std::shared_ptr<Dataset> ds2 = YesNo(folder_path, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 1;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"waveform", "sample_rate", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("waveform"), row.end());
|
||||
EXPECT_NE(row.find("sample_rate"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto waveform = row["waveform"];
|
||||
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: get the size of YesNo dataset.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestYesNoGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoGetDatasetSize.";
|
||||
|
||||
// Create a YesNo Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testYesNoData/";
|
||||
std::shared_ptr<Dataset> ds = YesNo(folder_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
}
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: getter functions.
|
||||
/// Expectation: the data is processed successfully.
|
||||
TEST_F(MindDataTestPipeline, TestYesNoGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoMixGetter.";
|
||||
// Create a YesNo Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testYesNoData/";
|
||||
std::shared_ptr<Dataset> ds = YesNo(folder_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"waveform", "sample_rate", "label"};
|
||||
EXPECT_EQ(types.size(), 3);
|
||||
EXPECT_EQ(types[0].ToString(), "float32");
|
||||
EXPECT_EQ(types[1].ToString(), "int32");
|
||||
EXPECT_EQ(types[2].ToString(), "int32");
|
||||
EXPECT_EQ(shapes.size(), 3);
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[2].ToString(), "<8>");
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: DatasetFail tests.
|
||||
/// Expectation: throw error messages when certain errors occur.
|
||||
TEST_F(MindDataTestPipeline, TestYesNoDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoDatasetFail.";
|
||||
|
||||
// Create a YesNo Dataset
|
||||
std::shared_ptr<Dataset> ds = YesNo("", std::make_shared<RandomSampler>(false, 1));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: Invalid YesNo directory
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Test YesNo dataset.
|
||||
/// Description: NullSamplerFail tests.
|
||||
/// Expectation: Throw error messages when certain errors occur.
|
||||
TEST_F(MindDataTestPipeline, TestYesNoDatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNo10DatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a YesNo Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testYesNoData/";
|
||||
std::shared_ptr<Dataset> ds = YesNo(folder_path, nullptr);
|
||||
// Expect failure: Null Sampler
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: Null Sampler
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,185 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testYesNoData/"
|
||||
|
||||
|
||||
def test_yes_no_basic():
|
||||
"""
|
||||
Feature: YesNo Dataset
|
||||
Description: Read all files
|
||||
Expectation: Output the amount of file
|
||||
"""
|
||||
logger.info("Test YesNoDataset Op")
|
||||
|
||||
data = ds.YesNoDataset(DATA_DIR)
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_yes_no_num_samples():
|
||||
"""
|
||||
Feature: YesNo Dataset
|
||||
Description: Test num_samples
|
||||
Expectation: Get certain number of samples
|
||||
"""
|
||||
data = ds.YesNoDataset(DATA_DIR, num_samples=2)
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_yes_no_repeat():
|
||||
"""
|
||||
Feature: YesNo Dataset
|
||||
Description: Test repeat
|
||||
Expectation: Output the amount of file
|
||||
"""
|
||||
data = ds.YesNoDataset(DATA_DIR, num_samples=2)
|
||||
data = data.repeat(5)
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_yes_no_dataset_size():
|
||||
"""
|
||||
Feature: YesNo Dataset
|
||||
Description: Test dataset_size
|
||||
Expectation: Output the size of dataset
|
||||
"""
|
||||
data = ds.YesNoDataset(DATA_DIR, shuffle=False)
|
||||
assert data.get_dataset_size() == 3
|
||||
|
||||
|
||||
def test_yes_no_sequential_sampler():
|
||||
"""
|
||||
Feature: YesNo Dataset
|
||||
Description: Use SequentialSampler to sample data.
|
||||
Expectation: The number of samplers returned by dict_iterator is equal to the requested number of samples.
|
||||
"""
|
||||
logger.info("Test YesNoDataset Op with SequentialSampler")
|
||||
num_samples = 2
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.YesNoDataset(DATA_DIR, sampler=sampler)
|
||||
data2 = ds.YesNoDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
|
||||
sample_rate_list1, sample_rate_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1),
|
||||
data2.create_dict_iterator(num_epochs=1)):
|
||||
sample_rate_list1.append(item1["sample_rate"])
|
||||
sample_rate_list2.append(item2["sample_rate"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_yes_no_exception():
|
||||
"""
|
||||
Feature: Error tests
|
||||
Description: Throw error messages when certain errors occur
|
||||
Expectation: Output error message
|
||||
"""
|
||||
logger.info("Test error cases for YesNoDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.YesNoDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.YesNoDataset(DATA_DIR, sampler=ds.PKSampler(3),
|
||||
num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.YesNoDataset(DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.YesNoDataset(DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
error_msg_8 = "The corresponding data files"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.YesNoDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=[
|
||||
"waveform"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.YesNoDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=[
|
||||
"sample_rate"], num_parallel_workers=1)
|
||||
for _ in data.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_yes_no_pipeline():
|
||||
"""
|
||||
Feature: Pipeline test
|
||||
Description: Read a sample
|
||||
Expectation: The amount of each function are equal
|
||||
"""
|
||||
# Original waveform
|
||||
dataset = ds.YesNoDataset(DATA_DIR, num_samples=1)
|
||||
band_biquad_op = audio.BandBiquad(8000, 200.0)
|
||||
# Filtered waveform by bandbiquad
|
||||
dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_yes_no_basic()
|
||||
test_yes_no_num_samples()
|
||||
test_yes_no_repeat()
|
||||
test_yes_no_dataset_size()
|
||||
test_yes_no_sequential_sampler()
|
||||
test_yes_no_exception()
|
||||
test_yes_no_pipeline()
|
Loading…
Reference in New Issue