[feat][assistant][I3T96H] add new dataset loading operator TedliumDataset

This commit is contained in:
luqilin 2021-12-02 09:53:39 +08:00
parent 37cb0b7561
commit 34bffbf768
30 changed files with 1901 additions and 1 deletions

View File

@ -112,6 +112,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h" #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h" #include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
@ -1448,6 +1449,34 @@ QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::ve
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
const std::vector<char> &usage, const std::vector<char> &extensions,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
CharToString(extensions), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
const std::vector<char> &usage, const std::vector<char> &extensions,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
CharToString(extensions), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
const std::vector<char> &usage, const std::vector<char> &extensions,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
CharToString(extensions), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples, TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {

View File

@ -44,6 +44,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h" #include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h" #include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
@ -400,6 +401,18 @@ PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
})); }));
})); }));
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
"to create a TedliumNode")
.def(py::init([](std::string dataset_dir, std::string release, std::string usage,
std::string extensions, py::handle sampler) {
auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions,
toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(tedlium->ValidateParams());
return tedlium;
}));
}));
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode") "to create a TextFileNode")

View File

@ -29,6 +29,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
random_data_op.cc random_data_op.cc
sbu_op.cc sbu_op.cc
speech_commands_op.cc speech_commands_op.cc
tedlium_op.cc
text_file_op.cc text_file_op.cc
usps_op.cc usps_op.cc
yes_no_op.cc yes_no_op.cc

View File

@ -0,0 +1,309 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/tedlium_op.h"
#include <algorithm>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <sstream>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
TedliumOp::TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, int32_t num_parallel_workers,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t queue_size)
: MappableLeafOp(num_parallel_workers, queue_size, std::move(sampler)),
dataset_dir_(dataset_dir),
release_(release),
usage_(usage),
extensions_(extensions),
data_schema_(std::move(data_schema)),
audio_files_({}),
usage_list_({}) {}
void TedliumOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nNumber of rows: " << num_rows_ << "\nTedliumOp directory: " << dataset_dir_;
}
}
Status TedliumOp::PrepareData() {
auto real_path = FileUtils::GetRealPath(dataset_dir_.c_str());
if (!real_path.has_value()) {
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + dataset_dir_);
}
Path root_folder(real_path.value());
if (release_ == "release1" || release_ == "release2") {
if (usage_ == "train" || usage_ == "test" || usage_ == "dev") {
usage_list_.push_back(usage_);
} else if (usage_ == "all") {
usage_list_ = {"train", "test", "dev"};
}
for (int32_t i = 0; i < usage_list_.size(); ++i) {
Path stm_folder = root_folder / usage_list_[i] / "stm";
RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, usage_list_[i]));
}
} else if (release_ == "release3") {
if (usage_ == "all") {
Path stm_folder = root_folder / "data" / "stm";
RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, "data"));
}
}
std::sort(audio_files_.begin(), audio_files_.end());
num_rows_ = audio_files_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API TedliumDataset. Please check file path or dataset API.");
}
return Status::OK();
}
Status TedliumOp::ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage) {
Path dir(stm_folder);
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&dir);
if (!dir.Exists() || dirItr == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open folder: " + dir.ToString());
}
MS_LOG(DEBUG) << "Tedlium " + release_ + " stm folder Path found: " << dir << ".";
while (dirItr->HasNext()) {
Path file = dirItr->Next();
if (file.Extension() == ".stm") {
std::ifstream handle(file.ToString());
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file.ToString());
}
std::string line;
int32_t numline = 0;
while (getline(handle, line)) {
std::string filename = line.substr(0, line.find(" "));
std::stringstream ss;
ss << numline;
audio_files_.push_back({ss.str(), filename, release_usage});
++numline;
}
handle.close();
}
}
return Status::OK();
}
Status TedliumOp::ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id,
std::string *start_time, std::string *end_time, std::string *identifier,
std::string *transcript) {
std::ifstream handle(file_stm_path.ToString().c_str());
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_stm_path.ToString());
}
std::string line;
int32_t i = 0;
while (i <= row_line && getline(handle, line)) {
++i;
}
handle.close();
std::vector<std::string> temp;
i = 0;
const int32_t data_stm_number = 7;
// There are seven pieces of data in each row, which need to be read out and stored
// with a space as a separator.
// Talk_id, _, speaker_id, start_time, end_time, identifier, transcript.
// "_" is the data we don't need.
while (i < data_stm_number - 1) {
std::string s = line.substr(0, line.find(" "));
temp.push_back(s);
line.erase(0, line.find(" ") + 1); // to delete space, so use s.find(" ") + 1.
++i;
}
temp.push_back(line);
if (temp.size() != data_stm_number) {
RETURN_STATUS_UNEXPECTED("Invalid data, stm data was broken.");
}
const int32_t talk_id_num = 0, speaker_id_num = 2, start_time_num = 3, end_time_num = 4, identifier_num = 5,
transcript_num = 6;
*talk_id = temp[talk_id_num];
// temp[1] is "_", which is the data we don't need.
*speaker_id = temp[speaker_id_num];
*start_time = temp[start_time_num];
*end_time = temp[end_time_num];
*identifier = temp[identifier_num];
*transcript = temp[transcript_num];
return Status::OK();
}
Status TedliumOp::ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate,
std::vector<float> *result) {
std::ifstream handle(file_sph_path.ToString().c_str(), std::ios::in | std::ios::binary);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_sph_path.ToString());
}
char head[1024];
handle.read(head, sizeof(head));
CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(),
"Invalid data, failed to read head part from sph file: " + file_sph_path.ToString() +
", re-download dataset(make sure the data is true).");
std::vector<std::string> vec;
for (int32_t i = 0, j = 0; i < strlen(head); ++i) {
if (head[i] == '\n' || head[i] == ' ') {
while (head[i + 1] == ' ') {
i++;
}
std::string strTemp(head + j, i - j);
vec.push_back(strTemp);
j = i + 1;
}
}
const int32_t dataToBytes = 2;
for (int32_t i = 0; i < vec.size(); ++i) {
if (vec[i] == "sample_rate") {
*sample_rate = atoi(vec[i + dataToBytes].c_str());
}
}
int32_t start = static_cast<int32_t>(start_time * (*sample_rate));
int32_t end = static_cast<int32_t>(end_time * (*sample_rate));
const int32_t size = (end - start);
std::vector<char> temp(size * dataToBytes);
handle.seekg(start, std::ios::beg);
int32_t j = 0;
char c;
while (j < size * dataToBytes) {
handle.read(&c, 1);
CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(),
"Invalid data, failed to read data part from sph file: " + file_sph_path.ToString() +
", re-download dataset(make sure the data is true).");
temp.push_back(c);
++j;
}
const float kMaxVal = 32767.0;
for (int32_t i = 0; i < size; ++i) {
char bh = temp[2 * i];
char bl = temp[2 * i + 1];
// SPH aduio files is big-endian, so we should convert the two bytes of data into int16_t based
// on the high 8 bits and the low 8 bits.
int16_t s = static_cast<int16_t>(((bh & 0x00FF) << 8) | (bl & 0x00FF));
// Data normalization: Convert the data from the interval [-32768,32767] to the interval [-1,1].
double t = s / kMaxVal;
(*result).push_back(t);
}
handle.close();
return Status::OK();
}
Status TedliumOp::LoadTensorRow(row_id_type row_id, TensorRow *row) {
int32_t row_line = atoi(audio_files_[row_id][0].c_str());
std::string file_name = audio_files_[row_id][1];
std::string file_usage_or3_none_ = audio_files_[row_id][2];
Path dir_path(dataset_dir_);
Path file_stm_path = dir_path / file_usage_or3_none_ / "stm" / (file_name + ".stm");
Path file_sph_path = dir_path / file_usage_or3_none_ / "sph" / (file_name + extensions_);
std::string talk_id, speaker_id, start_time, end_time, identifier, transcript;
std::vector<float> result;
int32_t sample_rate;
RETURN_IF_NOT_OK(
ReadStm(file_stm_path, row_line, &talk_id, &speaker_id, &start_time, &end_time, &identifier, &transcript));
RETURN_IF_NOT_OK(ReadSph(file_sph_path, atof(start_time.c_str()), atof(end_time.c_str()), &sample_rate, &result));
std::shared_ptr<Tensor> sample_rate_tensor, talk_id_tensor, speaker_id_tensor, identifier_tensor, transcript_tensor;
RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_tensor));
RETURN_IF_NOT_OK(Tensor::CreateScalar(talk_id, &talk_id_tensor));
RETURN_IF_NOT_OK(Tensor::CreateScalar(speaker_id, &speaker_id_tensor));
RETURN_IF_NOT_OK(Tensor::CreateScalar(identifier, &identifier_tensor));
RETURN_IF_NOT_OK(Tensor::CreateScalar(transcript, &transcript_tensor));
std::shared_ptr<Tensor> audio_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(result, &audio_tensor));
RETURN_IF_NOT_OK(audio_tensor->ExpandDim(0));
(*row) = TensorRow(row_id, {audio_tensor, sample_rate_tensor, transcript_tensor, talk_id_tensor, speaker_id_tensor,
identifier_tensor});
row->setPath({file_sph_path.ToString(), file_sph_path.ToString(), file_stm_path.ToString(), file_stm_path.ToString(),
file_stm_path.ToString(), file_stm_path.ToString()});
return Status::OK();
}
Status TedliumOp::CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, int64_t *count) {
// the logic of counting the number of samples is copied from PrepareData()
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto new_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
// build a new unique schema object
auto new_schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
new_schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
TensorShape trans_scalar = TensorShape::CreateScalar();
TensorShape talk_id_scalar = TensorShape::CreateScalar();
TensorShape speaker_id_scalar = TensorShape::CreateScalar();
TensorShape identi_scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(new_schema->AddColumn(
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
RETURN_IF_NOT_OK(new_schema->AddColumn(
ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar)));
RETURN_IF_NOT_OK(new_schema->AddColumn(
ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar)));
RETURN_IF_NOT_OK(new_schema->AddColumn(
ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar)));
RETURN_IF_NOT_OK(new_schema->AddColumn(
ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
std::shared_ptr<TedliumOp> op =
std::make_shared<TedliumOp>(dataset_dir, release, usage, extensions, num_workers, std::move(new_schema),
std::move(new_sampler), op_connect_size);
RETURN_IF_NOT_OK(op->PrepareData());
*count = static_cast<int64_t>(op->audio_files_.size());
return Status::OK();
}
Status TedliumOp::ComputeColMap() {
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,126 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
class TedliumOp : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] dataset_dir Directory of tedlium dataset.
/// \param[in] release Release of tedlium dataset, can be 'release1', 'release2' or 'release3'.
/// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'.
/// \param[in] extensions Extensions of the sph file, only '.sph' is valid.
/// \param[in] num_parallel_workers Number of workers in parallel.
/// \param[in] data_schema Schema of dataset.
/// \param[in] sampler Sampler tells TedliumOp what to read.
/// \param[in] queue_size Connector queue size.
TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, int32_t num_parallel_workers, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler, int32_t queue_size);
/// \brief Destructor.
~TedliumOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out Out stream.
/// \param[in] show_all Whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
std::string Name() const override { return "TedliumOp"; }
/// \brief Initialize TedliumOp related var, calls the function to walk all files.
/// \return Status The status code returned.
Status PrepareData() override;
/// \brief Function to count the number of samples in the TEDLIUM dataset.
/// \param[in] dataset_dir Directory of tedlium dataset.
/// \param[in] release Release of tedlium dataset.
/// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'.
/// \param[in] extensions Extensions of the sph file, only '.sph' is valid.
/// \param[in] count Output arg that will hold the actual dataset size.
/// \return Status The status code returned.
static Status CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, int64_t *count);
private:
/// \brief Read stm file.
/// \param[in] file_stm_path The path of stm file.
/// \param[in] row_line Which line of the file we need to read.
/// \param[out] talk_id Talk identifier of the row_line in the file.
/// \param[out] speaker_id Speaker identifier of the row_line in the file.
/// \param[out] start_time Start time of the row_line in the file.
/// \param[out] end_time End time of the row_line in the file.
/// \param[out] identifier Identifier of the row_line in the file.
/// \param[out] transcript Transcript of the row_line in the file.
/// \return Status The status code returned.
Status ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id,
std::string *start_time, std::string *end_time, std::string *identifier, std::string *transcript);
/// \brief Read sph file.
/// \param[in] file_sph_path The path of sph file.
/// \param[in] start_time The start_time of row we need to use.
/// \param[in] end_time The end_time of row we need to use.
/// \param[out] sample_rate Sample rate of the row.
/// \param[out] result Waveform result vector of the row.
/// \return Status The status code returned.
Status ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate,
std::vector<float> *result);
/// \brief Read stm files according current release`s usage.
/// \param[in] stm_folder The folder of stm files.
/// \param[in] release_usage For release1 or release2, use usage_, for release3, "data".
/// \return Status The status code returned.
Status ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage);
/// \brief Load a tensor row according to a pair.
/// \param[in] row_id Id of row need to load.
/// \param[in] row Audio & label read into this tensor row.
/// \return Status The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
/// \brief Private function for computing the assignment of the column name map.
/// \return Status The status code returned.
Status ComputeColMap() override;
const std::string release_;
const std::string dataset_dir_;
const std::string usage_;
const std::string extensions_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::vector<std::string> > audio_files_;
std::vector<std::string> usage_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_

View File

@ -103,6 +103,7 @@ constexpr char kQMnistNode[] = "QMnistDataset";
constexpr char kRandomNode[] = "RandomDataset"; constexpr char kRandomNode[] = "RandomDataset";
constexpr char kSBUNode[] = "SBUDataset"; constexpr char kSBUNode[] = "SBUDataset";
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset"; constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
constexpr char kTedliumNode[] = "TedliumDataset";
constexpr char kTextFileNode[] = "TextFileDataset"; constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset"; constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kUSPSNode[] = "USPSDataset"; constexpr char kUSPSNode[] = "USPSDataset";

View File

@ -29,6 +29,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
random_node.cc random_node.cc
sbu_node.cc sbu_node.cc
speech_commands_node.cc speech_commands_node.cc
tedlium_node.cc
text_file_node.cc text_file_node.cc
tf_record_node.cc tf_record_node.cc
usps_node.cc usps_node.cc

View File

@ -0,0 +1,152 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include <fstream>
#include <utility>
#include "minddata/dataset/engine/datasetops/source/tedlium_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for TedliumNode.
TedliumNode::TedliumNode(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
release_(release),
extensions_(extensions),
usage_(usage),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> TedliumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<TedliumNode>(dataset_dir_, release_, usage_, extensions_, sampler, cache_);
return node;
}
void TedliumNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status ValidateExtensionsParam(const std::string &dataset_name, const std::string &extensions) {
if (extensions != ".sph") {
std::string err_msg = dataset_name + ": extension " + extensions + " is not supported.";
MS_LOG(ERROR) << err_msg;
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status TedliumNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("TedliumNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", release_, {"release1", "release2", "release3"}));
RETURN_IF_NOT_OK(ValidateExtensionsParam("TedliumNode", extensions_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("TedliumNode", sampler_));
if (release_ == "release1" || release_ == "release2") {
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", usage_, {"dev", "train", "test", "all"}));
} else if (release_ == "release3") {
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", usage_, {"all"}));
}
return Status::OK();
}
Status TedliumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
TensorShape trans_scalar = TensorShape::CreateScalar();
TensorShape talk_id_scalar = TensorShape::CreateScalar();
TensorShape speaker_id_scalar = TensorShape::CreateScalar();
TensorShape identi_scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar)));
// Argument that is not exposed to user in the API.
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto tedlium_op = std::make_shared<TedliumOp>(dataset_dir_, release_, usage_, extensions_, num_workers_,
std::move(schema), std::move(sampler_rt), connector_que_size_);
tedlium_op->SetTotalRepeats(GetTotalRepeats());
tedlium_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(tedlium_op);
return Status::OK();
}
Status TedliumNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
Status TedliumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size = 0;
RETURN_IF_NOT_OK(TedliumOp::CountTotalRows(dataset_dir_, release_, usage_, extensions_, &num_rows));
// give sampler the total number of files and check if num_samples is smaller.
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
// We cache dataset size so as to not duplicated run.
dataset_size_ = *dataset_size;
return Status::OK();
}
Status TedliumNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["release"] = release_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["extensions"] = extensions_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,110 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class TedliumNode : public MappableSourceNode {
public:
/// \brief Constructor.
TedliumNode(const std::string &dataset_dir, const std::string &release, const std::string &usage,
const std::string &extensions, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~TedliumNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kTedliumNode; }
/// \brief Print the description.
/// \param out - The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id Shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
/// \brief Release getter.
/// \return Release of the current node.
const std::string &Release() const { return release_; }
/// \brief DatasetDir getter.
/// \return DatasetDir of the current node.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Usage getter.
/// \return Usage of the current node.
const std::string &Usage() const { return usage_; }
/// \brief Extensions getter.
/// \return Extensions of the current node.
const std::string &Extensions() const { return extensions_; }
private:
std::string dataset_dir_;
std::string release_;
std::string usage_;
std::string extensions_;
std::shared_ptr<SamplerObj> sampler_;
}; // class TedliumNode
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_

View File

@ -3614,6 +3614,109 @@ inline std::shared_ptr<SpeechCommandsDataset> SpeechCommands(const std::string &
return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
} }
/// \class TedliumDataset
/// \brief A source dataset for reading and parsing tedlium dataset.
class MS_API TedliumDataset : public Dataset {
public:
/// \brief Constructor of TedliumDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all".
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of TedliumDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all".
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of TedliumDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all".
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const std::reference_wrapper<Sampler> &samlper,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of TedliumDataset.
~TedliumDataset() = default;
};
/// \brief Function to create a TedliumDataset.
/// \note The generated dataset has six columns ["waveform", "sample_rate", "transcript", "talk_id", "speaker_id",
/// "identifier"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all" (default = "all").
/// \param[in] extensions The extensions of audio file. Only support ".sph" now (default = ".sph").
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the TedliumDataset.
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(
const std::string &dataset_dir, const std::string &release, const std::string &usage = "all",
const std::string &extensions = ".sph", const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
/// \brief Function to create a TedliumDataset.
/// \note The generated dataset has six columns ["waveform", "sample_rate","transcript", "talk_id", "speaker_id",
/// "identifier"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all".
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the TedliumDataset.
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(const std::string &dataset_dir, const std::string &release,
const std::string &usage, const std::string &extensions,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
/// \brief Function to create a TedliumDataset.
/// \note The generated dataset has six columns ["waveform", "sample_rate","transcript", "talk_id", "speaker_id",
/// "identifier"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
/// can be "train", "test" or "all".
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the TedliumDataset.
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(const std::string &dataset_dir, const std::string &release,
const std::string &usage, const std::string &extensions,
Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
/// \class TextFileDataset /// \class TextFileDataset
/// \brief A source dataset that reads and parses datasets stored on disk in text format. /// \brief A source dataset that reads and parses datasets stored on disk in text format.
class MS_API TextFileDataset : public Dataset { class MS_API TextFileDataset : public Dataset {

View File

@ -56,6 +56,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
friend class RandomDataDataset; friend class RandomDataDataset;
friend class SBUDataset; friend class SBUDataset;
friend class SpeechCommandsDataset; friend class SpeechCommandsDataset;
friend class TedliumDataset;
friend class TextFileDataset; friend class TextFileDataset;
friend class TFRecordDataset; friend class TFRecordDataset;
friend class USPSDataset; friend class USPSDataset;

View File

@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \ check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \ check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \ check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
check_yes_no_dataset, check_speech_commands_dataset check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size, get_auto_offload get_prefetch_size, get_auto_offload
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -8625,3 +8625,218 @@ class YesNoDataset(MappableDataset):
def parse(self, children=None): def parse(self, children=None):
return cde.YesNoNode(self.dataset_dir, self.sampler) return cde.YesNoNode(self.dataset_dir, self.sampler)
class TedliumDataset(MappableDataset):
"""
A source dataset for reading and parsing Tedlium dataset.
The columns of generated dataset depend on the source SPH files and the corresponding STM files.
The generated dataset has six columns :py:obj:`[waveform, sample_rate, transcript, talk_id, speaker_id,
identifier]`.
The tensor of column :py:obj:`waveform` is of the float32 type.
The tensor of column :py:obj:`sample_rate` is a scalar of the int32 type.
The tensor of column :py:obj:`transcript` is a scalar of the string type.
The tensor of column :py:obj:`talk_id` is a scalar of the string type.
The tensor of column :py:obj:`speaker_id` is a scalar of the string type.
The tensor of column :py:obj:`identifier` is a scalar of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
release (str): Release of the dataset, can be "release1", "release2", "release3".
usage (str, optional): Usage of this dataset.
For release1 or release2, can be `train`, `test`, ` dev` or `all`.
`train` will read from train samples,
`test` will read from test samples,
`dev` will read from dev samples,
`all` will read from all samples.
For release3, can only be "all", it will read from data samples (default=None, all samples).
extensions (str): Extensions of the SPH files, only '.sph' is valid.
(default=None, ".sph").
num_samples (int, optional): The number of audio samples to be included in the dataset
(default=None, all samples).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain stm files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> tedlium_dataset_dir = "/path/to/tedlium_dataset_directory"
>>> tedlium_dataset_release = ["release1", "release2", "release3"]
>>>
>>> # 1) Get all train samples from TEDLIUM_release1 dataset in sequence.
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[0],
... shuffle=False)
>>>
>>> # 2) Randomly select 10 samples from TEDLIUM_release2 dataset.
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[1],
... num_samples=10, shuffle=True)
>>>
>>> # 3) Get samples from TEDLIUM_release-3 dataset for shard 0 in a 2-way distributed training.
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[2],
... num_shards=2, shard_id=0)
>>>
>>> # In TEDLIUM dataset, each dictionary has keys : waveform, sample_rate, transcript, talk_id,
>>> # speaker_id and identifier.
About TEDLIUM_release1 dataset:
The TED-LIUM corpus is English-language TED talks, with transcriptions, sampled at 16kHz.
It contains about 118 hours of speech.
About TEDLIUM_release2 dataset:
This is the TED-LIUM corpus release 2, licensed under Creative Commons BY-NC-ND 3.0. All talks and text are
property of TED Conferences LLC. The TED-LIUM corpus was made from audio talks and their transcriptions available
on the TED website. We have prepared and filtered these data in order to train acoustic models to participate to
the International Workshop on Spoken Language Translation 2011 (the LIUM English/French SLT system reached the
first rank in the SLT task).
About TEDLIUM_release-3 dataset:
This is the TED-LIUM corpus release 3, licensed under Creative Commons BY-NC-ND 3.0. All talks and text are
property of TED Conferences LLC. This new TED-LIUM release was made through a collaboration between the Ubiqus
company and the LIUM (University of Le Mans, France).
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
The structure of TEDLIUM release2 is the same as TEDLIUM release1, only the data is different.
.. code-block::
.
TEDLIUM_release1
dev
sph
AlGore_2009.sph
BarrySchwartz_2005G.sph
stm
AlGore_2009.stm
BarrySchwartz_2005G.stm
test
sph
AimeeMullins_2009P.sph
BillGates_2010.sph
stm
AimeeMullins_2009P.stm
BillGates_2010.stm
train
sph
AaronHuey_2010X.sph
AdamGrosser_2007.sph
stm
AaronHuey_2010X.stm
AdamGrosser_2007.stm
readme
TEDLIUM.150k.dic
.. code-block::
.
TEDLIUM_release-3
data
ctl
sph
911Mothers_2010W.sph
AalaElKhani.sph
stm
911Mothers_2010W.stm
AalaElKhani.stm
doc
legacy
LM
speaker-adaptation
readme
TEDLIUM.150k.dic
Citation:
.. code-block::
@article{
title={TED-LIUM: an automatic speech recognition dedicated corpus},
author={A. Rousseau, P. Deléglise, Y. Estève},
journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)},
year={May 2012},
biburl={https://www.openslr.org/7/}
}
@article{
title={Enhancing the TED-LIUM Corpus with Selected Data for Language Modeling and More TED Talks},
author={A. Rousseau, P. Deléglise, and Y. Estève},
journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)},
year={May 2014},
biburl={https://www.openslr.org/19/}
}
@article{
title={TED-LIUM 3: twice as much data and corpus repartition for experiments on speaker adaptation},
author={François Hernandez, Vincent Nguyen, Sahar Ghannay, Natalia Tomashenko, and Yannick Estève},
journal={the 20th International Conference on Speech and Computer (SPECOM 2018)},
year={September 2018},
biburl={https://www.openslr.org/51/}
}
"""
@check_tedlium_dataset
def __init__(self, dataset_dir, release, usage=None, extensions=None, num_samples=None,
num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None,
shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.extensions = replace_none(extensions, ".sph")
self.release = release
self.usage = replace_none(usage, "all")
def parse(self, children=None):
return cde.TedliumNode(self.dataset_dir, self.release, self.usage, self.extensions, self.sampler)

View File

@ -1863,3 +1863,39 @@ def check_yes_no_dataset(method):
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
def check_tedlium_dataset(method):
"""Wrapper method to check the parameters of TedliumDataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
release = param_dict.get('release')
check_valid_str(release, ["release1", "release2", "release3"], "release")
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
if release in ["release1", "release2"]:
check_valid_str(usage, ["train", "test", "dev", "all"], "usage")
else:
check_valid_str(usage, ["all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -40,6 +40,7 @@ SET(DE_UT_SRCS
c_api_dataset_save.cc c_api_dataset_save.cc
c_api_dataset_sbu_test.cc c_api_dataset_sbu_test.cc
c_api_dataset_speech_commands_test.cc c_api_dataset_speech_commands_test.cc
c_api_dataset_tedlium_test.cc
c_api_dataset_textfile_test.cc c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc c_api_dataset_tfrecord_test.cc
c_api_dataset_usps_test.cc c_api_dataset_usps_test.cc

View File

@ -0,0 +1,383 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: TedliumDataset.
/// Description: read some samples from all files according to different versions.
/// Expectation: 4 * 2 samples.
TEST_F(MindDataTestPipeline, TestTedliumDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDataset.";
// Create a Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds1 =
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
std::shared_ptr<Dataset> ds3 =
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
EXPECT_NE(iter1, nullptr);
EXPECT_NE(iter3, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row1;
std::unordered_map<std::string, mindspore::MSTensor> row3;
ASSERT_OK(iter1->GetNextRow(&row1));
EXPECT_NE(row1.find("waveform"), row1.end());
EXPECT_NE(row1.find("sample_rate"), row1.end());
EXPECT_NE(row1.find("transcript"), row1.end());
EXPECT_NE(row1.find("talk_id"), row1.end());
EXPECT_NE(row1.find("speaker_id"), row1.end());
EXPECT_NE(row1.find("identifier"), row1.end());
ASSERT_OK(iter3->GetNextRow(&row3));
EXPECT_NE(row3.find("waveform"), row3.end());
EXPECT_NE(row3.find("sample_rate"), row3.end());
EXPECT_NE(row3.find("transcript"), row3.end());
EXPECT_NE(row3.find("talk_id"), row3.end());
EXPECT_NE(row3.find("speaker_id"), row3.end());
EXPECT_NE(row3.find("identifier"), row3.end());
uint64_t i = 0;
while (row1.size() != 0) {
i++;
auto audio = row1["waveform"];
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
ASSERT_OK(iter1->GetNextRow(&row1));
}
while (row3.size() != 0) {
i++;
auto audio = row3["waveform"];
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
ASSERT_OK(iter3->GetNextRow(&row3));
}
EXPECT_EQ(i, 4 * 2);
// Manually terminate the pipeline.
iter1->Stop();
iter3->Stop();
}
/// Feature: TedliumDataset.
/// Description: read some samples with pipeline from all files according to different versions.
/// Expectation: 8 * 2 samples.
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithPipeline.";
// Create two Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds11 =
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
std::shared_ptr<Dataset> ds31 =
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
std::shared_ptr<Dataset> ds12 =
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
std::shared_ptr<Dataset> ds32 =
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
EXPECT_NE(ds11, nullptr);
EXPECT_NE(ds12, nullptr);
EXPECT_NE(ds31, nullptr);
EXPECT_NE(ds32, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 1;
ds11 = ds11->Repeat(repeat_num);
ds31 = ds31->Repeat(repeat_num);
EXPECT_NE(ds11, nullptr);
EXPECT_NE(ds31, nullptr);
repeat_num = 1;
ds12 = ds12->Repeat(repeat_num);
ds32 = ds32->Repeat(repeat_num);
EXPECT_NE(ds12, nullptr);
EXPECT_NE(ds32, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"waveform", "sample_rate", "transcript",
"talk_id", "speaker_id", "identifier"};
ds11 = ds11->Project(column_project);
EXPECT_NE(ds11, nullptr);
ds12 = ds12->Project(column_project);
EXPECT_NE(ds12, nullptr);
ds31 = ds31->Project(column_project);
EXPECT_NE(ds31, nullptr);
ds32 = ds32->Project(column_project);
EXPECT_NE(ds32, nullptr);
// Create a Concat operation on the ds.
ds11 = ds11->Concat({ds12});
ds31 = ds31->Concat({ds32});
EXPECT_NE(ds11, nullptr);
EXPECT_NE(ds31, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds11->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds31->CreateIterator();
EXPECT_NE(iter1, nullptr);
EXPECT_NE(iter3, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row1;
std::unordered_map<std::string, mindspore::MSTensor> row3;
ASSERT_OK(iter1->GetNextRow(&row1));
ASSERT_OK(iter3->GetNextRow(&row3));
EXPECT_NE(row1.find("waveform"), row1.end());
EXPECT_NE(row1.find("sample_rate"), row1.end());
EXPECT_NE(row1.find("transcript"), row1.end());
EXPECT_NE(row1.find("talk_id"), row1.end());
EXPECT_NE(row1.find("speaker_id"), row1.end());
EXPECT_NE(row1.find("identifier"), row1.end());
EXPECT_NE(row3.find("waveform"), row3.end());
EXPECT_NE(row3.find("sample_rate"), row3.end());
EXPECT_NE(row3.find("transcript"), row3.end());
EXPECT_NE(row3.find("talk_id"), row3.end());
EXPECT_NE(row3.find("speaker_id"), row3.end());
EXPECT_NE(row3.find("identifier"), row3.end());
uint64_t i = 0;
while (row1.size() != 0) {
i++;
auto audio = row1["waveform"];
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
ASSERT_OK(iter1->GetNextRow(&row1));
}
while (row3.size() != 0) {
i++;
auto audio = row3["waveform"];
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
ASSERT_OK(iter3->GetNextRow(&row3));
}
EXPECT_EQ(i, 8 * 2);
// Manually terminate the pipeline.
iter1->Stop();
iter3->Stop();
}
/// Feature: TedliumDataset.
/// Description: read number of all samples from all files according to different versions.
/// Expectation: TEDLIUM_release12 : 1 + 2 + 3
/// TEDLIUM_release3 : 3 + 4
TEST_F(MindDataTestPipeline, TestTedliumGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumGetDatasetSize.";
// Create a Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "all", ".sph");
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "all", ".sph");
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds3, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 1 + 2 + 3);
EXPECT_EQ(ds3->GetDatasetSize(), 3 + 4);
}
/// Feature: TedliumDataset.
/// Description: Includes tests for shape, type, size.
/// Expectation: correct shape, type, size.
TEST_F(MindDataTestPipeline, TestTedliumGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumGetters.";
// Create a Tedlium Dataset.
std::string folder_path = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::shared_ptr<Dataset> ds = Tedlium(folder_path, "release1", "all", ".sph");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"waveform", "sample_rate", "transcript",
"talk_id", "speaker_id", "identifier"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 6);
EXPECT_EQ(types[0].ToString(), "float32");
EXPECT_EQ(types[1].ToString(), "int32");
EXPECT_EQ(types[2].ToString(), "string");
EXPECT_EQ(types[3].ToString(), "string");
EXPECT_EQ(types[4].ToString(), "string");
EXPECT_EQ(types[5].ToString(), "string");
EXPECT_EQ(shapes.size(), 6);
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(shapes[2].ToString(), "<>");
EXPECT_EQ(shapes[3].ToString(), "<>");
EXPECT_EQ(shapes[4].ToString(), "<>");
EXPECT_EQ(shapes[5].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
}
/// Feature: TedliumDataset.
/// Description: test with invalid release.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidReleaseFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidReleaseFail.";
// Create a Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "", "all", ".sph");
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "RELEASE2", "all", ".sph");
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "2", "all", ".sph");
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid Tedlium input, "", "RELEASE2" and "2" are not a valid release.
EXPECT_EQ(iter1, nullptr);
EXPECT_EQ(iter2, nullptr);
EXPECT_EQ(iter3, nullptr);
}
/// Feature: TedliumDataset.
/// Description: test with invalid path.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestTedliumDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetFail.";
// Create a Tedlium Dataset.
std::shared_ptr<Dataset> ds1 = Tedlium("", "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
std::shared_ptr<Dataset> ds2 =
Tedlium("validation", "release2", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
std::shared_ptr<Dataset> ds3 = Tedlium("2", "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid Tedlium input, "", "validation" and "2" are not a valid path.
EXPECT_EQ(iter1, nullptr);
EXPECT_EQ(iter2, nullptr);
EXPECT_EQ(iter3, nullptr);
}
/// Feature: TedliumDataset.
/// Description: test with invalid usage.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidUsageFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidUsageFail.";
// Create a Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "", ".sph");
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "release2", "DEV", ".sph");
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "2", ".sph");
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid Tedlium input, "", "DEV" and "2" are not a valid usage.
EXPECT_EQ(iter1, nullptr);
EXPECT_EQ(iter2, nullptr);
EXPECT_EQ(iter3, nullptr);
}
/// Feature: TedliumDataset.
/// Description: test with invalid extensions.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidExtensionsFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidExtensionsFail.";
// Create a Tedlium Dataset.
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "all", "sph");
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "release2", "all", ".SPH");
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "all", ".stm");
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid Tedlium input, "sph", ".SPH", ".stm" are not a valid extensions.
EXPECT_EQ(iter1, nullptr);
EXPECT_EQ(iter2, nullptr);
EXPECT_EQ(iter3, nullptr);
}
/// Feature: TedliumDataset.
/// Description: test with null sampler.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithNullSamplerFail.";
// Create a Tedlium Dataset.
std::string folder_path = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
std::shared_ptr<Dataset> ds = Tedlium(folder_path, "release1", "all", ".sph", nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Tedlium input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1 @@
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.

View File

@ -0,0 +1,2 @@
test2 1 test2 0.00 0.02 <o,f0,female> this is record 1 of test2.
test2 1 test2 0.02 0.09 <o,f0,female> this is record 2 of test2.

View File

@ -0,0 +1 @@
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.

View File

@ -0,0 +1,2 @@
test2 1 test2 0.00 0.02 <o,f0,female> this is record 1 of test2.
test2 1 test2 0.02 0.09 <o,f0,female> this is record 2 of test2.

View File

@ -0,0 +1,3 @@
test3 1 test3 0.00 0.01 <o,f0,female> this is record 1 of test3.
test3 1 test3 0.02 0.07 <o,f0,female> this is record 1 of test3.
test3 1 test3 0.07 0.09 <o,f0,female> this is record 1 of test3.

View File

@ -0,0 +1,4 @@
test4 1 test4 0.00 0.01 <o,f0,female> this is record 1 of test4.
test4 1 test4 0.02 0.03 <o,f0,female> this is record 2 of test4.
test4 1 test4 0.05 0.07 <o,f0,female> this is record 3 of test4.
test4 1 test4 0.07 0.09 <o,f0,female> this is record 4 of test4.

View File

@ -0,0 +1 @@
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.

View File

@ -0,0 +1,405 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
DATA_DIR_TEDLIUM_RELEASE12 = "../data/dataset/testTedliumData/TEDLIUM_release1"
DATA_DIR_TEDLIUM_RELEASE3 = "../data/dataset/testTedliumData/TEDLIUM_release3"
RELEASE1 = "release1"
RELEASE2 = "release2"
RELEASE3 = "release3"
NO_SPH_DIR_TEDLIUM12 = "../data/dataset/testTedliumData/else"
def test_tedlium_basic():
"""
Feature: TedliumDataset
Description: use different data to test the functions of different versions
Expectation: num_samples
set 1 2 4
get 1 2 4
num_parallel_workers
set 1 2 4(num_samples=4)
get 4 4 4
num repeat
set 3(num_samples=5)
get 15
"""
# case1 test num_samples
data11 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
data12 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=2)
data13 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4)
num_iter11 = 0
num_iter12 = 0
num_iter13 = 0
for _ in data11.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter11 += 1
for _ in data12.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter12 += 1
for _ in data13.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter13 += 1
assert num_iter11 == 1
assert num_iter12 == 2
assert num_iter13 == 4
# case2 test num_parallel_workers
data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=4, num_parallel_workers=1)
data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=4, num_parallel_workers=2)
data23 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4, num_parallel_workers=4)
num_iter21 = 0
num_iter22 = 0
num_iter23 = 0
for _ in data21.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter21 += 1
for _ in data22.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter22 += 1
for _ in data23.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter23 += 1
assert num_iter21 == 4
assert num_iter22 == 4
assert num_iter23 == 4
# case3 test repeat
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=5)
data3 = data3.repeat(3)
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter3 += 1
assert num_iter3 == 15
def test_tedlium_content_check():
"""
Feature: TedliumDataset
Description: Check content of the first sample
Expectation: correct content
"""
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1, shuffle=False)
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=1, shuffle=False)
num_iter1 = 0
num_iter3 = 0
for d in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
waveform = d["waveform"]
sample_rate = d["sample_rate"]
transcript = d["transcript"]
talk_id = d["talk_id"]
speaker_id = d["speaker_id"]
identifier = d["identifier"]
assert waveform.dtype == np.float32
assert waveform.shape == (1, 480)
assert sample_rate == 16000
assert sample_rate.dtype == np.int32
assert talk_id.item().decode("utf8") == "test1"
assert speaker_id.item().decode("utf8") == "test1"
assert transcript.item().decode("utf8") == "this is record 1 of test1."
assert identifier.item().decode("utf8") == "<o,f0,female>"
num_iter1 += 1
for d in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
waveform = d["waveform"]
sample_rate = d["sample_rate"]
transcript = d["transcript"]
talk_id = d["talk_id"]
speaker_id = d["speaker_id"]
identifier = d["identifier"]
assert waveform.dtype == np.float32
assert waveform.shape == (1, 160)
assert sample_rate == 16000
assert sample_rate.dtype == np.int32
assert talk_id.item().decode("utf8") == "test3"
assert speaker_id.item().decode("utf8") == "test3"
assert transcript.item().decode("utf8") == "this is record 1 of test3."
assert identifier.item().decode("utf8") == "<o,f0,female>"
num_iter3 += 1
assert num_iter1 == 1
assert num_iter3 == 1
def test_tedlium_exceptions():
"""
Feature: TedliumDataset
Description: send error when error occur
Expectation: send error
"""
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=256)
error_msg_7 = "Invalid data, no valid data matching the dataset API TedliumDataset"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.TedliumDataset(NO_SPH_DIR_TEDLIUM12, RELEASE1, "train")
for _ in ds1.__iter__():
pass
def test_tedlium_exception_file_path():
"""
Feature: TedliumDataset
Description: error test
Expectation: throw error
"""
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
data = data.map(operations=exception_func, input_columns=["transcript"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
data = data.map(operations=exception_func, input_columns=["talk_id"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
data = data.map(operations=exception_func, input_columns=["speaker_id"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
data = data.map(operations=exception_func, input_columns=["identifier"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
def test_tedlium_extensions():
"""
Feature: TedliumDataset
Description: test extensions of tedlium
Expectation: extensions
set invalid data
get throw error
"""
try:
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, "train", "invalid")
for _ in data.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
assert "is not supported." in str(e)
def test_tedlium_release():
"""
Feature: TedliumDataset
Description: test release of tedlium
Expectation: release
set invalid data
get throw error
"""
def test_config(release):
try:
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, release)
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return None
# test the release
assert "release is not within the valid set of ['release1', 'release2', 'release3']" in test_config("invalid")
assert "Argument release with value None is not of type [<class 'str'>]" in test_config(None)
assert "Argument release with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
def test_tedlium_sequential_sampler():
"""
Feature: TedliumDataset
Description: test tedlium sequential sampler
Expectation: correct data
"""
num_samples = 3
sampler = ds.SequentialSampler(num_samples=num_samples)
data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, sampler=sampler)
data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shuffle=False, num_samples=num_samples)
num_iter2 = 0
for item1, item2 in zip(data21.create_dict_iterator(num_epochs=1, output_numpy=True),
data22.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_equal(item1["waveform"], item2["waveform"])
num_iter2 += 1
assert num_iter2 == num_samples
def test_tedlium_sampler_get_dataset_size():
"""
Feature: TedliumDataset
Description: test TedliumDataset with SequentialSampler and get_dataset_size
Expectation: num_samples
set 5
get 5
"""
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, sampler=sampler)
num_iter3 = 0
ds_sz3 = data3.get_dataset_size()
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter3 += 1
assert ds_sz3 == num_iter3 == 5
def test_tedlium_usage():
"""
Feature: TedliumDataset
Description: test usage of tedlium
Expectation: usage
set valid data invalid data
get correct data throw error
"""
def test_config_tedlium12(usage):
try:
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, usage=usage)
data2 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, usage=usage)
num_rows = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
# test the usage of TEDLIUM
assert test_config_tedlium12("dev") == 1 + 1
assert test_config_tedlium12("test") == 2 + 2
assert test_config_tedlium12("train") == 3 + 3
assert test_config_tedlium12("all") == 1 + 1 + 2 + 2 + 3 + 3
assert "usage is not within the valid set of ['train', 'test', 'dev', 'all']" in test_config_tedlium12("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config_tedlium12(["list"])
def test_tedlium_with_chained_sampler_get_dataset_size():
"""
Feature: TedliumDataset
Description: test TedliumDataset with RandomSampler chained with a SequentialSampler and get_dataset_size
Expectation: num_samples
set 2
get 2
"""
sampler = ds.SequentialSampler(start_index=0, num_samples=2)
child_sampler = ds.RandomSampler()
sampler.add_child(child_sampler)
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=sampler)
num_iter1 = 0
ds_sz1 = data1.get_dataset_size()
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter1 += 1
assert ds_sz1 == num_iter1 == 2
def test_tedlium_pipeline():
"""
Feature: TedliumDataset
Description: Read a sample
Expectation: The amount of each function are equal
"""
# Original waveform
dataset = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
band_biquad_op = audio.BandBiquad(8000, 200.0)
# Filtered waveform by bandbiquad
dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
i = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
i += 1
assert i == 1
if __name__ == '__main__':
test_tedlium_basic()
test_tedlium_content_check()
test_tedlium_exceptions()
test_tedlium_exception_file_path()
test_tedlium_extensions()
test_tedlium_release()
test_tedlium_sequential_sampler()
test_tedlium_sampler_get_dataset_size()
test_tedlium_usage()
test_tedlium_with_chained_sampler_get_dataset_size()
test_tedlium_pipeline()